from functools import lru_cache
import dns.resolver  # pylint: disable=import-error
import dns.query  # pylint: disable=import-error
import dns.zone  # pylint: disable=import-error
from utils import Check


class CheckDNS(Check):  # pylint: disable=abstract-method
    _domain = "shore.co.il"
    _subdomains = [_domain, f"www.{_domain}"]

    @lru_cache(3)
    def _get_resolvers(self):
        """Return a list of resolvers for each nameserver in the DNS zone."""
        default_resolver = dns.resolver.Resolver()
        resolvers = []
        for ns in default_resolver.query(  # pylint: disable=invalid-name
            self._domain, "NS"
        ):
            for address in default_resolver.query(ns.to_text()):
                resolver = dns.resolver.Resolver(configure=False)
                resolver.nameservers = [address.to_text()]
                resolvers.append(resolver)
        return resolvers

    def _cross_query(self, qname, rdtype="A"):
        """Return all of the answers from all nameservers."""
        answers = set()
        for resolver in self._resolvers:
            for address in resolver.query(qname, rdtype):
                answers.add(address)
        return answers

    def __init__(self):
        self._resolvers = self._get_resolvers()


class CheckSOA(CheckDNS):
    def _check(self):
        """Validate the SOA record."""
        soas = self._cross_query(self._domain, "SOA")

        if len(soas) > 1:
            return False, "SOA records don't match."

        try:
            record = soas.pop()
            record.mname.to_text()
        except Exception as e:  # pylint: disable=broad-except,invalid-name
            return False, str(e)

        return True, "SOA record validated."


class CheckMX(CheckDNS):
    def _check(self):
        """Validate the MX record."""
        mxs = self._cross_query(self._domain, "MX")
        if len(mxs) > 1:
            return False, "MX records don't match."

        try:
            record = mxs.pop()
            ips = self._cross_query(record.exchange.to_text())
            if len(ips) > 1:
                return False, "MX records don't match."
        except Exception as e:  # pylint: disable=broad-except,invalid-name
            return [False, str(e)]

        return True, "MX record validated."


class CheckSubDomains(CheckDNS):
    def _check(self):
        """Validate important subdomains."""
        for domain in self._subdomains:
            try:
                ips = self._cross_query(domain)
                if len(ips) > 1:
                    return [True, f"Domain {domain} records don't match."]
            except Exception as e:  # pylint: disable=broad-except,invalid-name
                return [False, str(e)]
        return True, "Subdomains validated."


class CheckTransfer(CheckDNS):
    def _check(self):
        """Validate zone transfer, to check the TCP transport."""
        zones = []
        for resolver in self._resolvers:
            try:
                zone = dns.zone.from_xfr(
                    dns.query.xfr(resolver.nameservers[0], self._domain)
                )
                zones.append(zone)
            except Exception as e:  # pylint: disable=broad-except,invalid-name
                return False, str(e)

        return True, "Zone transfer validated."


def handler(event, context):  # pylint: disable=unused-argument
    """Lambda event handler."""
    CheckSOA().run()
    CheckMX().run()
    CheckSubDomains().run()
    # CheckTransfer().run()  # AXFR is blocked on ns1, need to figure out why.


if __name__ == "__main__":
    handler("event", "context")
