#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.names import authority from twisted.names import resolve from twisted.internet import defer from twisted.python import failure from invirt.common import InvirtConfigError from invirt.config import structs as config import invirt.database from invirt.database import NIC import psycopg2 import sqlalchemy import time import re class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() if name in self.domains: domain = name else: # Look for the longest-matching domain. best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': if name.endswith('.in-addr.arpa'): # Act authoritative for the IP address for reverse resolution requests best_domain = name else: return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) # The order of logic: # - What class? # - What domain: in-addr.arpa, domain root, or subdomain? # - What query type: A, PTR, NS, ...? if cls != dns.IN: # Hahaha. No. return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if name.endswith(".in-addr.arpa"): if type in (dns.PTR, dns.ALL_RECORDS): ip = '.'.join(reversed(name.split('.')[:-2])) value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first() if value and value.hostname: hostname = value.hostname if '.' not in hostname: if ip == value.other_ip: hostname = hostname + ".other" hostname = hostname + "." + config.dns.domains[0] record = dns.Record_PTR(hostname, ttl) results.append(dns.RRHeader(name, dns.PTR, dns.IN, ttl, record, auth=True)) else: # IP address doesn't point to an active host return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) # FIXME: Should only return success with no records if the name actually exists elif name == domain or name == '.'+domain or name == 'other.'+domain: if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: host = name[:-len(domain)-1] other = False if host.endswith(".other"): host = host[:-len(".other")] other = True value = invirt.database.NIC.query.filter_by(hostname=host).first() if value: if other: ip = value.other_ip action = value.other_action else: ip = value.ip else: value = invirt.database.Machine.query.filter_by(name=host).first() if value: if other: ip = value.nics[0].other_ip action = value.nics[0].other_action else: ip = value.nics[0].ip else: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if ip is None: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) elif other and type == dns.TXT: record = dns.Record_TXT(action if action else '', ttl=ttl) results.append(dns.RRHeader(name, dns.TXT, dns.IN, ttl, record, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) class DelegatingQuotingBindAuthority(authority.BindAuthority): """ A delegating BindAuthority that (almost) deals with quoting correctly This will catch double quotes as marking the start or end of a quoted phrase, unless the double quote is escaped by a backslash """ # Match either a quoted or unquoted string literal followed by # whitespace or the end of line. This yields two groups, one of # which has a match, and the other of which is None, depending on # whether the string literal was quoted or unquoted; this is what # necessitates the subsequent filtering out of groups that are # None. string_pat = \ re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)') # For interpreting escapes. escape_pat = re.compile(r'\\(.)') def collapseContinuations(self, lines): L = [] state = 0 for line in lines: if state == 0: if line.find('(') == -1: L.append(line) else: L.append(line[:line.find('(')]) state = 1 else: if line.find(')') != -1: L[-1] += ' ' + line[:line.find(')')] state = 0 else: L[-1] += ' ' + line lines = L L = [] for line in lines: in_quote = False split_line = [] for m in self.string_pat.finditer(line): [x] = [x for x in m.groups() if x is not None] split_line.append(self.escape_pat.sub(r'\1', x)) L.append(split_line) return filter(None, L) def _lookup(self, name, cls, type, timeout = None): maybeDelegate = False deferredResult = authority.BindAuthority._lookup(self, name, cls, type, timeout) # If we didn't find an exact match for the name we were seeking, # check if it's within a subdomain we're supposed to delegate to # some other DNS server. while (isinstance(deferredResult.result, failure.Failure) and '.' in name): maybeDelegate = True name = name[name.find('.') + 1 :] deferredResult = authority.BindAuthority._lookup(self, name, cls, dns.NS, timeout) return deferredResult class TypeLenientResolverChain(resolve.ResolverChain): """ This is a ResolverChain which is more lenient in its handling of queries requesting unimplemented record types. """ def query(self, query, timeout = None): try: return self.typeToMethod[query.type](str(query.name), timeout) except KeyError, e: # We don't support the requested record type. Twisted would # have us return SERVFAIL. Instead, we'll check whether the # name exists in our zone at all and return NXDOMAIN or an empty # result set with NOERROR as appropriate. deferredResult = self.lookupAllRecords(str(query.name), timeout) if isinstance(deferredResult.result, failure.Failure): return deferredResult return defer.succeed(([], [], [])) if '__main__' == __name__: resolvers = [] try: for zone in config.dns.zone_files: for origin in config.dns.domains: r = DelegatingQuotingBindAuthority(zone) # This sucks, but if I want a generic zone file, I have to # reload the information by hand r.origin = origin lines = open(zone).readlines() lines = r.collapseContinuations(r.stripComments(lines)) r.parseLines(lines) resolvers.append(r) except InvirtConfigError: # Don't care if zone_files isn't defined pass resolvers.append(DatabaseAuthority()) verbosity = 0 f = server.DNSServerFactory(verbose=verbosity) f.resolver = TypeLenientResolverChain(resolvers) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()