Workaround a bug in Twisted's zone file parsing.
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
10
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
15 import psycopg2
16 import sqlalchemy
17 import time
18 import re
19 import sys
20
21 class DatabaseAuthority(common.ResolverBase):
22     """An Authority that is loaded from a file."""
23
24     soa = None
25
26     def __init__(self, domains=None, database=None):
27         common.ResolverBase.__init__(self)
28         if database is not None:
29             invirt.database.connect(database)
30         else:
31             invirt.database.connect()
32         if domains is not None:
33             self.domains = domains
34         else:
35             self.domains = config.dns.domains
36         ns = config.dns.nameservers[0]
37         self.soa = dns.Record_SOA(mname=ns.hostname,
38                                   rname=config.dns.contact.replace('@','.',1),
39                                   serial=1, refresh=3600, retry=900,
40                                   expire=3600000, minimum=21600, ttl=3600)
41         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
42         record = dns.Record_A(address=ns.ip, ttl=3600)
43         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
44                                 3600, record, auth=True)
45
46     
47     def _lookup(self, name, cls, type, timeout = None):
48         for i in range(3):
49             try:
50                 value = self._lookup_unsafe(name, cls, type, timeout = None)
51             except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
52                 if i == 2:
53                     raise
54                 print "Reloading database"
55                 time.sleep(0.5)
56                 continue
57             else:
58                 return value
59
60     def _lookup_unsafe(self, name, cls, type, timeout):
61         invirt.database.clear_cache()
62         
63         ttl = 900
64         name = name.lower()
65
66         if name in self.domains:
67             domain = name
68         else:
69             # Look for the longest-matching domain.
70             best_domain = ''
71             for domain in self.domains:
72                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
73                     best_domain = domain
74             if best_domain == '':
75                 if name.endswith('.in-addr.arpa'):
76                     # Act authoritative for the IP address for reverse resolution requests
77                     best_domain = name
78                 else:
79                     return defer.fail(failure.Failure(dns.DomainError(name)))
80             domain = best_domain
81         results = []
82         authority = []
83         additional = [self.ns1]
84         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
85                                       3600, self.ns, auth=True))
86
87         # The order of logic:
88         # - What class?
89         # - What domain: in-addr.arpa, domain root, or subdomain?
90         # - What query type: A, PTR, NS, ...?
91
92         if cls != dns.IN:
93             # Hahaha.  No.
94             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
95
96         if name.endswith(".in-addr.arpa"):
97             if type in (dns.PTR, dns.ALL_RECORDS):
98                 ip = '.'.join(reversed(name.split('.')[:-2]))
99                 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
100                 if value and value.hostname:
101                     hostname = value.hostname
102                     if '.' not in hostname:
103                         if ip == value.other_ip:
104                             hostname = hostname + ".other"
105                         hostname = hostname + "." + config.dns.domains[0]
106                     record = dns.Record_PTR(hostname, ttl)
107                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
108                                                 ttl, record, auth=True))
109                 else: # IP address doesn't point to an active host
110                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
111             elif type == dns.SOA:
112                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
113                                             ttl, self.soa, auth=True))
114             # FIXME: Should only return success with no records if the name actually exists
115
116         elif name == domain or name == '.'+domain or name == 'other.'+domain:
117             if type in (dns.A, dns.ALL_RECORDS):
118                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
119                 results.append(dns.RRHeader(name, dns.A, dns.IN,
120                                             ttl, record, auth=True))
121             elif type == dns.NS:
122                 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
123                                             ttl, self.ns, auth=True))
124                 authority = []
125             elif type == dns.SOA:
126                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
127                                             ttl, self.soa, auth=True))
128
129         else:
130             host = name[:-len(domain)-1]
131             other = False
132             if host.endswith(".other"):
133                 host = host[:-len(".other")]
134                 other = True
135             value = invirt.database.NIC.query.filter_by(hostname=host).first()
136             if value:
137                 if other:
138                     ip = value.other_ip
139                     action = value.other_action
140                 else:
141                     ip = value.ip
142             else:
143                 value = invirt.database.Machine.query.filter_by(name=host).first()
144                 if value:
145                     if other:
146                         ip = value.nics[0].other_ip
147                         action = value.nics[0].other_action
148                     else:
149                         ip = value.nics[0].ip
150                 else:
151                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
152             if ip is None:
153                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
154             if type in (dns.A, dns.ALL_RECORDS):
155                 record = dns.Record_A(ip, ttl)
156                 results.append(dns.RRHeader(name, dns.A, dns.IN,
157                                             ttl, record, auth=True))
158             if other and type in (dns.TXT, dns.ALL_RECORDS):
159                 record = dns.Record_TXT(action if action else '', ttl=ttl)
160                 results.append(dns.RRHeader(name, dns.TXT, dns.IN,
161                                             ttl, record, auth=True))
162             if type == dns.SOA:
163                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
164                                             ttl, self.soa, auth=True))
165
166         if len(results) == 0:
167             authority = []
168             additional = []
169         return defer.succeed((results, authority, additional))
170
171 class DelegatingQuotingBindAuthority(authority.BindAuthority):
172     """
173     A delegating BindAuthority that (almost) deals with quoting correctly
174     
175     This will catch double quotes as marking the start or end of a
176     quoted phrase, unless the double quote is escaped by a backslash
177     """
178     # Match either a quoted or unquoted string literal followed by
179     # whitespace or the end of line.  This yields two groups, one of
180     # which has a match, and the other of which is None, depending on
181     # whether the string literal was quoted or unquoted; this is what
182     # necessitates the subsequent filtering out of groups that are
183     # None.
184     string_pat = \
185             re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
186
187     # For interpreting escapes.
188     escape_pat = re.compile(r'\\(.)')
189
190     def collapseContinuations(self, lines):
191         L = []
192         state = 0
193         for line in lines:
194             if state == 0:
195                 if line.find('(') == -1:
196                     L.append(line)
197                 else:
198                     L.append(line[:line.find('(')])
199                     state = 1
200             else:
201                 if line.find(')') != -1:
202                     L[-1] += ' ' + line[:line.find(')')]
203                     state = 0
204                 else:
205                     L[-1] += ' ' + line
206         lines = L
207         L = []
208
209         for line in lines:
210             in_quote = False
211             split_line = []
212             for m in self.string_pat.finditer(line):
213                 [x] = [x for x in m.groups() if x is not None]
214                 split_line.append(self.escape_pat.sub(r'\1', x))
215             L.append(split_line)
216         return filter(None, L)
217
218     # See https://twistedmatrix.com/documents/13.1.0/api/twisted.internet.defer.html#inlineCallbacks
219     @defer.inlineCallbacks
220     def _lookup(self, name, cls, type, timeout = None):
221         try:
222             result = yield authority.BindAuthority._lookup(self, name, cls,
223                                                            type, timeout)
224             defer.returnValue(result)
225         except Exception as e:
226             # XXX: Twisted returns DomainError even if it is
227             # authoritative for the domain because our SOA record
228             # incorrectly contains (origin + "." + origin)
229             if not isinstance(e, (dns.DomainError, dns.AuthoritativeDomainError)):
230                 sys.stderr.write("while looking up '%s', got: %s\n" % (name, e))
231
232             # If we didn't find an exact match for the name we were
233             # seeking, check if it's within a subdomain we're supposed
234             # to delegate to some other DNS server.
235             while '.' in name:
236                 _, name = name.split('.', 1)
237                 try:
238                     # BindAuthority puts the NS in the authority
239                     # section automatically for us, so just return
240                     # it. We override the type to NS.
241                     result = yield authority.BindAuthority._lookup(self, name, cls,
242                                                                    dns.NS, timeout)
243                     defer.returnValue(result)
244                 except Exception: # Should be one of (dns.DomainError, dns.AuthoritativeDomainError)
245                     pass
246             # We didn't find a delegation, so return the original
247             # NXDOMAIN.
248             raise
249
250 class TypeLenientResolverChain(resolve.ResolverChain):
251     """
252     This is a ResolverChain which is more lenient in its handling of
253     queries requesting unimplemented record types.
254     """
255
256     def query(self, query, timeout = None):
257         try:
258             return self.typeToMethod[query.type](str(query.name), timeout)
259         except KeyError, e:
260             # We don't support the requested record type.  Twisted would
261             # have us return SERVFAIL.  Instead, we'll check whether the
262             # name exists in our zone at all and return NXDOMAIN or an empty
263             # result set with NOERROR as appropriate.
264             deferredResult = self.lookupAllRecords(str(query.name), timeout)
265             if isinstance(deferredResult.result, failure.Failure):
266                 return deferredResult
267             return defer.succeed(([], [], []))
268
269 if '__main__' == __name__:
270     resolvers = []
271     try:
272         for zone in config.dns.zone_files:
273             for origin in config.dns.domains:
274                 r = DelegatingQuotingBindAuthority(zone)
275                 # This sucks, but if I want a generic zone file, I have to
276                 # reload the information by hand
277                 # XXX: This causes our SOA record to contain
278                 # (origin + "." + origin)
279                 # As a result the resolver never believes it is authoritative.
280                 r.origin = origin
281                 lines = open(zone).readlines()
282                 lines = r.collapseContinuations(r.stripComments(lines))
283                 r.parseLines(lines)
284                 
285                 resolvers.append(r)
286     except InvirtConfigError:
287         # Don't care if zone_files isn't defined
288         pass
289     resolvers.append(DatabaseAuthority())
290
291     verbosity = 0
292     f = server.DNSServerFactory(verbose=verbosity)
293     f.resolver = TypeLenientResolverChain(resolvers)
294     p = dns.DNSDatagramProtocol(f)
295     f.noisy = p.noisy = verbosity
296     
297     reactor.listenUDP(53, p)
298     reactor.listenTCP(53, f)
299     reactor.run()