sqlalchemy api changes
[invirt/packages/invirt-dns.git] / invirt-dns
index 32b18c2..5b24d31 100755 (executable)
@@ -45,7 +45,7 @@ class DatabaseAuthority(common.ResolverBase):
         for i in range(3):
             try:
                 value = self._lookup_unsafe(name, cls, type, timeout = None)
-            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
+            except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
                 if i == 2:
                     raise
                 print "Reloading database"
@@ -81,67 +81,75 @@ class DatabaseAuthority(common.ResolverBase):
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
-        if cls == dns.IN:
-            host = name[:-len(domain)-1]
-            if not host and type != dns.PTR: # Request for the domain itself.
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
+        # The order of logic:
+        # - What class?
+        # - What domain: in-addr.arpa, domain root, or subdomain?
+        # - What query type: A, PTR, NS, ...?
+
+        if cls != dns.IN:
+            # Hahaha.  No.
+            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+
+        if name.endswith(".in-addr.arpa"):
+            if type in (dns.PTR, dns.ALL_RECORDS):
+                ip = '.'.join(reversed(name.split('.')[:-2]))
+                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                if value and value.hostname:
+                    hostname = value.hostname
+                    if '.' not in hostname:
+                        hostname = hostname + "." + config.dns.domains[0]
+                    record = dns.Record_PTR(hostname, ttl)
+                    results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                                                 ttl, record, auth=True))
-                elif type == dns.NS:
-                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
-                                                ttl, self.ns, auth=True))
-                    authority = []
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            else: # Request for a subdomain.
-                if name.endswith(".in-addr.arpa"): # Reverse resolution here
-                    if type in (dns.PTR, dns.ALL_RECORDS):
-                        ip = '.'.join(reversed(name.split('.')[:-2]))
-                        value = invirt.database.NIC.query.filter_by(ip=ip).first()
-                        if value and value.hostname:
-                            hostname = value.hostname
-                            if '.' not in hostname:
-                                hostname = hostname + "." + config.dns.domains[0]
-                            record = dns.Record_PTR(hostname, ttl)
-                            results.append(dns.RRHeader(name, dns.PTR, dns.IN,
-                                                        ttl, record, auth=True))
-                        else: # IP address doesn't point to an active host
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                    # FIXME: Should only return success with no records if the name actually exists
-                else: # Forward resolution here
-                    value = invirt.database.NIC.query.filter_by(hostname=host).first()
-                    if value:
-                        ip = value.ip
-                    else:
-                        value = invirt.database.Machine.query().filter_by(name=host).first()
-                        if value:
-                            ip = value.nics[0].ip
-                        else:
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                
-                    if ip is None:
-                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-
-                    if type in (dns.A, dns.ALL_RECORDS):
-                        record = dns.Record_A(ip, ttl)
-                        results.append(dns.RRHeader(name, dns.A, dns.IN, 
-                                                    ttl, record, auth=True))
-                    elif type == dns.SOA:
-                        results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                    ttl, self.soa, auth=True))
-            if len(results) == 0:
+                else: # IP address doesn't point to an active host
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+            # FIXME: Should only return success with no records if the name actually exists
+
+        elif name == domain or name == '.'+domain:
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.NS:
+                results.append(dns.RRHeader(domain, dns.NS, dns.IN,
+                                            ttl, self.ns, auth=True))
                 authority = []
-                additional = []
-            return defer.succeed((results, authority, additional))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
         else:
-            #Doesn't exist
-            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            host = name[:-len(domain)-1]
+            value = invirt.database.NIC.query.filter_by(hostname=host).first()
+            if value:
+                ip = value.ip
+            else:
+                value = invirt.database.Machine.query().filter_by(name=host).first()
+                if value:
+                    ip = value.nics[0].ip
+                else:
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if ip is None:
+                return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
 
-class QuotingBindAuthority(authority.BindAuthority):
+        if len(results) == 0:
+            authority = []
+            additional = []
+        return defer.succeed((results, authority, additional))
+
+class DelegatingQuotingBindAuthority(authority.BindAuthority):
     """
-    A BindAuthority that (almost) deals with quoting correctly
+    A delegating BindAuthority that (almost) deals with quoting correctly
     
     This will catch double quotes as marking the start or end of a
     quoted phrase, unless the double quote is escaped by a backslash
@@ -186,12 +194,38 @@ class QuotingBindAuthority(authority.BindAuthority):
             L.append(split_line)
         return filter(None, L)
 
+    def _lookup(self, name, cls, type, timeout = None):
+        maybeDelegate = False
+        deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                         type, timeout)
+        # If we didn't find an exact match for the name we were seeking,
+        # check if it's within a subdomain we're supposed to delegate to
+        # some other DNS server.
+        while (isinstance(deferredResult.result, failure.Failure)
+               and '.' in name):
+            maybeDelegate = True
+            name = name[name.find('.') + 1 :]
+            deferredResult = authority.BindAuthority._lookup(self, name, cls,
+                                                             dns.NS, timeout)
+        # If we found somewhere to delegate the query to, our _lookup()
+        # for the NS record resulted in it being in the 'results' section.
+        # We need to instead return that information in the 'authority'
+        # section to delegate, and return an empty 'results' section
+        # (because we didn't find the name we were asked about).  We
+        # leave the 'additional' section as we received it because it
+        # may contain A records for the DNS server we're delegating to.
+        if maybeDelegate and not isinstance(deferredResult.result,
+                                            failure.Failure):
+            (nsResults, nsAuthority, nsAdditional) = deferredResult.result
+            deferredResult = defer.succeed(([], nsResults, nsAdditional))
+        return deferredResult
+
 if '__main__' == __name__:
     resolvers = []
     try:
         for zone in config.dns.zone_files:
             for origin in config.dns.domains:
-                r = QuotingBindAuthority(zone)
+                r = DelegatingQuotingBindAuthority(zone)
                 # This sucks, but if I want a generic zone file, I have to
                 # reload the information by hand
                 r.origin = origin