Include TXT records in ANY queries
[invirt/packages/invirt-dns.git] / invirt-dns
index 9bc051f..b540ff7 100755 (executable)
@@ -4,12 +4,14 @@ from twisted.names import server
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
 from twisted.names import dns
 from twisted.names import common
 from twisted.names import authority
+from twisted.names import resolve
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 from twisted.internet import defer
 from twisted.python import failure
 
 from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
+from invirt.database import NIC
 import psycopg2
 import sqlalchemy
 import time
 import psycopg2
 import sqlalchemy
 import time
@@ -45,7 +47,7 @@ class DatabaseAuthority(common.ResolverBase):
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
-            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
+            except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
                 if i == 2:
                     raise
                 print "Reloading database"
                 if i == 2:
                     raise
                 print "Reloading database"
@@ -63,15 +65,15 @@ class DatabaseAuthority(common.ResolverBase):
         if name in self.domains:
             domain = name
         else:
         if name in self.domains:
             domain = name
         else:
-            # Look for the longest-matching domain.  (This works because domain
-            # will remain bound after breaking out of the loop.)
+            # Look for the longest-matching domain.
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
                 if name.endswith('.in-addr.arpa'):
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
                 if name.endswith('.in-addr.arpa'):
-                    best_domain = name # Act authoritative for the IP address for reverse resolution requests
+                    # Act authoritative for the IP address for reverse resolution requests
+                    best_domain = name
                 else:
                     return defer.fail(failure.Failure(dns.DomainError(name)))
             domain = best_domain
                 else:
                     return defer.fail(failure.Failure(dns.DomainError(name)))
             domain = best_domain
@@ -81,67 +83,93 @@ class DatabaseAuthority(common.ResolverBase):
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
-        if cls == dns.IN:
-            host = name[:-len(domain)-1]
-            if not host and type != dns.PTR: # Request for the domain itself.
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
+        # The order of logic:
+        # - What class?
+        # - What domain: in-addr.arpa, domain root, or subdomain?
+        # - What query type: A, PTR, NS, ...?
+
+        if cls != dns.IN:
+            # Hahaha.  No.
+            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+
+        if name.endswith(".in-addr.arpa"):
+            if type in (dns.PTR, dns.ALL_RECORDS):
+                ip = '.'.join(reversed(name.split('.')[:-2]))
+                value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
+                if value and value.hostname:
+                    hostname = value.hostname
+                    if '.' not in hostname:
+                        if ip == value.other_ip:
+                            hostname = hostname + ".other"
+                        hostname = hostname + "." + config.dns.domains[0]
+                    record = dns.Record_PTR(hostname, ttl)
+                    results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                                                 ttl, record, auth=True))
                                                 ttl, record, auth=True))
-                elif type == dns.NS:
-                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
-                                                ttl, self.ns, auth=True))
-                    authority = []
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            else: # Request for a subdomain.
-                if name.endswith(".in-addr.arpa"): # Reverse resolution here
-                    if type in (dns.PTR, dns.ALL_RECORDS):
-                        ip = '.'.join(reversed(name.split('.')[:-2]))
-                        value = invirt.database.NIC.query.filter_by(ip=ip).first()
-                        if value and value.hostname:
-                            hostname = value.hostname
-                            if '.' not in hostname:
-                                hostname = hostname + "." + config.dns.domains[0]
-                            record = dns.Record_PTR(hostname, ttl)
-                            results.append(dns.RRHeader(name, dns.PTR, dns.IN,
-                                                        ttl, record, auth=True))
-                        else: # IP address doesn't point to an active host
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                    # FIXME: Should only return success with no records if the name actually exists
-                else: # Forward resolution here
-                    value = invirt.database.NIC.query.filter_by(hostname=host).first()
-                    if value:
-                        ip = value.ip
-                    else:
-                        value = invirt.database.Machine.query().filter_by(name=host).first()
-                        if value:
-                            ip = value.nics[0].ip
-                        else:
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                
-                    if ip is None:
-                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-
-                    if type in (dns.A, dns.ALL_RECORDS):
-                        record = dns.Record_A(ip, ttl)
-                        results.append(dns.RRHeader(name, dns.A, dns.IN, 
-                                                    ttl, record, auth=True))
-                    elif type == dns.SOA:
-                        results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                    ttl, self.soa, auth=True))
-            if len(results) == 0:
+                else: # IP address doesn't point to an active host
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+            # FIXME: Should only return success with no records if the name actually exists
+
+        elif name == domain or name == '.'+domain or name == 'other.'+domain:
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.NS:
+                results.append(dns.RRHeader(domain, dns.NS, dns.IN,
+                                            ttl, self.ns, auth=True))
                 authority = []
                 authority = []
-                additional = []
-            return defer.succeed((results, authority, additional))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
         else:
         else:
-            #Doesn't exist
-            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            host = name[:-len(domain)-1]
+            other = False
+            if host.endswith(".other"):
+                host = host[:-len(".other")]
+                other = True
+            value = invirt.database.NIC.query.filter_by(hostname=host).first()
+            if value:
+                if other:
+                    ip = value.other_ip
+                    action = value.other_action
+                else:
+                    ip = value.ip
+            else:
+                value = invirt.database.Machine.query.filter_by(name=host).first()
+                if value:
+                    if other:
+                        ip = value.nics[0].other_ip
+                        action = value.nics[0].other_action
+                    else:
+                        ip = value.nics[0].ip
+                else:
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if ip is None:
+                return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            if other and type in (dns.TXT, dns.ALL_RECORDS):
+                record = dns.Record_TXT(action if action else '', ttl=ttl)
+                results.append(dns.RRHeader(name, dns.TXT, dns.IN,
+                                            ttl, record, auth=True))
+            if type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
 
 
-class QuotingBindAuthority(authority.BindAuthority):
+        if len(results) == 0:
+            authority = []
+            additional = []
+        return defer.succeed((results, authority, additional))
+
+class DelegatingQuotingBindAuthority(authority.BindAuthority):
     """
     """
-    A BindAuthority that (almost) deals with quoting correctly
+    A delegating BindAuthority that (almost) deals with quoting correctly
     
     This will catch double quotes as marking the start or end of a
     quoted phrase, unless the double quote is escaped by a backslash
     
     This will catch double quotes as marking the start or end of a
     quoted phrase, unless the double quote is escaped by a backslash
@@ -186,12 +214,46 @@ class QuotingBindAuthority(authority.BindAuthority):
             L.append(split_line)
         return filter(None, L)
 
             L.append(split_line)
         return filter(None, L)
 
+    def _lookup(self, name, cls, type, timeout = None):
+        maybeDelegate = False
+        deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                         type, timeout)
+        # If we didn't find an exact match for the name we were seeking,
+        # check if it's within a subdomain we're supposed to delegate to
+        # some other DNS server.
+        while (isinstance(deferredResult.result, failure.Failure)
+               and '.' in name):
+            maybeDelegate = True
+            name = name[name.find('.') + 1 :]
+            deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                             dns.NS, timeout)
+        return deferredResult
+
+class TypeLenientResolverChain(resolve.ResolverChain):
+    """
+    This is a ResolverChain which is more lenient in its handling of
+    queries requesting unimplemented record types.
+    """
+
+    def query(self, query, timeout = None):
+        try:
+            return self.typeToMethod[query.type](str(query.name), timeout)
+        except KeyError, e:
+            # We don't support the requested record type.  Twisted would
+            # have us return SERVFAIL.  Instead, we'll check whether the
+            # name exists in our zone at all and return NXDOMAIN or an empty
+            # result set with NOERROR as appropriate.
+            deferredResult = self.lookupAllRecords(str(query.name), timeout)
+            if isinstance(deferredResult.result, failure.Failure):
+                return deferredResult
+            return defer.succeed(([], [], []))
+
 if '__main__' == __name__:
     resolvers = []
     try:
         for zone in config.dns.zone_files:
             for origin in config.dns.domains:
 if '__main__' == __name__:
     resolvers = []
     try:
         for zone in config.dns.zone_files:
             for origin in config.dns.domains:
-                r = QuotingBindAuthority(zone)
+                r = DelegatingQuotingBindAuthority(zone)
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
                 r.origin = origin
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
                 r.origin = origin
@@ -206,7 +268,8 @@ if '__main__' == __name__:
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
-    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
+    f = server.DNSServerFactory(verbose=verbosity)
+    f.resolver = TypeLenientResolverChain(resolvers)
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity
     
     p = dns.DNSDatagramProtocol(f)
     f.noisy = p.noisy = verbosity