#!/usr/bin/python
from twisted.internet import reactor
from twisted.names import server
from twisted.names import dns
from twisted.names import common
from twisted.names import authority
from twisted.internet import defer
from twisted.python import failure

from invirt.config import structs as config
import invirt.database
import psycopg2
import sqlalchemy
import time
import re

class DatabaseAuthority(common.ResolverBase):
    """An Authority that is loaded from a file."""

    soa = None

    def __init__(self, domains=None, database=None):
        common.ResolverBase.__init__(self)
        if database is not None:
            invirt.database.connect(database)
        else:
            invirt.database.connect()
        if domains is not None:
            self.domains = domains
        else:
            self.domains = config.dns.domains
        ns = config.dns.nameservers[0]
        self.soa = dns.Record_SOA(mname=ns.hostname,
                                  rname=config.dns.contact.replace('@','.',1),
                                  serial=1, refresh=3600, retry=900,
                                  expire=3600000, minimum=21600, ttl=3600)
        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
        record = dns.Record_A(address=ns.ip, ttl=3600)
        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
                                3600, record, auth=True)

    
    def _lookup(self, name, cls, type, timeout = None):
        for i in range(3):
            try:
                value = self._lookup_unsafe(name, cls, type, timeout = None)
            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
                if i == 2:
                    raise
                print "Reloading database"
                time.sleep(0.5)
                continue
            else:
                return value

    def _lookup_unsafe(self, name, cls, type, timeout):
        invirt.database.clear_cache()
        
        ttl = 900
        name = name.lower()

        if name in self.domains:
            domain = name
        else:
            # Look for the longest-matching domain.  (This works because domain
            # will remain bound after breaking out of the loop.)
            best_domain = ''
            for domain in self.domains:
                if name.endswith('.'+domain) and len(domain) > len(best_domain):
                    best_domain = domain
            if best_domain == '':
                return defer.fail(failure.Failure(dns.DomainError(name)))
            domain = best_domain
        results = []
        authority = []
        additional = [self.ns1]
        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                      3600, self.ns, auth=True))

        if cls == dns.IN:
            host = name[:-len(domain)-1]
            if not host: # Request for the domain itself.
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.NS:
                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                                ttl, self.ns, auth=True))
                    authority = []
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            else: # Request for a subdomain.
                value = invirt.database.Machine.query().filter_by(name=host).first()
                if value is None or not value.nics:
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                ip = value.nics[0].ip
                if ip is None:  #Deactivated?
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN, 
                                                ttl, record, auth=True))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            if len(results) == 0:
                authority = []
                additional = []
            return defer.succeed((results, authority, additional))
        else:
            #Doesn't exist
            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

class QuotingBindAuthority(authority.BindAuthority):
    """
    A BindAuthority that (almost) deals with quoting correctly
    
    This will catch double quotes as marking the start or end of a
    quoted phrase, unless the double quote is escaped by a backslash
    """
    # Match either a quoted or unquoted string literal followed by
    # whitespace or the end of line.  This yields two groups, one of
    # which has a match, and the other of which is None, depending on
    # whether the string literal was quoted or unquoted; this is what
    # necessitates the subsequent filtering out of groups that are
    # None.
    string_pat = \
            re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')

    # For interpreting escapes.
    escape_pat = re.compile(r'\\(.)')

    def collapseContinuations(self, lines):
        L = []
        state = 0
        for line in lines:
            if state == 0:
                if line.find('(') == -1:
                    L.append(line)
                else:
                    L.append(line[:line.find('(')])
                    state = 1
            else:
                if line.find(')') != -1:
                    L[-1] += ' ' + line[:line.find(')')]
                    state = 0
                else:
                    L[-1] += ' ' + line
        lines = L
        L = []

        for line in lines:
            in_quote = False
            split_line = []
            for m in self.string_pat.finditer(line):
                [x] = [x for x in m.groups() if x is not None]
                split_line.append(self.escape_pat.sub(r'\1', x))
            L.append(split_line)
        return filter(None, L)

if '__main__' == __name__:
    resolvers = []
    for zone in config.dns.zone_files:
        for origin in config.dns.domains:
            r = QuotingBindAuthority(zone)
            # This sucks, but if I want a generic zone file, I have to
            # reload the information by hand
            r.origin = origin
            lines = open(zone).readlines()
            lines = r.collapseContinuations(r.stripComments(lines))
            r.parseLines(lines)
            
            resolvers.append(r)
    resolvers.append(DatabaseAuthority())

    verbosity = 0
    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
    p = dns.DNSDatagramProtocol(f)
    f.noisy = p.noisy = verbosity
    
    reactor.listenUDP(53, p)
    reactor.listenTCP(53, f)
    reactor.run()