refactor DNS logic; fix some bugs in reverse-resolution
[invirt/packages/invirt-dns.git] / invirt-dns
index 9bc051f..1744ac8 100755 (executable)
@@ -63,15 +63,15 @@ class DatabaseAuthority(common.ResolverBase):
         if name in self.domains:
             domain = name
         else:
-            # Look for the longest-matching domain.  (This works because domain
-            # will remain bound after breaking out of the loop.)
+            # Look for the longest-matching domain.
             best_domain = ''
             for domain in self.domains:
                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
                     best_domain = domain
             if best_domain == '':
                 if name.endswith('.in-addr.arpa'):
-                    best_domain = name # Act authoritative for the IP address for reverse resolution requests
+                    # Act authoritative for the IP address for reverse resolution requests
+                    best_domain = name
                 else:
                     return defer.fail(failure.Failure(dns.DomainError(name)))
             domain = best_domain
@@ -82,11 +82,27 @@ class DatabaseAuthority(common.ResolverBase):
                                       3600, self.ns, auth=True))
 
         if cls == dns.IN:
-            host = name[:-len(domain)-1]
-            if not host and type != dns.PTR: # Request for the domain itself.
+            if name.endswith(".in-addr.arpa"):
+                if type in (dns.PTR, dns.ALL_RECORDS):
+                    ip = '.'.join(reversed(name.split('.')[:-2]))
+                    value = invirt.database.NIC.query.filter_by(ip=ip).first()
+                    if value and value.hostname:
+                        hostname = value.hostname
+                        if '.' not in hostname:
+                            hostname = hostname + "." + config.dns.domains[0]
+                        record = dns.Record_PTR(hostname, ttl)
+                        results.append(dns.RRHeader(name, dns.PTR, dns.IN,
+                                                    ttl, record, auth=True))
+                    else: # IP address doesn't point to an active host
+                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+                elif type == dns.SOA:
+                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                                ttl, self.soa, auth=True))
+                # FIXME: Should only return success with no records if the name actually exists
+            elif name == domain or name == '.'+domain:
                 if type in (dns.A, dns.ALL_RECORDS):
                     record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
-                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
+                    results.append(dns.RRHeader(name, dns.A, dns.IN,
                                                 ttl, record, auth=True))
                 elif type == dns.NS:
                     results.append(dns.RRHeader(domain, dns.NS, dns.IN,
@@ -95,42 +111,28 @@ class DatabaseAuthority(common.ResolverBase):
                 elif type == dns.SOA:
                     results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                 ttl, self.soa, auth=True))
-            else: # Request for a subdomain.
-                if name.endswith(".in-addr.arpa"): # Reverse resolution here
-                    if type in (dns.PTR, dns.ALL_RECORDS):
-                        ip = '.'.join(reversed(name.split('.')[:-2]))
-                        value = invirt.database.NIC.query.filter_by(ip=ip).first()
-                        if value and value.hostname:
-                            hostname = value.hostname
-                            if '.' not in hostname:
-                                hostname = hostname + "." + config.dns.domains[0]
-                            record = dns.Record_PTR(hostname, ttl)
-                            results.append(dns.RRHeader(name, dns.PTR, dns.IN,
-                                                        ttl, record, auth=True))
-                        else: # IP address doesn't point to an active host
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                    # FIXME: Should only return success with no records if the name actually exists
-                else: # Forward resolution here
-                    value = invirt.database.NIC.query.filter_by(hostname=host).first()
+            else:
+                host = name[:-len(domain)-1]
+                value = invirt.database.NIC.query.filter_by(hostname=host).first()
+                if value:
+                    ip = value.ip
+                else:
+                    value = invirt.database.Machine.query().filter_by(name=host).first()
                     if value:
-                        ip = value.ip
+                        ip = value.nics[0].ip
                     else:
-                        value = invirt.database.Machine.query().filter_by(name=host).first()
-                        if value:
-                            ip = value.nics[0].ip
-                        else:
-                            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                
-                    if ip is None:
                         return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
-                    if type in (dns.A, dns.ALL_RECORDS):
-                        record = dns.Record_A(ip, ttl)
-                        results.append(dns.RRHeader(name, dns.A, dns.IN, 
-                                                    ttl, record, auth=True))
-                    elif type == dns.SOA:
-                        results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
-                                                    ttl, self.soa, auth=True))
+                if ip is None:
+                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+
+                if type in (dns.A, dns.ALL_RECORDS):
+                    record = dns.Record_A(ip, ttl)
+                    results.append(dns.RRHeader(name, dns.A, dns.IN,
+                                                ttl, record, auth=True))
+                elif type == dns.SOA:
+                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
+                                                ttl, self.soa, auth=True))
             if len(results) == 0:
                 authority = []
                 additional = []