Add support for DNS zone files that take precedence over the database
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.internet import defer
8 from twisted.python import failure
9
10 from invirt.config import structs as config
11 import invirt.database
12 import psycopg2
13 import sqlalchemy
14 import time
15
16 class DatabaseAuthority(common.ResolverBase):
17     """An Authority that is loaded from a file."""
18
19     soa = None
20
21     def __init__(self, domains=None, database=None):
22         common.ResolverBase.__init__(self)
23         if database is not None:
24             invirt.database.connect(database)
25         else:
26             invirt.database.connect()
27         if domains is not None:
28             self.domains = domains
29         else:
30             self.domains = config.dns.domains
31         ns = config.dns.nameservers[0]
32         self.soa = dns.Record_SOA(mname=ns.hostname,
33                                   rname=config.dns.contact.replace('@','.',1),
34                                   serial=1, refresh=3600, retry=900,
35                                   expire=3600000, minimum=21600, ttl=3600)
36         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
37         record = dns.Record_A(address=ns.ip, ttl=3600)
38         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
39                                 3600, record, auth=True)
40
41     
42     def _lookup(self, name, cls, type, timeout = None):
43         for i in range(3):
44             try:
45                 value = self._lookup_unsafe(name, cls, type, timeout = None)
46             except (psycopg2.OperationalError, sqlalchemy.exceptions.SQLError):
47                 if i == 2:
48                     raise
49                 print "Reloading database"
50                 time.sleep(0.5)
51                 continue
52             else:
53                 return value
54
55     def _lookup_unsafe(self, name, cls, type, timeout):
56         invirt.database.clear_cache()
57         
58         ttl = 900
59         name = name.lower()
60
61         if name in self.domains:
62             domain = name
63         else:
64             # Look for the longest-matching domain.  (This works because domain
65             # will remain bound after breaking out of the loop.)
66             best_domain = ''
67             for domain in self.domains:
68                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
69                     best_domain = domain
70             if best_domain == '':
71                 return defer.fail(failure.Failure(dns.DomainError(name)))
72             domain = best_domain
73         results = []
74         authority = []
75         additional = [self.ns1]
76         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
77                                       3600, self.ns, auth=True))
78
79         if cls == dns.IN:
80             host = name[:-len(domain)-1]
81             if not host: # Request for the domain itself.
82                 if type in (dns.A, dns.ALL_RECORDS):
83                     record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
84                     results.append(dns.RRHeader(name, dns.A, dns.IN, 
85                                                 ttl, record, auth=True))
86                 elif type == dns.NS:
87                     results.append(dns.RRHeader(domain, dns.NS, dns.IN,
88                                                 ttl, self.ns, auth=True))
89                     authority = []
90                 elif type == dns.SOA:
91                     results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
92                                                 ttl, self.soa, auth=True))
93             else: # Request for a subdomain.
94                 if 'passup' in dir(config.dns) and host in config.dns.passup:
95                     record = dns.Record_CNAME('%s.%s' % (host, config.dns.parent), ttl)
96                     return defer.succeed((
97                         [dns.RRHeader(name, dns.CNAME, dns.IN, ttl, record, auth=True)],
98                         [], []))
99
100                 value = invirt.database.Machine.query().filter_by(name=host).first()
101                 if value is None or not value.nics:
102                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
103                 ip = value.nics[0].ip
104                 if ip is None:  #Deactivated?
105                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
106
107                 if type in (dns.A, dns.ALL_RECORDS):
108                     record = dns.Record_A(ip, ttl)
109                     results.append(dns.RRHeader(name, dns.A, dns.IN, 
110                                                 ttl, record, auth=True))
111                 elif type == dns.SOA:
112                     results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
113                                                 ttl, self.soa, auth=True))
114             if len(results) == 0:
115                 authority = []
116                 additional = []
117             return defer.succeed((results, authority, additional))
118         else:
119             #Doesn't exist
120             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
121
122 if '__main__' == __name__:
123     resolvers = []
124     for zone in config.dns.zone_files:
125         for origin in config.dns.domains:
126             r = authority.BindAuthority(zone)
127             # This sucks, but if I want a generic zone file, I have to
128             # reload the information by hand
129             r.origin = origin
130             lines = open(zone).readlines()
131             lines = r.collapseContinuations(r.stripComments(lines))
132             r.parseLines(lines)
133             
134             resolvers.append(r)
135     resolvers.append(DatabaseAuthority())
136
137     verbosity = 0
138     f = server.DNSServerFactory(authorities=resolvers, verbose=verbosity)
139     p = dns.DNSDatagramProtocol(f)
140     f.noisy = p.noisy = verbosity
141     
142     reactor.listenUDP(53, p)
143     reactor.listenTCP(53, f)
144     reactor.run()