#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.names import authority from twisted.names import resolve from twisted.internet import defer from twisted.python import failure from invirt.common import InvirtConfigError from invirt.config import structs as config import invirt.database from invirt.database import NIC import psycopg2 import sqlalchemy import time import re import sys class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() if name in self.domains: domain = name else: # Look for the longest-matching domain. best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': if name.endswith('.in-addr.arpa'): # Act authoritative for the IP address for reverse resolution requests best_domain = name else: return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) # The order of logic: # - What class? # - What domain: in-addr.arpa, domain root, or subdomain? # - What query type: A, PTR, NS, ...? if cls != dns.IN: # Hahaha. No. return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if name.endswith(".in-addr.arpa"): if type in (dns.PTR, dns.ALL_RECORDS): ip = '.'.join(reversed(name.split('.')[:-2])) value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first() if value and value.hostname: hostname = value.hostname if '.' not in hostname: if ip == value.other_ip: hostname = hostname + ".other" hostname = hostname + "." + config.dns.domains[0] record = dns.Record_PTR(hostname, ttl) results.append(dns.RRHeader(name, dns.PTR, dns.IN, ttl, record, auth=True)) else: # IP address doesn't point to an active host return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) # FIXME: Should only return success with no records if the name actually exists elif name == domain or name == '.'+domain or name == 'other.'+domain: if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: host = name[:-len(domain)-1] other = False if host.endswith(".other"): host = host[:-len(".other")] other = True value = invirt.database.NIC.query.filter_by(hostname=host).first() if value: if other: ip = value.other_ip action = value.other_action else: ip = value.ip else: value = invirt.database.Machine.query.filter_by(name=host).first() if value: if other: ip = value.nics[0].other_ip action = value.nics[0].other_action else: ip = value.nics[0].ip else: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if ip is None: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) if other and type in (dns.TXT, dns.ALL_RECORDS): record = dns.Record_TXT(action if action else '', ttl=ttl) results.append(dns.RRHeader(name, dns.TXT, dns.IN, ttl, record, auth=True)) if type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) class DelegatingQuotingBindAuthority(authority.BindAuthority): """ A delegating BindAuthority that (almost) deals with quoting correctly This will catch double quotes as marking the start or end of a quoted phrase, unless the double quote is escaped by a backslash """ # Match either a quoted or unquoted string literal followed by # whitespace or the end of line. This yields two groups, one of # which has a match, and the other of which is None, depending on # whether the string literal was quoted or unquoted; this is what # necessitates the subsequent filtering out of groups that are # None. string_pat = \ re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)') # For interpreting escapes. escape_pat = re.compile(r'\\(.)') def collapseContinuations(self, lines): L = [] state = 0 for line in lines: if state == 0: if line.find('(') == -1: L.append(line) else: L.append(line[:line.find('(')]) state = 1 else: if line.find(')') != -1: L[-1] += ' ' + line[:line.find(')')] state = 0 else: L[-1] += ' ' + line lines = L L = [] for line in lines: in_quote = False split_line = [] for m in self.string_pat.finditer(line): [x] = [x for x in m.groups() if x is not None] split_line.append(self.escape_pat.sub(r'\1', x)) L.append(split_line) return filter(None, L) # See https://twistedmatrix.com/documents/13.1.0/api/twisted.internet.defer.html#inlineCallbacks @defer.inlineCallbacks def _lookup(self, name, cls, type, timeout = None): try: result = yield authority.BindAuthority._lookup(self, name, cls, type, timeout) defer.returnValue(result) except Exception as e: # XXX: Twisted returns DomainError even if it is # authoritative for the domain because our SOA record # incorrectly contains (origin + "." + origin) if not isinstance(e, (dns.DomainError, dns.AuthoritativeDomainError)): sys.stderr.write("while looking up '%s', got: %s\n" % (name, e)) # If we didn't find an exact match for the name we were # seeking, check if it's within a subdomain we're supposed # to delegate to some other DNS server. while '.' in name: _, name = name.split('.', 1) try: # BindAuthority puts the NS in the authority # section automatically for us, so just return # it. We override the type to NS. result = yield authority.BindAuthority._lookup(self, name, cls, dns.NS, timeout) defer.returnValue(result) except Exception: # Should be one of (dns.DomainError, dns.AuthoritativeDomainError) pass # We didn't find a delegation, so return the original # NXDOMAIN. raise class TypeLenientResolverChain(resolve.ResolverChain): """ This is a ResolverChain which is more lenient in its handling of queries requesting unimplemented record types. """ def query(self, query, timeout = None): try: return self.typeToMethod[query.type](str(query.name), timeout) except KeyError, e: # We don't support the requested record type. Twisted would # have us return SERVFAIL. Instead, we'll check whether the # name exists in our zone at all and return NXDOMAIN or an empty # result set with NOERROR as appropriate. deferredResult = self.lookupAllRecords(str(query.name), timeout) if isinstance(deferredResult.result, failure.Failure): return deferredResult return defer.succeed(([], [], [])) if '__main__' == __name__: resolvers = [] try: for zone in config.dns.zone_files: for origin in config.dns.domains: r = DelegatingQuotingBindAuthority(zone) # This sucks, but if I want a generic zone file, I have to # reload the information by hand # XXX: This causes our SOA record to contain # (origin + "." + origin) # As a result the resolver never believes it is authoritative. r.origin = origin lines = open(zone).readlines() lines = r.collapseContinuations(r.stripComments(lines)) r.parseLines(lines) resolvers.append(r) except InvirtConfigError: # Don't care if zone_files isn't defined pass resolvers.append(DatabaseAuthority()) verbosity = 0 f = server.DNSServerFactory(verbose=verbosity) f.resolver = TypeLenientResolverChain(resolvers) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()