The DNS server shouldn't error if dns.zone_files isn't set in the config.
[invirt/packages/invirt-dns.git] / invirt-dns
index 13e00d4..24117f0 100755 (executable)
@@ -7,11 +7,13 @@ from twisted.names import authority
 from twisted.internet import defer
 from twisted.python import failure
 
 from twisted.internet import defer
 from twisted.python import failure
 
+from invirt.common import InvirtConfigError
 from invirt.config import structs as config
 import invirt.database
 import psycopg2
 import sqlalchemy
 import time
 from invirt.config import structs as config
 import invirt.database
 import psycopg2
 import sqlalchemy
 import time
+import re
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
 
 class DatabaseAuthority(common.ResolverBase):
     """An Authority that is loaded from a file."""
@@ -91,17 +93,17 @@ class DatabaseAuthority(common.ResolverBase):
                     results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                 ttl, self.soa, auth=True))
             else: # Request for a subdomain.
                     results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                 ttl, self.soa, auth=True))
             else: # Request for a subdomain.
-                if 'passup' in dir(config.dns) and host in config.dns.passup:
-                    record = dns.Record_CNAME('%s.%s' % (host, config.dns.parent), ttl)
-                    return defer.succeed((
-                        [dns.RRHeader(name, dns.CNAME, dns.IN, ttl, record, auth=True)],
-                        [], []))
-
-                value = invirt.database.Machine.query().filter_by(name=host).first()
-                if value is None or not value.nics:
-                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
-                ip = value.nics[0].ip
-                if ip is None:  #Deactivated?
+                value = invirt.database.NIC.query.filter_by(hostname=host).first()
+                if value:
+                    ip = value.ip
+                else:
+                    value = invirt.database.Machine.query().filter_by(name=host).first()
+                    if value:
+                        ip = value.nics[0].ip
+                    else:
+                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
+                
+                if ip is None:
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
                 if type in (dns.A, dns.ALL_RECORDS):
                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
                 if type in (dns.A, dns.ALL_RECORDS):
@@ -119,19 +121,70 @@ class DatabaseAuthority(common.ResolverBase):
             #Doesn't exist
             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
             #Doesn't exist
             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
 
+class QuotingBindAuthority(authority.BindAuthority):
+    """
+    A BindAuthority that (almost) deals with quoting correctly
+    
+    This will catch double quotes as marking the start or end of a
+    quoted phrase, unless the double quote is escaped by a backslash
+    """
+    # Match either a quoted or unquoted string literal followed by
+    # whitespace or the end of line.  This yields two groups, one of
+    # which has a match, and the other of which is None, depending on
+    # whether the string literal was quoted or unquoted; this is what
+    # necessitates the subsequent filtering out of groups that are
+    # None.
+    string_pat = \
+            re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
+
+    # For interpreting escapes.
+    escape_pat = re.compile(r'\\(.)')
+
+    def collapseContinuations(self, lines):
+        L = []
+        state = 0
+        for line in lines:
+            if state == 0:
+                if line.find('(') == -1:
+                    L.append(line)
+                else:
+                    L.append(line[:line.find('(')])
+                    state = 1
+            else:
+                if line.find(')') != -1:
+                    L[-1] += ' ' + line[:line.find(')')]
+                    state = 0
+                else:
+                    L[-1] += ' ' + line
+        lines = L
+        L = []
+
+        for line in lines:
+            in_quote = False
+            split_line = []
+            for m in self.string_pat.finditer(line):
+                [x] = [x for x in m.groups() if x is not None]
+                split_line.append(self.escape_pat.sub(r'\1', x))
+            L.append(split_line)
+        return filter(None, L)
+
 if '__main__' == __name__:
     resolvers = []
 if '__main__' == __name__:
     resolvers = []
-    for zone in config.dns.zone_files:
-        for origin in config.dns.domains:
-            r = authority.BindAuthority(zone)
-            # This sucks, but if I want a generic zone file, I have to
-            # reload the information by hand
-            r.origin = origin
-            lines = open(zone).readlines()
-            lines = r.collapseContinuations(r.stripComments(lines))
-            r.parseLines(lines)
-            
-            resolvers.append(r)
+    try:
+        for zone in config.dns.zone_files:
+            for origin in config.dns.domains:
+                r = QuotingBindAuthority(zone)
+                # This sucks, but if I want a generic zone file, I have to
+                # reload the information by hand
+                r.origin = origin
+                lines = open(zone).readlines()
+                lines = r.collapseContinuations(r.stripComments(lines))
+                r.parseLines(lines)
+                
+                resolvers.append(r)
+    except InvirtConfigError:
+        # Don't care if zone_files isn't defined
+        pass
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0
     resolvers.append(DatabaseAuthority())
 
     verbosity = 0