#!/usr/bin/python
from twisted.internet import reactor
from twisted.names import server
from twisted.names import dns
from twisted.names import common
from twisted.internet import defer
from twisted.python import failure

from invirt.config import structs as config
import invirt.database
import psycopg2
import sqlalchemy
import time

class DatabaseAuthority(common.ResolverBase):
    """An Authority that is loaded from a file."""

    soa = None

    def __init__(self, domains=None, database=None):
        common.ResolverBase.__init__(self)
        if database is not None:
            invirt.database.connect(database)
        else:
            invirt.database.connect()
        if domains is not None:
            self.domains = domains
        else:
            self.domains = config.dns.domains
        ns = config.dns.nameservers[0]
        self.soa = dns.Record_SOA(mname=ns.hostname,
                                  rname=config.dns.contact.replace('@','.',1),
                                  serial=1, refresh=3600, retry=900,
                                  expire=3600000, minimum=21600, ttl=3600)
        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
        record = dns.Record_A(address=ns.ip, ttl=3600)
        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
                                3600, record, auth=True)

    
    def _lookup(self, name, cls, type, timeout = None):
        for i in range(3):
            try:
                value = self._lookup_unsafe(name, cls, type, timeout = None)
            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
                if i == 2:
                    raise
                print "Reloading database"
                time.sleep(0.5)
                continue
            else:
                return value

    def _lookup_unsafe(self, name, cls, type, timeout):
        invirt.database.clear_cache()
        
        ttl = 900
        name = name.lower()

        # XXX hack for the transition to two separate dev/prod clusters
        if 'dev.xvm.mit.edu' in self.domains and name.endswith('prod.xvm.mit.edu'):
            # import time, sys
            # print time.localtime(), 'handling prod request', name
            # sys.stdout.flush()

            # Point the client in the right direction for prod requests.
            authority = dns.RRHeader('prod.xvm.mit.edu', dns.NS, dns.IN, 3600,
                    dns.Record_NS(name='ns1.prod.xvm.mit.edu', ttl=3600), auth=True)
            additional = dns.RRHeader('ns1.prod.xvm.mit.edu', dns.A, dns.IN, 3600,
                    dns.Record_A(address='18.181.0.221', ttl=3600), auth=True)
            return defer.succeed(([], [authority], [additional]))

        if name in self.domains:
            domain = name
        else:
            # Look for the longest-matching domain.  (This works because domain
            # will remain bound after breaking out of the loop.)
            best_domain = ''
            for domain in self.domains:
                if name.endswith('.'+domain) and len(domain) > len(best_domain):
                    best_domain = domain
            if best_domain == '':
                return defer.fail(failure.Failure(dns.DomainError(name)))
            domain = best_domain
        results = []
        authority = []
        additional = [self.ns1]
        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                      3600, self.ns, auth=True))

        if cls == dns.IN:
            host = name[:-len(domain)-1]
            if not host: # Request for the domain itself.
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.NS:
                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                                ttl, self.ns, auth=True))
                    authority = []
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            else: # Request for a subdomain.
                if 'passup' in dir(config.dns) and host in config.dns.passup:
                    record = dns.Record_CNAME('%s.%s' % (host, config.dns.parent), ttl)
                    return defer.succeed((
                        [dns.RRHeader(name, dns.CNAME, dns.IN, ttl, record, auth=True)],
                        [], []))

                value = invirt.database.Machine.query().filter_by(name=host).first()
                if value is None or not value.nics:
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                ip = value.nics[0].ip
                if ip is None:  #Deactivated?
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            if len(results) == 0:
                authority = []
                additional = []
            return defer.succeed((results, authority, additional))
        else:
            #Doesn't exist
            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

if '__main__' == __name__:
    resolver = DatabaseAuthority()

    verbosity = 0
    f = server.DNSServerFactory(authorities=[resolver], verbose=verbosity)
    p = dns.DNSDatagramProtocol(f)
    f.noisy = p.noisy = verbosity
    
    reactor.listenUDP(53, p)
    reactor.listenTCP(53, f)
    reactor.run()