#!/usr/bin/python from twisted.internet import reactor from twisted.names import server from twisted.names import dns from twisted.names import common from twisted.internet import defer from twisted.python import failure from invirt.config import structs as config import invirt.database import psycopg2 import sqlalchemy import time class DatabaseAuthority(common.ResolverBase): """An Authority that is loaded from a file.""" soa = None def __init__(self, domains=None, database=None): common.ResolverBase.__init__(self) if database is not None: invirt.database.connect(database) else: invirt.database.connect() if domains is not None: self.domains = domains else: self.domains = config.dns.domains ns = config.dns.nameservers[0] self.soa = dns.Record_SOA(mname=ns.hostname, rname=config.dns.contact.replace('@','.',1), serial=1, refresh=3600, retry=900, expire=3600000, minimum=21600, ttl=3600) self.ns = dns.Record_NS(name=ns.hostname, ttl=3600) record = dns.Record_A(address=ns.ip, ttl=3600) self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN, 3600, record, auth=True) def _lookup(self, name, cls, type, timeout = None): for i in range(3): try: value = self._lookup_unsafe(name, cls, type, timeout = None) except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError): if i == 2: raise print "Reloading database" time.sleep(0.5) continue else: return value def _lookup_unsafe(self, name, cls, type, timeout): invirt.database.clear_cache() ttl = 900 name = name.lower() # XXX hack for the transition to two separate dev/prod clusters if 'dev.xvm.mit.edu' in self.domains and name.endswith('prod.xvm.mit.edu'): # import time, sys # print time.localtime(), 'handling prod request', name # sys.stdout.flush() # Point the client in the right direction for prod requests. authority = dns.RRHeader('prod.xvm.mit.edu', dns.NS, dns.IN, 3600, dns.Record_NS(name='ns1.prod.xvm.mit.edu', ttl=3600), auth=True) additional = dns.RRHeader('ns1.prod.xvm.mit.edu', dns.A, dns.IN, 3600, dns.Record_A(address='18.181.0.221', ttl=3600), auth=True) return defer.succeed(([], [authority], [additional])) if name in self.domains: domain = name else: # Look for the longest-matching domain. (This works because domain # will remain bound after breaking out of the loop.) best_domain = '' for domain in self.domains: if name.endswith('.'+domain) and len(domain) > len(best_domain): best_domain = domain if best_domain == '': return defer.fail(failure.Failure(dns.DomainError(name))) domain = best_domain results = [] authority = [] additional = [self.ns1] authority.append(dns.RRHeader(domain, dns.NS, dns.IN, 3600, self.ns, auth=True)) if cls == dns.IN: host = name[:-len(domain)-1] if not host: # Request for the domain itself. if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(config.dns.nameservers[0].ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.NS: results.append(dns.RRHeader(domain, dns.NS, dns.IN, ttl, self.ns, auth=True)) authority = [] elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) else: # Request for a subdomain. if 'passup' in dir(config.dns) and host in config.dns.passup: record = dns.Record_CNAME('%s.%s' % (host, config.dns.parent), ttl) return defer.succeed(( [dns.RRHeader(name, dns.CNAME, dns.IN, ttl, record, auth=True)], [], [])) value = invirt.database.Machine.query().filter_by(name=host).first() if value is None or not value.nics: return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) ip = value.nics[0].ip if ip is None: #Deactivated? return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if type in (dns.A, dns.ALL_RECORDS): record = dns.Record_A(ip, ttl) results.append(dns.RRHeader(name, dns.A, dns.IN, ttl, record, auth=True)) elif type == dns.SOA: results.append(dns.RRHeader(domain, dns.SOA, dns.IN, ttl, self.soa, auth=True)) if len(results) == 0: authority = [] additional = [] return defer.succeed((results, authority, additional)) else: #Doesn't exist return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) if '__main__' == __name__: resolver = DatabaseAuthority() verbosity = 0 f = server.DNSServerFactory(authorities=[resolver], verbose=verbosity) p = dns.DNSDatagramProtocol(f) f.noisy = p.noisy = verbosity reactor.listenUDP(53, p) reactor.listenTCP(53, f) reactor.run()