pull out dns.IN, comment the structure of the DNS logic
[invirt/packages/invirt-dns.git] / invirt-dns
index e49c279..d5fa021 100755 (executable)
@@ -7,6 +7,7 @@ from twisted.names import authority
 from twisted.internet import defer
 from twisted.python import failure
 
+from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 import psycopg2
@@ -62,14 +63,17 @@ class DatabaseAuthority(common.ResolverBase):
         if name in self.domains:
             domain = name
         else:
-            # Look for the longest-matching domain.  (This works because domain
-            # will remain bound after breaking out of the loop.)
+            # Look for the longest-matching domain.
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
-                return defer.fail(failure.Failure(dns.DomainError(name)))
+                if name.endswith('.in-addr.arpa'):
+                    # Act authoritative for the IP address for reverse resolution requests
+                    best_domain = name
+                else:
+                    return defer.fail(failure.Failure(dns.DomainError(name)))
             domain = best_domain
         results = []
         authority = []
@@ -77,42 +81,71 @@ class DatabaseAuthority(common.ResolverBase):
         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                       3600, self.ns, auth=True))
 
-        if cls == dns.IN:
-            host = name[:-len(domain)-1]
-            if not host: # Request for the domain itself.
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
+        # The order of logic:
+        # - What class?
+        # - What domain: in-addr.arpa, domain root, or subdomain?
+        # - What query type: A, PTR, NS, ...?
+
+        if cls != dns.IN:
+            # Hahaha.  No.
+            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+
+        if name.endswith(".in-addr.arpa"):
+            if type in (dns.PTR, dns.ALL_RECORDS):
+                ip = '.'.join(reversed(name.split('.')[:-2]))
+                value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                if value and value.hostname:
+                    hostname = value.hostname
+                    if '.' not in hostname:
+                        hostname = hostname + "." + config.dns.domains[0]
+                    record = dns.Record_PTR(hostname, ttl)
+                    results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                                                 ttl, record, auth=True))
-                elif type == dns.NS:
-                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
-                                                ttl, self.ns, auth=True))
-                    authority = []
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            else: # Request for a subdomain.
-                value = invirt.database.Machine.query().filter_by(name=host).first()
-                if value is None or not value.nics:
-                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                ip = value.nics[0].ip
-                if ip is None:  #Deactivated?
+                else: # IP address doesn't point to an active host
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+            # FIXME: Should only return success with no records if the name actually exists
 
-                if type in (dns.A, dns.ALL_RECORDS):
-                    record = dns.Record_A(ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
-                                                ttl, record, auth=True))
-                elif type == dns.SOA:
-                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                ttl, self.soa, auth=True))
-            if len(results) == 0:
+        elif name == domain or name == '.'+domain:
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.NS:
+                results.append(dns.RRHeader(domain, dns.NS, dns.IN,
+                                            ttl, self.ns, auth=True))
                 authority = []
-                additional = []
-            return defer.succeed((results, authority, additional))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
         else:
-            #Doesn't exist
-            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            host = name[:-len(domain)-1]
+            value = invirt.database.NIC.query.filter_by(hostname=host).first()
+            if value:
+                ip = value.ip
+            else:
+                value = invirt.database.Machine.query().filter_by(name=host).first()
+                if value:
+                    ip = value.nics[0].ip
+                else:
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if ip is None:
+                return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+            if type in (dns.A, dns.ALL_RECORDS):
+                record = dns.Record_A(ip, ttl)
+                results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                            ttl, record, auth=True))
+            elif type == dns.SOA:
+                results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                            ttl, self.soa, auth=True))
+
+        if len(results) == 0:
+            authority = []
+            additional = []
+        return defer.succeed((results, authority, additional))
 
 class QuotingBindAuthority(authority.BindAuthority):
     """
@@ -163,17 +196,21 @@ class QuotingBindAuthority(authority.BindAuthority):
 
 if '__main__' == __name__:
     resolvers = []
-    for zone in config.dns.zone_files:
-        for origin in config.dns.domains:
-            r = QuotingBindAuthority(zone)
-            # This sucks, but if I want a generic zone file, I have to
-            # reload the information by hand
-            r.origin = origin
-            lines = open(zone).readlines()
-            lines = r.collapseContinuations(r.stripComments(lines))
-            r.parseLines(lines)
-            
-            resolvers.append(r)
+    try:
+        for zone in config.dns.zone_files:
+            for origin in config.dns.domains:
+                r = QuotingBindAuthority(zone)
+                # This sucks, but if I want a generic zone file, I have to
+                # reload the information by hand
+                r.origin = origin
+                lines = open(zone).readlines()
+                lines = r.collapseContinuations(r.stripComments(lines))
+                r.parseLines(lines)
+                
+                resolvers.append(r)
+    except InvirtConfigError:
+        # Don't care if zone_files isn't defined
+        pass
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0