diff --git a/src/_dns.py b/src/_dns.py new file mode 100644 index 0000000000000000000000000000000000000000..47486589818e89109831ee17aeab2b338f2452c1 --- /dev/null +++ b/src/_dns.py @@ -0,0 +1,89 @@ +# pylint: disable=import-error,invalid-name +import dns.resolver +import utils + + +DOMAIN = "shore.co.il" +SUBDOMAINS = [DOMAIN, f"www.{DOMAIN}"] + + +def get_resolvers(): + """Return a list of resolvers for each nameserver in the DNS zone.""" + default_resolver = dns.resolver.Resolver() + resolvers = [] + for ns in default_resolver.query(DOMAIN, "NS"): + for address in default_resolver.query(ns.to_text()): + resolver = dns.resolver.Resolver(configure=False) + resolver.nameservers = [address.to_text()] + resolvers.append(resolver) + return resolvers + + +RESOLVERS = get_resolvers() + + +def cross_query(qname, rdtype="A"): + """Return all of the answers from all nameservers.""" + answers = set() + for r in RESOLVERS: + for a in r.query(qname, rdtype): + answers.add(a) + return answers + + +def validate_soa(): + """Validate the SOA record.""" + soas = cross_query(DOMAIN, "SOA") + + if len(soas) > 1: + return [False, "SOA records don't match."] + + try: + r = soas.pop() + r.mname.to_text() + except Exception: # pylint: disable=broad-except + return [False, "SOA record is invalid."] + + return [True, "SOA record validated."] + + +def validate_mx(): + """Validate the MX record.""" + mxs = cross_query(DOMAIN, "MX") + if len(mxs) > 1: + return [False, "MX records don't match."] + + try: + r = mxs.pop() + ips = cross_query(r.exchange.to_text()) + if len(ips) > 1: + return [False, "MX records don't match."] + except Exception: # pylint: disable=broad-except + return [False, "MX record is invalid."] + + return [True, "MX record validated."] + + +def validate_subdomains(): + """Validate important subdomains.""" + for d in SUBDOMAINS: + try: + ips = cross_query(d) + if len(ips) > 1: + return [True, f"Domain {d} records don't match."] + except Exception: # pylint: disable=broad-except + return [False, "Failed to validate domain {d}."] + return [True, "Subdomains validated."] + + +def handler(event, context): # pylint: disable=unused-argument + """Lambda event handler.""" + for c in [validate_soa, validate_mx, validate_subdomains]: + success, message = c.__call__() + print(message) + if not success: + utils.publish(message) + + +if __name__ == "__main__": + handler("event", "context") diff --git a/src/function.py b/src/function.py deleted file mode 100644 index 71761090cc486b504e98235c6aeceeda29cda88a..0000000000000000000000000000000000000000 --- a/src/function.py +++ /dev/null @@ -1,3 +0,0 @@ -def handler(event, context): # pylint: disable=unused-argument - """Lambda event handler.""" - pass # pylint: disable=unnecessary-pass diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbc3ccd159bbc55dba9c351347174daf4197907 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,12 @@ +# pylint: disable=import-error +import os +import boto3 + + +TOPIC_ARN = os.getenv("TOPIC_ARN") + + +def publish(message): + """Publish an SNS message.""" + client = boto3.client("sns") + client.publish(TopicArn=TOPIC_ARN, Message=message)