#!/usr/bin/env python3

"""
Criticality: Normal
Checks that DNS records are provided by the customer servers are correctly set
"""

from pathlib import Path
import re
import subprocess
import sys

sys.path.append(str(Path(__file__).parents[1].resolve()))

# pylint: disable=wrong-import-position
from envsetup import utils as u  # noqa: E402


def get_dns_servers() -> set:
    if subprocess.getstatusoutput("command -v nmcli")[0] == 0:
        # network-manager
        _, output = subprocess.getstatusoutput(
            "nmcli -f all device show | grep IP4.DNS | awk '{ print $2 }'"
        )
        servers = [l for l in output.split("\n") if l]

    if not len(servers):
        # resolvconf
        with open("/etc/resolv.conf", "r") as f:
            d = f.read().strip()
            servers = [l.split()[1] for l in d.split("\n") if l.startswith("nameserver")]

    if "127.0.0.53" in servers:
        # systemd-resolved
        servers = list()
        _, output = subprocess.getstatusoutput("systemd-resolve --status")
        lines = [l.strip() for l in output.split("\n")]
        dns_line = False
        ip_pattern = re.compile("\d+\.\d+\.\d+\.\d+")
        for line in lines:
            if line.startswith("DNS Servers:"):
                dns_line = True
                servers.append(line.split()[-1])
            elif dns_line and ip_pattern.match(line):
                servers.append(line)
            else:
                dns_line = False


    return set(servers)


def get_result(output: str) -> str:
    for line in output.split("\n"):
        if "has address " in line:
            return line.split("has address ")[1]


def check_dns(hostname: str, expected_ip: str, resolvers: set) -> tuple:
    warnings = 0
    errors = 0

    for resolver in resolvers:
        status, output = subprocess.getstatusoutput(
            "host {} {}".format(hostname, resolver)
        )
        if status == 0:
            address = get_result(output)
            if address == expected_ip:
                u.success("dns({}): {} -> {}".format(resolver, hostname, address))
            else:
                u.error(
                    "dns({}): {} -> {} (should be {})".format(
                        resolver, hostname, address, expected_ip
                    )
                )
                errors += 1
        else:
            u.error("dns({}): cannot resolve {}".format(resolver, hostname))
            errors += 1

    return warnings, errors


def check_resolver(conf: dict, resolvers: set, ip: str) -> tuple:
    warnings = 0
    errors = 0

    conf_resolvers_keys = ("NETWORK_DNS1", "NETWORK_DNS2", "NETWORK_DNS3")
    for conf_resolver_key in conf_resolvers_keys:
        conf_resolver = conf.get(conf_resolver_key)
        if conf_resolver and conf_resolver not in resolvers:
            u.warning("resolver {} not configured".format(conf_resolver))
            warnings += 1

    if not ip and (not errors):
        u.info("no resolver defined in envsetup configuration, unable to test DNS")
        exit(2)

    return warnings, errors


def main():
    print("Check DNS settings:")

    warnings = 0
    errors = 0
    conf = u.load_conf()
    resolvers = get_dns_servers()
    ip = conf.get("NETWORK_IP_NAT") or conf.get("NETWORK_IP")

    check_resolver_warn, check_resolver_err = check_resolver(conf, resolvers, ip)
    if check_resolver_err:
        errors += check_resolver_err
    if check_resolver_warn:
        warnings += check_resolver_warn

    services_info = (
        ("MS_SERVER_NAME", "mediaserver", "ubicast-mediaserver"),
        ("MONITOR_SERVER_NAME", "monitor", "ubicast-monitor"),
        ("CM_SERVER_NAME", "mirismanager", "ubicast-skyreach"),
    )

    for conf_name, default_domain, package in services_info:
        domain = conf.get(conf_name)
        resolution_ignored = conf.get("TESTER_DNS_RESOLUTION_IGNORED", "").split(",")
        if (
            domain
            and domain not in ("localhost", default_domain)
            and domain not in resolution_ignored
        ):
            # check that the service is installed on this system
            status, _ = subprocess.getstatusoutput("dpkg -s {}".format(package))
            if status == 0:
                u.info("checking IP of {}".format(domain))
                check_dns_warn, check_dns_err = check_dns(domain, ip, resolvers)
                if check_dns_err:
                    errors += check_dns_err
                if check_dns_warn:
                    warnings += check_dns_warn
            else:
                u.info("{} not installed, skip {}".format(package, domain))

    if errors:
        exit(1)
    elif warnings:
        exit(3)


if __name__ == "__main__":
    main()