#!/usr/bin/python
from twisted.internet import reactor
from twisted.names import server
from twisted.names import dns
from twisted.names import common
from twisted.names import authority
from twisted.internet import defer
from twisted.python import failure

from invirt.common import InvirtConfigError
from invirt.config import structs as config
import invirt.database
import psycopg2
import sqlalchemy
import time
import re

class DatabaseAuthority(common.ResolverBase):
    """An Authority that is loaded from a file."""

    soa = None

    def __init__(self, domains=None, database=None):
        common.ResolverBase.__init__(self)
        if database is not None:
            invirt.database.connect(database)
        else:
            invirt.database.connect()
        if domains is not None:
            self.domains = domains
        else:
            self.domains = config.dns.domains
        ns = config.dns.nameservers[0]
        self.soa = dns.Record_SOA(mname=ns.hostname,
                                  rname=config.dns.contact.replace('@','.',1),
                                  serial=1, refresh=3600, retry=900,
                                  expire=3600000, minimum=21600, ttl=3600)
        self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
        record = dns.Record_A(address=ns.ip, ttl=3600)
        self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
                                3600, record, auth=True)

    
    def _lookup(self, name, cls, type, timeout = None):
        for i in range(3):
            try:
                value = self._lookup_unsafe(name, cls, type, timeout = None)
            except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
                if i == 2:
                    raise
                print "Reloading database"
                time.sleep(0.5)
                continue
            else:
                return value

    def _lookup_unsafe(self, name, cls, type, timeout):
        invirt.database.clear_cache()
        
        ttl = 900
        name = name.lower()

        if name in self.domains:
            domain = name
        else:
            # Look for the longest-matching domain.
            best_domain = ''
            for domain in self.domains:
                if name.endswith('.'+domain) and len(domain) > len(best_domain):
                    best_domain = domain
            if best_domain == '':
                if name.endswith('.in-addr.arpa'):
                    # Act authoritative for the IP address for reverse resolution requests
                    best_domain = name
                else:
                    return defer.fail(failure.Failure(dns.DomainError(name)))
            domain = best_domain
        results = []
        authority = []
        additional = [self.ns1]
        authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                      3600, self.ns, auth=True))

        if cls == dns.IN:
            if name.endswith(".in-addr.arpa"):
                if type in (dns.PTR, dns.ALL_RECORDS):
                    ip = '.'.join(reversed(name.split('.')[:-2]))
                    value = invirt.database.NIC.query.filter_by(ip=ip).first()
                    if value and value.hostname:
                        hostname = value.hostname
                        if '.' not in hostname:
                            hostname = hostname + "." + config.dns.domains[0]
                        record = dns.Record_PTR(hostname, ttl)
                        results.append(dns.RRHeader(name, dns.PTR, dns.IN,
                                                    ttl, record, auth=True))
                    else: # IP address doesn't point to an active host
                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
                # FIXME: Should only return success with no records if the name actually exists
            elif name == domain or name == '.'+domain:
                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN,
                                                ttl, record, auth=True))
                elif type == dns.NS:
                    results.append(dns.RRHeader(domain, dns.NS, dns.IN,
                                                ttl, self.ns, auth=True))
                    authority = []
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            else:
                host = name[:-len(domain)-1]
                value = invirt.database.NIC.query.filter_by(hostname=host).first()
                if value:
                    ip = value.ip
                else:
                    value = invirt.database.Machine.query().filter_by(name=host).first()
                    if value:
                        ip = value.nics[0].ip
                    else:
                        return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

                if ip is None:
                    return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

                if type in (dns.A, dns.ALL_RECORDS):
                    record = dns.Record_A(ip, ttl)
                    results.append(dns.RRHeader(name, dns.A, dns.IN,
                                                ttl, record, auth=True))
                elif type == dns.SOA:
                    results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
                                                ttl, self.soa, auth=True))
            if len(results) == 0:
                authority = []
                additional = []
            return defer.succeed((results, authority, additional))
        else:
            #Doesn't exist
            return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))

class QuotingBindAuthority(authority.BindAuthority):
    """
    A BindAuthority that (almost) deals with quoting correctly
    
    This will catch double quotes as marking the start or end of a
    quoted phrase, unless the double quote is escaped by a backslash
    """
    # Match either a quoted or unquoted string literal followed by
    # whitespace or the end of line.  This yields two groups, one of
    # which has a match, and the other of which is None, depending on
    # whether the string literal was quoted or unquoted; this is what
    # necessitates the subsequent filtering out of groups that are
    # None.
    string_pat = \
            re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')

    # For interpreting escapes.
    escape_pat = re.compile(r'\\(.)')

    def collapseContinuations(self, lines):
        L = []
        state = 0
        for line in lines:
            if state == 0:
                if line.find('(') == -1:
                    L.append(line)
                else:
                    L.append(line[:line.find('(')])
                    state = 1
            else:
                if line.find(')') != -1:
                    L[-1] += ' ' + line[:line.find(')')]
                    state = 0
                else:
                    L[-1] += ' ' + line
        lines = L
        L = []

        for line in lines:
            in_quote = False
            split_line = []
            for m in self.string_pat.finditer(line):
                [x] = [x for x in m.groups() if x is not None]
                split_line.append(self.escape_pat.sub(r'\1', x))
            L.append(split_line)
        return filter(None, L)

if '__main__' == __name__:
    resolvers = []
    try:
        for zone in config.dns.zone_files:
            for origin in config.dns.domains:
                r = QuotingBindAuthority(zone)
                # This sucks, but if I want a generic zone file, I have to
                # reload the information by hand
                r.origin = origin
                lines = open(zone).readlines()
                lines = r.collapseContinuations(r.stripComments(lines))
                r.parseLines(lines)
                
                resolvers.append(r)
    except InvirtConfigError:
        # Don't care if zone_files isn't defined
        pass
    resolvers.append(DatabaseAuthority())

    verbosity = 0
    f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
    p = dns.DNSDatagramProtocol(f)
    f.noisy = p.noisy = verbosity
    
    reactor.listenUDP(53, p)
    reactor.listenTCP(53, f)
    reactor.run()