Don't populate the authority or additional sections in responses
[invirt/packages/invirt-dns.git] / invirt-dns
index de53e0d..ed0ab01 100755 (executable)
@@ -4,6 +4,7 @@ from twisted.names import server
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
+from twisted.names import resolve
 from twisted.internet import defer
 from twisted.python import failure
 
@@ -45,7 +46,7 @@ class DatabaseAuthority(common.ResolverBase):
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
-            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
+            except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
                 if i == 2:
                     raise
                 print "Reloading database"
@@ -127,7 +128,7 @@ class DatabaseAuthority(common.ResolverBase):
             if value:
                 ip = value.ip
             else:
-                value = invirt.database.Machine.query().filter_by(name=host).first()
+                value = invirt.database.Machine.query.filter_by(name=host).first()
                 if value:
                     ip = value.nics[0].ip
                 else:
@@ -202,7 +203,7 @@ class DelegatingQuotingBindAuthority(authority.BindAuthority):
         # check if it's within a subdomain we're supposed to delegate to
         # some other DNS server.
         while (isinstance(deferredResult.result, failure.Failure)
-               and name.find('.') != -1):
+               and '.' in name):
             maybeDelegate = True
             name = name[name.find('.') + 1 :]
             deferredResult = authority.BindAuthority._lookup(self, name, cls,
@@ -220,6 +221,25 @@ class DelegatingQuotingBindAuthority(authority.BindAuthority):
             deferredResult = defer.succeed(([], nsResults, nsAdditional))
         return deferredResult
 
+class TypeLenientResolverChain(resolve.ResolverChain):
+    """
+    This is a ResolverChain which is more lenient in its handling of
+    queries requesting unimplemented record types.
+    """
+
+    def query(self, query, timeout = None):
+        try:
+            return self.typeToMethod[query.type](str(query.name), timeout)
+        except KeyError, e:
+            # We don't support the requested record type.  Twisted would
+            # have us return SERVFAIL.  Instead, we'll check whether the
+            # name exists in our zone at all and return NXDOMAIN or an empty
+            # result set with NOERROR as appropriate.
+            deferredResult = self.lookupAllRecords(str(query.name), timeout)
+            if isinstance(deferredResult.result, failure.Failure):
+                return deferredResult
+            return defer.succeed(([], [], []))
+
 if '__main__' == __name__:
     resolvers = []
     try:
@@ -240,7 +260,8 @@ if '__main__' == __name__:
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
-    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
+    f = server.DNSServerFactory(verbose=verbosity)
+    f.resolver = TypeLenientResolverChain(resolvers)
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity