import dns.resolver  # pylint: disable=import-error
import utils


DOMAIN = "shore.co.il"
SUBDOMAINS = [DOMAIN, f"www.{DOMAIN}"]


def get_resolvers():
    """Return a list of resolvers for each nameserver in the DNS zone."""
    default_resolver = dns.resolver.Resolver()
    resolvers = []
    for ns in default_resolver.query(  # pylint: disable=invalid-name
        DOMAIN, "NS"
    ):
        for address in default_resolver.query(ns.to_text()):
            resolver = dns.resolver.Resolver(configure=False)
            resolver.nameservers = [address.to_text()]
            resolvers.append(resolver)
    return resolvers


RESOLVERS = get_resolvers()


def cross_query(qname, rdtype="A"):
    """Return all of the answers from all nameservers."""
    answers = set()
    for resolver in RESOLVERS:
        for address in resolver.query(qname, rdtype):
            answers.add(address)
    return answers


def validate_soa():
    """Validate the SOA record."""
    soas = cross_query(DOMAIN, "SOA")

    if len(soas) > 1:
        return [False, "SOA records don't match."]

    try:
        record = soas.pop()
        record.mname.to_text()
    except Exception as e:  # pylint: disable=broad-except,invalid-name
        print(str(e))
        return [False, "SOA record is invalid."]

    return [True, "SOA record validated."]


def validate_mx():
    """Validate the MX record."""
    mxs = cross_query(DOMAIN, "MX")
    if len(mxs) > 1:
        return [False, "MX records don't match."]

    try:
        record = mxs.pop()
        ips = cross_query(record.exchange.to_text())
        if len(ips) > 1:
            return [False, "MX records don't match."]
    except Exception as e:  # pylint: disable=broad-except,invalid-name
        print(str(e))
        return [False, "MX record is invalid."]

    return [True, "MX record validated."]


def validate_subdomains():
    """Validate important subdomains."""
    for domain in SUBDOMAINS:
        try:
            ips = cross_query(domain)
            if len(ips) > 1:
                return [True, f"Domain {domain} records don't match."]
        except Exception as e:  # pylint: disable=broad-except,invalid-name
            print(str(e))
            return [False, "Failed to validate domain {d}."]
    return [True, "Subdomains validated."]


def handler(event, context):  # pylint: disable=unused-argument
    """Lambda event handler."""
    for check in [validate_soa, validate_mx, validate_subdomains]:
        success, message = check.__call__()
        print(message)
        if not success:
            utils.publish(message)


if __name__ == "__main__":
    handler("event", "context")
