Skip to content
Snippets Groups Projects
test_dns_records.py 5.67 KiB
#!/usr/bin/env python3

"""
Criticality: Normal
Checks that DNS records are provided by the customer servers are correctly set
"""

from pathlib import Path
import re
import subprocess
import sys
import dns.resolver

try:
    import pydbus
except ImportError:
    exit(2)

sys.path.append(str(Path(__file__).parents[1].resolve()))

from utilities import logging as lg  # noqa: E402
from utilities.config import load_conf  # noqa: E402
from utilities.os import supported_platform  # noqa: E402


def get_dns_servers() -> set:
    servers = list()
    ip_pattern = re.compile(r"^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$")

    # dbus method
    try:
        bus = pydbus.SystemBus()
        bus_client = bus.get("org.freedesktop.resolve1", "/org/freedesktop/resolve1")
        servers.extend(
            [".".join(map(str, dns[2])) for dns in bus_client.DNS if dns[1] == 2]
        )  # IPv4
        servers.extend(
            [":".join(map(str, dns[2])) for dns in bus_client.DNS if dns[1] == 10]
        )  # IPv6
    except Exception:
        pass

    # network-manager method
    if not len(servers) and subprocess.getstatusoutput("command -v nmcli")[0] == 0:
        _, output = subprocess.getstatusoutput(
            "nmcli -f all device show | grep IP4.DNS | awk '{ print $2 }'"
        )
        servers = [line for line in output.split("\n") if ip_pattern.match(line)]

    # resolvconf method
    if not len(servers) and Path("/etc/resolv.conf").exists():
        with open("/etc/resolv.conf", "r") as fo:
            content = fo.read().strip()
            servers = [
                line.split()[1] for line in content.split("\n") if line.startswith("nameserver")
            ]

    # systemd-resolved method
    if "127.0.0.53" in servers:
        servers.remove("127.0.0.53")
        _, output = subprocess.getstatusoutput("systemd-resolve --status")
        lines = [line.strip() for line in output.split("\n")]
        dns_line = False
        for line in lines:
            if line.startswith("DNS Servers:"):
                dns_line = True
                servers.append(line.split()[-1])
            elif dns_line and ip_pattern.match(line):
                servers.append(line)
            else:
                dns_line = False

    return set(servers)


def get_result(output: str) -> str:
    for line in output.split("\n"):
        if "has address " in line:
            return line.split("has address ")[1]


def check_dns(hostname: str, expected_ip: str, resolvers: set) -> tuple:
    warnings = 0
    errors = 0

    resolver = dns.resolver.Resolver(configure=False)
    resolver.nameservers = list(resolvers)
    try:
        answers = [rdata.address for rdata in resolver.query(hostname)]
    except Exception as dns_err:
        lg.error("cannot resolve {}: {}".format(hostname, dns_err))
        errors += 1
    else:
        for address in answers:
            if address == expected_ip:
                lg.success("{}".format(address))
            else:
                lg.error("{} instead of {}".format(address, expected_ip))
                errors += 1

    return warnings, errors


def check_resolver(conf: dict, resolvers: set) -> tuple:
    warnings = 0
    errors = 0
    resolver_set = False

    conf_resolvers = conf.get("NETWORK_DNS").split(",")
    if not conf_resolvers:
        # backward compatibility
        conf_resolvers = [
            conf.get(r)
            for r in ("NETWORK_DNS1", "NETWORK_DNS2", "NETWORK_DNS3")
            if conf.get(r)
        ]
    for conf_resolver in conf_resolvers:
        if conf_resolver:
            resolver_set = True
            if conf_resolver not in resolvers:
                lg.warning("resolver {} not configured".format(conf_resolver))
                warnings += 1
            else:
                lg.success("resolver {} configured".format(conf_resolver))

    if not resolver_set:
        lg.info("no resolver defined in envsetup")
        exit(2)

    return warnings, errors


def main():
    print("Check DNS settings:")

    if not supported_platform():
        lg.info("platform not supported")
        exit(2)

    warnings = 0
    errors = 0
    conf = load_conf()
    resolvers = get_dns_servers()
    ip = conf.get("NETWORK_IP_NAT") or conf.get("NETWORK_IP")

    check_resolver_warn, check_resolver_err = check_resolver(conf, resolvers)
    if check_resolver_err:
        errors += check_resolver_err
    if check_resolver_warn:
        warnings += check_resolver_warn

    services_info = (
        ("MS_SERVER_NAME", "mediaserver", "ubicast-mediaserver"),
        ("MONITOR_SERVER_NAME", "monitor", "ubicast-monitor"),
        ("CM_SERVER_NAME", "mirismanager", "ubicast-skyreach"),
    )

    if not ip:
        lg.info("no ip address defined in envsetup")
        exit(2)

    for conf_name, default_domain, package in services_info:
        domain = conf.get(conf_name)
        resolution_ignored = conf.get("TESTER_DNS_RESOLUTION_IGNORED", "").split(",")
        if (
            domain
            and domain not in ("localhost", default_domain)
            and domain not in resolution_ignored
        ):
            # check that the service is installed on this system
            status, _ = subprocess.getstatusoutput("dpkg -s {}".format(package))
            if status == 0 and ip:
                lg.info("resolving {}".format(domain))
                check_dns_warn, check_dns_err = check_dns(domain, ip, resolvers)
                if check_dns_err:
                    errors += check_dns_err
                if check_dns_warn:
                    warnings += check_dns_warn
            else:
                lg.info("{} not installed, skip {}".format(package, domain))

    if errors:
        exit(1)
    elif warnings:
        exit(3)


if __name__ == "__main__":
    main()