Workaround a bug in Twisted's zone file parsing.
[invirt/packages/invirt-dns.git] / invirt-dns
index de53e0d..188d1ce 100755 (executable)
@@ -4,16 +4,19 @@ from twisted.names import server
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
+from twisted.names import resolve
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
+from invirt.database import NIC
 import psycopg2
 import sqlalchemy
 import time
 import re
 import psycopg2
 import sqlalchemy
 import time
 import re
+import sys
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
@@ -45,7 +48,7 @@ class DatabaseAuthority(common.ResolverBase):
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
-            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
+            except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
                 if i == 2:
                     raise
                 print "Reloading database"
                 if i == 2:
                     raise
                 print "Reloading database"
@@ -93,10 +96,12 @@ class DatabaseAuthority(common.ResolverBase):
         if name.endswith(".in-addr.arpa"):
             if type in (dns.PTR, dns.ALL_RECORDS):
                 ip = '.'.join(reversed(name.split('.')[:-2]))
         if name.endswith(".in-addr.arpa"):
             if type in (dns.PTR, dns.ALL_RECORDS):
                 ip = '.'.join(reversed(name.split('.')[:-2]))
-                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
                 if value and value.hostname:
                     hostname = value.hostname
                     if '.' not in hostname:
                 if value and value.hostname:
                     hostname = value.hostname
                     if '.' not in hostname:
+                        if ip == value.other_ip:
+                            hostname = hostname + ".other"
                         hostname = hostname + "." + config.dns.domains[0]
                     record = dns.Record_PTR(hostname, ttl)
                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                         hostname = hostname + "." + config.dns.domains[0]
                     record = dns.Record_PTR(hostname, ttl)
                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
@@ -108,7 +113,7 @@ class DatabaseAuthority(common.ResolverBase):
                                             ttl, self.soa, auth=True))
             # FIXME: Should only return success with no records if the name actually exists
 
                                             ttl, self.soa, auth=True))
             # FIXME: Should only return success with no records if the name actually exists
 
-        elif name == domain or name == '.'+domain:
+        elif name == domain or name == '.'+domain or name == 'other.'+domain:
             if type in (dns.A, dns.ALL_RECORDS):
                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
             if type in (dns.A, dns.ALL_RECORDS):
                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
@@ -123,13 +128,25 @@ class DatabaseAuthority(common.ResolverBase):
 
         else:
             host = name[:-len(domain)-1]
 
         else:
             host = name[:-len(domain)-1]
+            other = False
+            if host.endswith(".other"):
+                host = host[:-len(".other")]
+                other = True
             value = invirt.database.NIC.query.filter_by(hostname=host).first()
             if value:
             value = invirt.database.NIC.query.filter_by(hostname=host).first()
             if value:
-                ip = value.ip
+                if other:
+                    ip = value.other_ip
+                    action = value.other_action
+                else:
+                    ip = value.ip
             else:
             else:
-                value = invirt.database.Machine.query().filter_by(name=host).first()
+                value = invirt.database.Machine.query.filter_by(name=host).first()
                 if value:
                 if value:
-                    ip = value.nics[0].ip
+                    if other:
+                        ip = value.nics[0].other_ip
+                        action = value.nics[0].other_action
+                    else:
+                        ip = value.nics[0].ip
                 else:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
             if ip is None:
                 else:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
             if ip is None:
@@ -138,7 +155,11 @@ class DatabaseAuthority(common.ResolverBase):
                 record = dns.Record_A(ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
                                             ttl, record, auth=True))
                 record = dns.Record_A(ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
                                             ttl, record, auth=True))
-            elif type == dns.SOA:
+            if other and type in (dns.TXT, dns.ALL_RECORDS):
+                record = dns.Record_TXT(action if action else '', ttl=ttl)
+                results.append(dns.RRHeader(name, dns.TXT, dns.IN,
+                                            ttl, record, auth=True))
+            if type == dns.SOA:
                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                             ttl, self.soa, auth=True))
 
                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                             ttl, self.soa, auth=True))
 
@@ -194,31 +215,56 @@ class DelegatingQuotingBindAuthority(authority.BindAuthority):
             L.append(split_line)
         return filter(None, L)
 
             L.append(split_line)
         return filter(None, L)
 
+    # See https://twistedmatrix.com/documents/13.1.0/api/twisted.internet.defer.html#inlineCallbacks
+    @defer.inlineCallbacks
     def _lookup(self, name, cls, type, timeout = None):
     def _lookup(self, name, cls, type, timeout = None):
-        maybeDelegate = False
-        deferredResult = authority.BindAuthority._lookup(self, name, cls,
-                                                         type, timeout)
-        # If we didn't find an exact match for the name we were seeking,
-        # check if it's within a subdomain we're supposed to delegate to
-        # some other DNS server.
-        while (isinstance(deferredResult.result, failure.Failure)
-               and name.find('.') != -1):
-            maybeDelegate = True
-            name = name[name.find('.') + 1 :]
-            deferredResult = authority.BindAuthority._lookup(self, name, cls,
-                                                             dns.NS, timeout)
-        # If we found somewhere to delegate the query to, our _lookup()
-        # for the NS record resulted in it being in the 'results' section.
-        # We need to instead return that information in the 'authority'
-        # section to delegate, and return an empty 'results' section
-        # (because we didn't find the name we were asked about).  We
-        # leave the 'additional' section as we received it because it
-        # may contain A records for the DNS server we're delegating to.
-        if maybeDelegate and not isinstance(deferredResult.result,
-                                            failure.Failure):
-            (nsResults, nsAuthority, nsAdditional) = deferredResult.result
-            deferredResult = defer.succeed(([], nsResults, nsAdditional))
-        return deferredResult
+        try:
+            result = yield authority.BindAuthority._lookup(self, name, cls,
+                                                           type, timeout)
+            defer.returnValue(result)
+        except Exception as e:
+            # XXX: Twisted returns DomainError even if it is
+            # authoritative for the domain because our SOA record
+            # incorrectly contains (origin + "." + origin)
+            if not isinstance(e, (dns.DomainError, dns.AuthoritativeDomainError)):
+                sys.stderr.write("while looking up '%s', got: %s\n" % (name, e))
+
+            # If we didn't find an exact match for the name we were
+            # seeking, check if it's within a subdomain we're supposed
+            # to delegate to some other DNS server.
+            while '.' in name:
+                _, name = name.split('.', 1)
+                try:
+                    # BindAuthority puts the NS in the authority
+                    # section automatically for us, so just return
+                    # it. We override the type to NS.
+                    result = yield authority.BindAuthority._lookup(self, name, cls,
+                                                                   dns.NS, timeout)
+                    defer.returnValue(result)
+                except Exception: # Should be one of (dns.DomainError, dns.AuthoritativeDomainError)
+                    pass
+            # We didn't find a delegation, so return the original
+            # NXDOMAIN.
+            raise
+
+class TypeLenientResolverChain(resolve.ResolverChain):
+    """
+    This is a ResolverChain which is more lenient in its handling of
+    queries requesting unimplemented record types.
+    """
+
+    def query(self, query, timeout = None):
+        try:
+            return self.typeToMethod[query.type](str(query.name), timeout)
+        except KeyError, e:
+            # We don't support the requested record type.  Twisted would
+            # have us return SERVFAIL.  Instead, we'll check whether the
+            # name exists in our zone at all and return NXDOMAIN or an empty
+            # result set with NOERROR as appropriate.
+            deferredResult = self.lookupAllRecords(str(query.name), timeout)
+            if isinstance(deferredResult.result, failure.Failure):
+                return deferredResult
+            return defer.succeed(([], [], []))
 
 if '__main__' == __name__:
     resolvers = []
 
 if '__main__' == __name__:
     resolvers = []
@@ -228,6 +274,9 @@ if '__main__' == __name__:
                 r = DelegatingQuotingBindAuthority(zone)
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
                 r = DelegatingQuotingBindAuthority(zone)
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
+                # XXX: This causes our SOA record to contain
+                # (origin + "." + origin)
+                # As a result the resolver never believes it is authoritative.
                 r.origin = origin
                 lines = open(zone).readlines()
                 lines = r.collapseContinuations(r.stripComments(lines))
                 r.origin = origin
                 lines = open(zone).readlines()
                 lines = r.collapseContinuations(r.stripComments(lines))
@@ -240,7 +289,8 @@ if '__main__' == __name__:
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
-    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
+    f = server.DNSServerFactory(verbose=verbosity)
+    f.resolver = TypeLenientResolverChain(resolvers)
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity
     
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity