Workaround a bug in Twisted's zone file parsing.
[invirt/packages/invirt-dns.git] / invirt-dns
index c320e51..188d1ce 100755 (executable)
@@ -4,16 +4,19 @@ from twisted.names import server
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
+from twisted.names import resolve
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
+from invirt.database import NIC
 import psycopg2
 import sqlalchemy
 import time
 import re
+import sys
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
@@ -93,10 +96,12 @@ class DatabaseAuthority(common.ResolverBase):
         if name.endswith(".in-addr.arpa"):
             if type in (dns.PTR, dns.ALL_RECORDS):
                 ip = '.'.join(reversed(name.split('.')[:-2]))
-                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
                 if value and value.hostname:
                     hostname = value.hostname
                     if '.' not in hostname:
+                        if ip == value.other_ip:
+                            hostname = hostname + ".other"
                         hostname = hostname + "." + config.dns.domains[0]
                     record = dns.Record_PTR(hostname, ttl)
                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
@@ -108,7 +113,7 @@ class DatabaseAuthority(common.ResolverBase):
                                             ttl, self.soa, auth=True))
             # FIXME: Should only return success with no records if the name actually exists
 
-        elif name == domain or name == '.'+domain:
+        elif name == domain or name == '.'+domain or name == 'other.'+domain:
             if type in (dns.A, dns.ALL_RECORDS):
                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
@@ -123,13 +128,25 @@ class DatabaseAuthority(common.ResolverBase):
 
         else:
             host = name[:-len(domain)-1]
+            other = False
+            if host.endswith(".other"):
+                host = host[:-len(".other")]
+                other = True
             value = invirt.database.NIC.query.filter_by(hostname=host).first()
             if value:
-                ip = value.ip
+                if other:
+                    ip = value.other_ip
+                    action = value.other_action
+                else:
+                    ip = value.ip
             else:
                 value = invirt.database.Machine.query.filter_by(name=host).first()
                 if value:
-                    ip = value.nics[0].ip
+                    if other:
+                        ip = value.nics[0].other_ip
+                        action = value.nics[0].other_action
+                    else:
+                        ip = value.nics[0].ip
                 else:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
             if ip is None:
@@ -138,7 +155,11 @@ class DatabaseAuthority(common.ResolverBase):
                 record = dns.Record_A(ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
                                             ttl, record, auth=True))
-            elif type == dns.SOA:
+            if other and type in (dns.TXT, dns.ALL_RECORDS):
+                record = dns.Record_TXT(action if action else '', ttl=ttl)
+                results.append(dns.RRHeader(name, dns.TXT, dns.IN,
+                                            ttl, record, auth=True))
+            if type == dns.SOA:
                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                             ttl, self.soa, auth=True))
 
@@ -194,31 +215,56 @@ class DelegatingQuotingBindAuthority(authority.BindAuthority):
             L.append(split_line)
         return filter(None, L)
 
+    # See https://twistedmatrix.com/documents/13.1.0/api/twisted.internet.defer.html#inlineCallbacks
+    @defer.inlineCallbacks
     def _lookup(self, name, cls, type, timeout = None):
-        maybeDelegate = False
-        deferredResult = authority.BindAuthority._lookup(self, name, cls,
-                                                         type, timeout)
-        # If we didn't find an exact match for the name we were seeking,
-        # check if it's within a subdomain we're supposed to delegate to
-        # some other DNS server.
-        while (isinstance(deferredResult.result, failure.Failure)
-               and '.' in name):
-            maybeDelegate = True
-            name = name[name.find('.') + 1 :]
-            deferredResult = authority.BindAuthority._lookup(self, name, cls,
-                                                             dns.NS, timeout)
-        # If we found somewhere to delegate the query to, our _lookup()
-        # for the NS record resulted in it being in the 'results' section.
-        # We need to instead return that information in the 'authority'
-        # section to delegate, and return an empty 'results' section
-        # (because we didn't find the name we were asked about).  We
-        # leave the 'additional' section as we received it because it
-        # may contain A records for the DNS server we're delegating to.
-        if maybeDelegate and not isinstance(deferredResult.result,
-                                            failure.Failure):
-            (nsResults, nsAuthority, nsAdditional) = deferredResult.result
-            deferredResult = defer.succeed(([], nsResults, nsAdditional))
-        return deferredResult
+        try:
+            result = yield authority.BindAuthority._lookup(self, name, cls,
+                                                           type, timeout)
+            defer.returnValue(result)
+        except Exception as e:
+            # XXX: Twisted returns DomainError even if it is
+            # authoritative for the domain because our SOA record
+            # incorrectly contains (origin + "." + origin)
+            if not isinstance(e, (dns.DomainError, dns.AuthoritativeDomainError)):
+                sys.stderr.write("while looking up '%s', got: %s\n" % (name, e))
+
+            # If we didn't find an exact match for the name we were
+            # seeking, check if it's within a subdomain we're supposed
+            # to delegate to some other DNS server.
+            while '.' in name:
+                _, name = name.split('.', 1)
+                try:
+                    # BindAuthority puts the NS in the authority
+                    # section automatically for us, so just return
+                    # it. We override the type to NS.
+                    result = yield authority.BindAuthority._lookup(self, name, cls,
+                                                                   dns.NS, timeout)
+                    defer.returnValue(result)
+                except Exception: # Should be one of (dns.DomainError, dns.AuthoritativeDomainError)
+                    pass
+            # We didn't find a delegation, so return the original
+            # NXDOMAIN.
+            raise
+
+class TypeLenientResolverChain(resolve.ResolverChain):
+    """
+    This is a ResolverChain which is more lenient in its handling of
+    queries requesting unimplemented record types.
+    """
+
+    def query(self, query, timeout = None):
+        try:
+            return self.typeToMethod[query.type](str(query.name), timeout)
+        except KeyError, e:
+            # We don't support the requested record type.  Twisted would
+            # have us return SERVFAIL.  Instead, we'll check whether the
+            # name exists in our zone at all and return NXDOMAIN or an empty
+            # result set with NOERROR as appropriate.
+            deferredResult = self.lookupAllRecords(str(query.name), timeout)
+            if isinstance(deferredResult.result, failure.Failure):
+                return deferredResult
+            return defer.succeed(([], [], []))
 
 if '__main__' == __name__:
     resolvers = []
@@ -228,6 +274,9 @@ if '__main__' == __name__:
                 r = DelegatingQuotingBindAuthority(zone)
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
+                # XXX: This causes our SOA record to contain
+                # (origin + "." + origin)
+                # As a result the resolver never believes it is authoritative.
                 r.origin = origin
                 lines = open(zone).readlines()
                 lines = r.collapseContinuations(r.stripComments(lines))
@@ -240,7 +289,8 @@ if '__main__' == __name__:
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
-    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
+    f = server.DNSServerFactory(verbose=verbosity)
+    f.resolver = TypeLenientResolverChain(resolvers)
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity