Add TXT records in .other pseudo-domain to reveal the other_action value
[invirt/packages/invirt-dns.git] / invirt-dns
index de53e0d..2b920bd 100755 (executable)
@@ -4,12 +4,14 @@ from twisted.names import server
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
+from twisted.names import resolve
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
+from invirt.database import NIC
 import psycopg2
 import sqlalchemy
 import time
 import psycopg2
 import sqlalchemy
 import time
@@ -45,7 +47,7 @@ class DatabaseAuthority(common.ResolverBase):
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
-            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
+            except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
                 if i == 2:
                     raise
                 print "Reloading database"
                 if i == 2:
                     raise
                 print "Reloading database"
@@ -93,10 +95,12 @@ class DatabaseAuthority(common.ResolverBase):
         if name.endswith(".in-addr.arpa"):
             if type in (dns.PTR, dns.ALL_RECORDS):
                 ip = '.'.join(reversed(name.split('.')[:-2]))
         if name.endswith(".in-addr.arpa"):
             if type in (dns.PTR, dns.ALL_RECORDS):
                 ip = '.'.join(reversed(name.split('.')[:-2]))
-                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
                 if value and value.hostname:
                     hostname = value.hostname
                     if '.' not in hostname:
                 if value and value.hostname:
                     hostname = value.hostname
                     if '.' not in hostname:
+                        if ip == value.other_ip:
+                            hostname = hostname + ".other"
                         hostname = hostname + "." + config.dns.domains[0]
                     record = dns.Record_PTR(hostname, ttl)
                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                         hostname = hostname + "." + config.dns.domains[0]
                     record = dns.Record_PTR(hostname, ttl)
                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
@@ -108,7 +112,7 @@ class DatabaseAuthority(common.ResolverBase):
                                             ttl, self.soa, auth=True))
             # FIXME: Should only return success with no records if the name actually exists
 
                                             ttl, self.soa, auth=True))
             # FIXME: Should only return success with no records if the name actually exists
 
-        elif name == domain or name == '.'+domain:
+        elif name == domain or name == '.'+domain or name == 'other.'+domain:
             if type in (dns.A, dns.ALL_RECORDS):
                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
             if type in (dns.A, dns.ALL_RECORDS):
                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                 results.append(dns.RRHeader(name, dns.A, dns.IN,
@@ -123,13 +127,25 @@ class DatabaseAuthority(common.ResolverBase):
 
         else:
             host = name[:-len(domain)-1]
 
         else:
             host = name[:-len(domain)-1]
+            other = False
+            if host.endswith(".other"):
+                host = host[:-len(".other")]
+                other = True
             value = invirt.database.NIC.query.filter_by(hostname=host).first()
             if value:
             value = invirt.database.NIC.query.filter_by(hostname=host).first()
             if value:
-                ip = value.ip
+                if other:
+                    ip = value.other_ip
+                    action = value.other_action
+                else:
+                    ip = value.ip
             else:
             else:
-                value = invirt.database.Machine.query().filter_by(name=host).first()
+                value = invirt.database.Machine.query.filter_by(name=host).first()
                 if value:
                 if value:
-                    ip = value.nics[0].ip
+                    if other:
+                        ip = value.nics[0].other_ip
+                        action = value.nics[0].other_action
+                    else:
+                        ip = value.nics[0].ip
                 else:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
             if ip is None:
                 else:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
             if ip is None:
@@ -141,6 +157,10 @@ class DatabaseAuthority(common.ResolverBase):
             elif type == dns.SOA:
                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                             ttl, self.soa, auth=True))
             elif type == dns.SOA:
                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                             ttl, self.soa, auth=True))
+            elif other and type == dns.TXT:
+                record = dns.Record_TXT(action if action else '', ttl=ttl)
+                results.append(dns.RRHeader(name, dns.TXT, dns.IN,
+                                            ttl, record, auth=True))
 
         if len(results) == 0:
             authority = []
 
         if len(results) == 0:
             authority = []
@@ -202,24 +222,32 @@ class DelegatingQuotingBindAuthority(authority.BindAuthority):
         # check if it's within a subdomain we're supposed to delegate to
         # some other DNS server.
         while (isinstance(deferredResult.result, failure.Failure)
         # check if it's within a subdomain we're supposed to delegate to
         # some other DNS server.
         while (isinstance(deferredResult.result, failure.Failure)
-               and name.find('.') != -1):
+               and '.' in name):
             maybeDelegate = True
             name = name[name.find('.') + 1 :]
             deferredResult = authority.BindAuthority._lookup(self, name, cls,
                                                              dns.NS, timeout)
             maybeDelegate = True
             name = name[name.find('.') + 1 :]
             deferredResult = authority.BindAuthority._lookup(self, name, cls,
                                                              dns.NS, timeout)
-        # If we found somewhere to delegate the query to, our _lookup()
-        # for the NS record resulted in it being in the 'results' section.
-        # We need to instead return that information in the 'authority'
-        # section to delegate, and return an empty 'results' section
-        # (because we didn't find the name we were asked about).  We
-        # leave the 'additional' section as we received it because it
-        # may contain A records for the DNS server we're delegating to.
-        if maybeDelegate and not isinstance(deferredResult.result,
-                                            failure.Failure):
-            (nsResults, nsAuthority, nsAdditional) = deferredResult.result
-            deferredResult = defer.succeed(([], nsResults, nsAdditional))
         return deferredResult
 
         return deferredResult
 
+class TypeLenientResolverChain(resolve.ResolverChain):
+    """
+    This is a ResolverChain which is more lenient in its handling of
+    queries requesting unimplemented record types.
+    """
+
+    def query(self, query, timeout = None):
+        try:
+            return self.typeToMethod[query.type](str(query.name), timeout)
+        except KeyError, e:
+            # We don't support the requested record type.  Twisted would
+            # have us return SERVFAIL.  Instead, we'll check whether the
+            # name exists in our zone at all and return NXDOMAIN or an empty
+            # result set with NOERROR as appropriate.
+            deferredResult = self.lookupAllRecords(str(query.name), timeout)
+            if isinstance(deferredResult.result, failure.Failure):
+                return deferredResult
+            return defer.succeed(([], [], []))
+
 if '__main__' == __name__:
     resolvers = []
     try:
 if '__main__' == __name__:
     resolvers = []
     try:
@@ -240,7 +268,8 @@ if '__main__' == __name__:
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
-    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
+    f = server.DNSServerFactory(verbose=verbosity)
+    f.resolver = TypeLenientResolverChain(resolvers)
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity
     
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity