#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.names import authority from twisted.internet import defer from twisted.python import failure from invirt.config import structs as config import invirt.database import psycopg2 import sqlalchemy import time import re class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() if name in self.domains: domain = name else: # Look for the longest-matching domain. (This works because domain # will remain bound after breaking out of the loop.) best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) if cls == dns.IN: host = name[:-len(domain)-1] if not host: # Request for the domain itself. if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: # Request for a subdomain. value = invirt.database.Machine.query().filter_by(name=host).first() if value is None or not value.nics: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) ip = value.nics[0].ip if ip is None: #Deactivated? return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) else: #Doesn't exist return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) class QuotingBindAuthority(authority.BindAuthority): """ A BindAuthority that (almost) deals with quoting correctly This will catch double quotes as marking the start or end of a quoted phrase, unless the double quote is escaped by a backslash """ # Grab everything up to the first whitespace character or # quotation mark not proceeded by a backslash whitespace_re = re.compile(r'(.*?)([\t\n\x0b\x0c\r ]+|(? 0: match = self.whitespace_re.match(line) if match is None: # If there's no match, that means that there's no # whitespace in the rest of the line, so it should # be treated as a single entity, quoted or not # # This also means that a closing quote isn't # strictly necessary if the line ends the quote substr = line end = '' else: substr, end = match.groups() if in_quote: # If we're in the middle of the quote, the string # we just grabbed belongs at the end of the # previous string # # Including the whitespace! Unless it's not # whitespace and is actually a closequote instead split_line[-1] += substr + (end if end != '"' else '') else: # If we're not in the middle of a quote, than this # is the next new string split_line.append(substr) if end == '"': in_quote = not in_quote # Then strip off what we just processed line = line[len(substr + end):] L.append(split_line) return filter(None, L) if '__main__' == __name__: resolvers = [] for zone in config.dns.zone_files: for origin in config.dns.domains: r = QuotingBindAuthority(zone) # This sucks, but if I want a generic zone file, I have to # reload the information by hand r.origin = origin lines = open(zone).readlines() lines = r.collapseContinuations(r.stripComments(lines)) r.parseLines(lines) resolvers.append(r) resolvers.append(DatabaseAuthority()) verbosity = 0 f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()