Teach the DNS server how to delegate
[invirt/packages/invirt-dns.git] / invirt-dns
index 976a9a7..de53e0d 100755 (executable)
@@ -7,11 +7,13 @@ from twisted.names import authority
 from twisted.internet import defer
 from twisted.python import failure
 
 from twisted.internet import defer
 from twisted.python import failure
 
+from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 import psycopg2
 import sqlalchemy
 import time
 from invirt.config import structs as config
 import invirt.database
 import psycopg2
 import sqlalchemy
 import time
+import re
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
@@ -61,14 +63,17 @@ class DatabaseAuthority(common.ResolverBase):
         if name in self.domains:
             domain = name
         else:
         if name in self.domains:
             domain = name
         else:
-            # Look for the longest-matching domain.  (This works because domain
-            # will remain bound after breaking out of the loop.)
+            # Look for the longest-matching domain.
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
-                return defer.fail(failure.Failure(dns.DomainError(name)))
+                if name.endswith('.in-addr.arpa'):
+                    # Act authoritative for the IP address for reverse resolution requests
+                    best_domain = name
+                else:
+                    return defer.fail(failure.Failure(dns.DomainError(name)))
             domain = best_domain
         results = []
         authority = []
             domain = best_domain
         results = []
         authority = []
@@ -76,56 +81,162 @@ class DatabaseAuthority(common.ResolverBase):
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
-        if cls == dns.IN:
-            host = name[:-len(domain)-1]
-            if not host: # Request for the domain itself.
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
+        # The order of logic:
+        # - What class?
+        # - What domain: in-addr.arpa, domain root, or subdomain?
+        # - What query type: A, PTR, NS, ...?
+
+        if cls != dns.IN:
+            # Hahaha.  No.
+            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+
+        if name.endswith(".in-addr.arpa"):
+            if type in (dns.PTR, dns.ALL_RECORDS):
+                ip = '.'.join(reversed(name.split('.')[:-2]))
+                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                if value and value.hostname:
+                    hostname = value.hostname
+                    if '.' not in hostname:
+                        hostname = hostname + "." + config.dns.domains[0]
+                    record = dns.Record_PTR(hostname, ttl)
+                    results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                                                 ttl, record, auth=True))
                                                 ttl, record, auth=True))
-                elif type == dns.NS:
-                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
-                                                ttl, self.ns, auth=True))
-                    authority = []
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            else: # Request for a subdomain.
-                value = invirt.database.Machine.query().filter_by(name=host).first()
-                if value is None or not value.nics:
-                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                ip = value.nics[0].ip
-                if ip is None:  #Deactivated?
+                else: # IP address doesn't point to an active host
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+            # FIXME: Should only return success with no records if the name actually exists
 
 
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
-                                                ttl, record, auth=True))
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            if len(results) == 0:
+        elif name == domain or name == '.'+domain:
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.NS:
+                results.append(dns.RRHeader(domain, dns.NS, dns.IN,
+                                            ttl, self.ns, auth=True))
                 authority = []
                 authority = []
-                additional = []
-            return defer.succeed((results, authority, additional))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
         else:
         else:
-            #Doesn't exist
-            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            host = name[:-len(domain)-1]
+            value = invirt.database.NIC.query.filter_by(hostname=host).first()
+            if value:
+                ip = value.ip
+            else:
+                value = invirt.database.Machine.query().filter_by(name=host).first()
+                if value:
+                    ip = value.nics[0].ip
+                else:
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if ip is None:
+                return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
+        if len(results) == 0:
+            authority = []
+            additional = []
+        return defer.succeed((results, authority, additional))
+
+class DelegatingQuotingBindAuthority(authority.BindAuthority):
+    """
+    A delegating BindAuthority that (almost) deals with quoting correctly
+    
+    This will catch double quotes as marking the start or end of a
+    quoted phrase, unless the double quote is escaped by a backslash
+    """
+    # Match either a quoted or unquoted string literal followed by
+    # whitespace or the end of line.  This yields two groups, one of
+    # which has a match, and the other of which is None, depending on
+    # whether the string literal was quoted or unquoted; this is what
+    # necessitates the subsequent filtering out of groups that are
+    # None.
+    string_pat = \
+            re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
+
+    # For interpreting escapes.
+    escape_pat = re.compile(r'\\(.)')
+
+    def collapseContinuations(self, lines):
+        L = []
+        state = 0
+        for line in lines:
+            if state == 0:
+                if line.find('(') == -1:
+                    L.append(line)
+                else:
+                    L.append(line[:line.find('(')])
+                    state = 1
+            else:
+                if line.find(')') != -1:
+                    L[-1] += ' ' + line[:line.find(')')]
+                    state = 0
+                else:
+                    L[-1] += ' ' + line
+        lines = L
+        L = []
+
+        for line in lines:
+            in_quote = False
+            split_line = []
+            for m in self.string_pat.finditer(line):
+                [x] = [x for x in m.groups() if x is not None]
+                split_line.append(self.escape_pat.sub(r'\1', x))
+            L.append(split_line)
+        return filter(None, L)
+
+    def _lookup(self, name, cls, type, timeout = None):
+        maybeDelegate = False
+        deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                         type, timeout)
+        # If we didn't find an exact match for the name we were seeking,
+        # check if it's within a subdomain we're supposed to delegate to
+        # some other DNS server.
+        while (isinstance(deferredResult.result, failure.Failure)
+               and name.find('.') != -1):
+            maybeDelegate = True
+            name = name[name.find('.') + 1 :]
+            deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                             dns.NS, timeout)
+        # If we found somewhere to delegate the query to, our _lookup()
+        # for the NS record resulted in it being in the 'results' section.
+        # We need to instead return that information in the 'authority'
+        # section to delegate, and return an empty 'results' section
+        # (because we didn't find the name we were asked about).  We
+        # leave the 'additional' section as we received it because it
+        # may contain A records for the DNS server we're delegating to.
+        if maybeDelegate and not isinstance(deferredResult.result,
+                                            failure.Failure):
+            (nsResults, nsAuthority, nsAdditional) = deferredResult.result
+            deferredResult = defer.succeed(([], nsResults, nsAdditional))
+        return deferredResult
 
 if '__main__' == __name__:
     resolvers = []
 
 if '__main__' == __name__:
     resolvers = []
-    for zone in config.dns.zone_files:
-        for origin in config.dns.domains:
-            r = authority.BindAuthority(zone)
-            # This sucks, but if I want a generic zone file, I have to
-            # reload the information by hand
-            r.origin = origin
-            lines = open(zone).readlines()
-            lines = r.collapseContinuations(r.stripComments(lines))
-            r.parseLines(lines)
-            
-            resolvers.append(r)
+    try:
+        for zone in config.dns.zone_files:
+            for origin in config.dns.domains:
+                r = DelegatingQuotingBindAuthority(zone)
+                # This sucks, but if I want a generic zone file, I have to
+                # reload the information by hand
+                r.origin = origin
+                lines = open(zone).readlines()
+                lines = r.collapseContinuations(r.stripComments(lines))
+                r.parseLines(lines)
+                
+                resolvers.append(r)
+    except InvirtConfigError:
+        # Don't care if zone_files isn't defined
+        pass
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0