Be tolerant of requests for unimplemented record types.
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
10
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 import psycopg2
15 import sqlalchemy
16 import time
17 import re
18
19 class DatabaseAuthority(common.ResolverBase):
20     """An Authority that is loaded from a file."""
21
22     soa = None
23
24     def __init__(self, domains=None, database=None):
25         common.ResolverBase.__init__(self)
26         if database is not None:
27             invirt.database.connect(database)
28         else:
29             invirt.database.connect()
30         if domains is not None:
31             self.domains = domains
32         else:
33             self.domains = config.dns.domains
34         ns = config.dns.nameservers[0]
35         self.soa = dns.Record_SOA(mname=ns.hostname,
36                                   rname=config.dns.contact.replace('@','.',1),
37                                   serial=1, refresh=3600, retry=900,
38                                   expire=3600000, minimum=21600, ttl=3600)
39         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
40         record = dns.Record_A(address=ns.ip, ttl=3600)
41         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
42                                 3600, record, auth=True)
43
44     
45     def _lookup(self, name, cls, type, timeout = None):
46         for i in range(3):
47             try:
48                 value = self._lookup_unsafe(name, cls, type, timeout = None)
49             except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
50                 if i == 2:
51                     raise
52                 print "Reloading database"
53                 time.sleep(0.5)
54                 continue
55             else:
56                 return value
57
58     def _lookup_unsafe(self, name, cls, type, timeout):
59         invirt.database.clear_cache()
60         
61         ttl = 900
62         name = name.lower()
63
64         if name in self.domains:
65             domain = name
66         else:
67             # Look for the longest-matching domain.
68             best_domain = ''
69             for domain in self.domains:
70                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
71                     best_domain = domain
72             if best_domain == '':
73                 if name.endswith('.in-addr.arpa'):
74                     # Act authoritative for the IP address for reverse resolution requests
75                     best_domain = name
76                 else:
77                     return defer.fail(failure.Failure(dns.DomainError(name)))
78             domain = best_domain
79         results = []
80         authority = []
81         additional = [self.ns1]
82         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
83                                       3600, self.ns, auth=True))
84
85         # The order of logic:
86         # - What class?
87         # - What domain: in-addr.arpa, domain root, or subdomain?
88         # - What query type: A, PTR, NS, ...?
89
90         if cls != dns.IN:
91             # Hahaha.  No.
92             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
93
94         if name.endswith(".in-addr.arpa"):
95             if type in (dns.PTR, dns.ALL_RECORDS):
96                 ip = '.'.join(reversed(name.split('.')[:-2]))
97                 value = invirt.database.NIC.query.filter_by(ip=ip).first()
98                 if value and value.hostname:
99                     hostname = value.hostname
100                     if '.' not in hostname:
101                         hostname = hostname + "." + config.dns.domains[0]
102                     record = dns.Record_PTR(hostname, ttl)
103                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
104                                                 ttl, record, auth=True))
105                 else: # IP address doesn't point to an active host
106                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
107             elif type == dns.SOA:
108                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
109                                             ttl, self.soa, auth=True))
110             # FIXME: Should only return success with no records if the name actually exists
111
112         elif name == domain or name == '.'+domain:
113             if type in (dns.A, dns.ALL_RECORDS):
114                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
115                 results.append(dns.RRHeader(name, dns.A, dns.IN,
116                                             ttl, record, auth=True))
117             elif type == dns.NS:
118                 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
119                                             ttl, self.ns, auth=True))
120                 authority = []
121             elif type == dns.SOA:
122                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
123                                             ttl, self.soa, auth=True))
124
125         else:
126             host = name[:-len(domain)-1]
127             value = invirt.database.NIC.query.filter_by(hostname=host).first()
128             if value:
129                 ip = value.ip
130             else:
131                 value = invirt.database.Machine.query.filter_by(name=host).first()
132                 if value:
133                     ip = value.nics[0].ip
134                 else:
135                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
136             if ip is None:
137                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
138             if type in (dns.A, dns.ALL_RECORDS):
139                 record = dns.Record_A(ip, ttl)
140                 results.append(dns.RRHeader(name, dns.A, dns.IN,
141                                             ttl, record, auth=True))
142             elif type == dns.SOA:
143                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
144                                             ttl, self.soa, auth=True))
145
146         if len(results) == 0:
147             authority = []
148             additional = []
149         return defer.succeed((results, authority, additional))
150
151 class DelegatingQuotingBindAuthority(authority.BindAuthority):
152     """
153     A delegating BindAuthority that (almost) deals with quoting correctly
154     
155     This will catch double quotes as marking the start or end of a
156     quoted phrase, unless the double quote is escaped by a backslash
157     """
158     # Match either a quoted or unquoted string literal followed by
159     # whitespace or the end of line.  This yields two groups, one of
160     # which has a match, and the other of which is None, depending on
161     # whether the string literal was quoted or unquoted; this is what
162     # necessitates the subsequent filtering out of groups that are
163     # None.
164     string_pat = \
165             re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
166
167     # For interpreting escapes.
168     escape_pat = re.compile(r'\\(.)')
169
170     def collapseContinuations(self, lines):
171         L = []
172         state = 0
173         for line in lines:
174             if state == 0:
175                 if line.find('(') == -1:
176                     L.append(line)
177                 else:
178                     L.append(line[:line.find('(')])
179                     state = 1
180             else:
181                 if line.find(')') != -1:
182                     L[-1] += ' ' + line[:line.find(')')]
183                     state = 0
184                 else:
185                     L[-1] += ' ' + line
186         lines = L
187         L = []
188
189         for line in lines:
190             in_quote = False
191             split_line = []
192             for m in self.string_pat.finditer(line):
193                 [x] = [x for x in m.groups() if x is not None]
194                 split_line.append(self.escape_pat.sub(r'\1', x))
195             L.append(split_line)
196         return filter(None, L)
197
198     def _lookup(self, name, cls, type, timeout = None):
199         maybeDelegate = False
200         deferredResult = authority.BindAuthority._lookup(self, name, cls,
201                                                          type, timeout)
202         # If we didn't find an exact match for the name we were seeking,
203         # check if it's within a subdomain we're supposed to delegate to
204         # some other DNS server.
205         while (isinstance(deferredResult.result, failure.Failure)
206                and '.' in name):
207             maybeDelegate = True
208             name = name[name.find('.') + 1 :]
209             deferredResult = authority.BindAuthority._lookup(self, name, cls,
210                                                              dns.NS, timeout)
211         # If we found somewhere to delegate the query to, our _lookup()
212         # for the NS record resulted in it being in the 'results' section.
213         # We need to instead return that information in the 'authority'
214         # section to delegate, and return an empty 'results' section
215         # (because we didn't find the name we were asked about).  We
216         # leave the 'additional' section as we received it because it
217         # may contain A records for the DNS server we're delegating to.
218         if maybeDelegate and not isinstance(deferredResult.result,
219                                             failure.Failure):
220             (nsResults, nsAuthority, nsAdditional) = deferredResult.result
221             deferredResult = defer.succeed(([], nsResults, nsAdditional))
222         return deferredResult
223
224 class TypeLenientResolverChain(resolve.ResolverChain):
225     """
226     This is a ResolverChain which is more lenient in its handling of
227     queries requesting unimplemented record types.
228     """
229
230     def query(self, query, timeout = None):
231         try:
232             return self.typeToMethod[query.type](str(query.name), timeout)
233         except KeyError, e:
234             # We don't support the requested record type.  Twisted would
235             # have us return SERVFAIL.  Instead, we'll check whether the
236             # name exists in our zone at all and return NXDOMAIN or an empty
237             # result set with NOERROR as appropriate.
238             deferredResult = self.lookupAllRecords(str(query.name), timeout)
239             if isinstance(deferredResult.result, failure.Failure):
240                 return deferredResult
241             (results, authority, additional) = deferredResult.result
242             return defer.succeed(([], authority, additional))
243
244 if '__main__' == __name__:
245     resolvers = []
246     try:
247         for zone in config.dns.zone_files:
248             for origin in config.dns.domains:
249                 r = DelegatingQuotingBindAuthority(zone)
250                 # This sucks, but if I want a generic zone file, I have to
251                 # reload the information by hand
252                 r.origin = origin
253                 lines = open(zone).readlines()
254                 lines = r.collapseContinuations(r.stripComments(lines))
255                 r.parseLines(lines)
256                 
257                 resolvers.append(r)
258     except InvirtConfigError:
259         # Don't care if zone_files isn't defined
260         pass
261     resolvers.append(DatabaseAuthority())
262
263     verbosity = 0
264     f = server.DNSServerFactory(verbose=verbosity)
265     f.resolver = TypeLenientResolverChain(resolvers)
266     p = dns.DNSDatagramProtocol(f)
267     f.noisy = p.noisy = verbosity
268     
269     reactor.listenUDP(53, p)
270     reactor.listenTCP(53, f)
271     reactor.run()