#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.names import authority from twisted.internet import defer from twisted.python import failure from invirt.config import structs as config import invirt.database import psycopg2 import sqlalchemy import time class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() if name in self.domains: domain = name else: # Look for the longest-matching domain. (This works because domain # will remain bound after breaking out of the loop.) best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) if cls == dns.IN: host = name[:-len(domain)-1] if not host: # Request for the domain itself. if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: # Request for a subdomain. value = invirt.database.Machine.query().filter_by(name=host).first() if value is None or not value.nics: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) ip = value.nics[0].ip if ip is None: #Deactivated? return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) else: #Doesn't exist return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if '__main__' == __name__: resolvers = [] for zone in config.dns.zone_files: for origin in config.dns.domains: r = authority.BindAuthority(zone) # This sucks, but if I want a generic zone file, I have to # reload the information by hand r.origin = origin lines = open(zone).readlines() lines = r.collapseContinuations(r.stripComments(lines)) r.parseLines(lines) resolvers.append(r) resolvers.append(DatabaseAuthority()) verbosity = 0 f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()