#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.names import authority from twisted.internet import defer from twisted.python import failure from invirt.common import InvirtConfigError from invirt.config import structs as config import invirt.database import psycopg2 import sqlalchemy import time import re class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() if name in self.domains: domain = name else: # Look for the longest-matching domain. best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': if name.endswith('.in-addr.arpa'): # Act authoritative for the IP address for reverse resolution requests best_domain = name else: return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) if cls == dns.IN: if name.endswith(".in-addr.arpa"): if type in (dns.PTR, dns.ALL_RECORDS): ip = '.'.join(reversed(name.split('.')[:-2])) value = invirt.database.NIC.query.filter_by(ip=ip).first() if value and value.hostname: hostname = value.hostname if '.' not in hostname: hostname = hostname + "." + config.dns.domains[0] record = dns.Record_PTR(hostname, ttl) results.append(dns.RRHeader(name, dns.PTR, dns.IN, ttl, record, auth=True)) else: # IP address doesn't point to an active host return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) # FIXME: Should only return success with no records if the name actually exists elif name == domain or name == '.'+domain: if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: host = name[:-len(domain)-1] value = invirt.database.NIC.query.filter_by(hostname=host).first() if value: ip = value.ip else: value = invirt.database.Machine.query().filter_by(name=host).first() if value: ip = value.nics[0].ip else: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if ip is None: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) else: #Doesn't exist return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) class QuotingBindAuthority(authority.BindAuthority): """ A BindAuthority that (almost) deals with quoting correctly This will catch double quotes as marking the start or end of a quoted phrase, unless the double quote is escaped by a backslash """ # Match either a quoted or unquoted string literal followed by # whitespace or the end of line. This yields two groups, one of # which has a match, and the other of which is None, depending on # whether the string literal was quoted or unquoted; this is what # necessitates the subsequent filtering out of groups that are # None. string_pat = \ re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)') # For interpreting escapes. escape_pat = re.compile(r'\\(.)') def collapseContinuations(self, lines): L = [] state = 0 for line in lines: if state == 0: if line.find('(') == -1: L.append(line) else: L.append(line[:line.find('(')]) state = 1 else: if line.find(')') != -1: L[-1] += ' ' + line[:line.find(')')] state = 0 else: L[-1] += ' ' + line lines = L L = [] for line in lines: in_quote = False split_line = [] for m in self.string_pat.finditer(line): [x] = [x for x in m.groups() if x is not None] split_line.append(self.escape_pat.sub(r'\1', x)) L.append(split_line) return filter(None, L) if '__main__' == __name__: resolvers = [] try: for zone in config.dns.zone_files: for origin in config.dns.domains: r = QuotingBindAuthority(zone) # This sucks, but if I want a generic zone file, I have to # reload the information by hand r.origin = origin lines = open(zone).readlines() lines = r.collapseContinuations(r.stripComments(lines)) r.parseLines(lines) resolvers.append(r) except InvirtConfigError: # Don't care if zone_files isn't defined pass resolvers.append(DatabaseAuthority()) verbosity = 0 f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()