c9ce43ed15bf692ab2edf1d62ceeac2113b11c14
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
10
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
15 import psycopg2
16 import sqlalchemy
17 import time
18 import re
19
20 class DatabaseAuthority(common.ResolverBase):
21     """An Authority that is loaded from a file."""
22
23     soa = None
24
25     def __init__(self, domains=None, database=None):
26         common.ResolverBase.__init__(self)
27         if database is not None:
28             invirt.database.connect(database)
29         else:
30             invirt.database.connect()
31         if domains is not None:
32             self.domains = domains
33         else:
34             self.domains = config.dns.domains
35         ns = config.dns.nameservers[0]
36         self.soa = dns.Record_SOA(mname=ns.hostname,
37                                   rname=config.dns.contact.replace('@','.',1),
38                                   serial=1, refresh=3600, retry=900,
39                                   expire=3600000, minimum=21600, ttl=3600)
40         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
41         record = dns.Record_A(address=ns.ip, ttl=3600)
42         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
43                                 3600, record, auth=True)
44
45     
46     def _lookup(self, name, cls, type, timeout = None):
47         for i in range(3):
48             try:
49                 value = self._lookup_unsafe(name, cls, type, timeout = None)
50             except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
51                 if i == 2:
52                     raise
53                 print "Reloading database"
54                 time.sleep(0.5)
55                 continue
56             else:
57                 return value
58
59     def _lookup_unsafe(self, name, cls, type, timeout):
60         invirt.database.clear_cache()
61         
62         ttl = 900
63         name = name.lower()
64
65         if name in self.domains:
66             domain = name
67         else:
68             # Look for the longest-matching domain.
69             best_domain = ''
70             for domain in self.domains:
71                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
72                     best_domain = domain
73             if best_domain == '':
74                 if name.endswith('.in-addr.arpa'):
75                     # Act authoritative for the IP address for reverse resolution requests
76                     best_domain = name
77                 else:
78                     return defer.fail(failure.Failure(dns.DomainError(name)))
79             domain = best_domain
80         results = []
81         authority = []
82         additional = [self.ns1]
83         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
84                                       3600, self.ns, auth=True))
85
86         # The order of logic:
87         # - What class?
88         # - What domain: in-addr.arpa, domain root, or subdomain?
89         # - What query type: A, PTR, NS, ...?
90
91         if cls != dns.IN:
92             # Hahaha.  No.
93             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
94
95         if name.endswith(".in-addr.arpa"):
96             if type in (dns.PTR, dns.ALL_RECORDS):
97                 ip = '.'.join(reversed(name.split('.')[:-2]))
98                 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
99                 if value and value.hostname:
100                     hostname = value.hostname
101                     if '.' not in hostname:
102                         if ip == value.other_ip:
103                             hostname = hostname + ".other"
104                         hostname = hostname + "." + config.dns.domains[0]
105                     record = dns.Record_PTR(hostname, ttl)
106                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
107                                                 ttl, record, auth=True))
108                 else: # IP address doesn't point to an active host
109                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
110             elif type == dns.SOA:
111                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
112                                             ttl, self.soa, auth=True))
113             # FIXME: Should only return success with no records if the name actually exists
114
115         elif name == domain or name == '.'+domain or name == 'other.'+domain:
116             if type in (dns.A, dns.ALL_RECORDS):
117                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
118                 results.append(dns.RRHeader(name, dns.A, dns.IN,
119                                             ttl, record, auth=True))
120             elif type == dns.NS:
121                 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
122                                             ttl, self.ns, auth=True))
123                 authority = []
124             elif type == dns.SOA:
125                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
126                                             ttl, self.soa, auth=True))
127
128         else:
129             host = name[:-len(domain)-1]
130             other = False
131             if host.endswith(".other"):
132                 host = host[:-len(".other")]
133                 other = True
134             value = invirt.database.NIC.query.filter_by(hostname=host).first()
135             if value:
136                 if other:
137                     ip = value.other_ip
138                     action = value.other_action
139                 else:
140                     ip = value.ip
141             else:
142                 value = invirt.database.Machine.query.filter_by(name=host).first()
143                 if value:
144                     if other:
145                         ip = value.nics[0].other_ip
146                         action = value.nics[0].other_action
147                     else:
148                         ip = value.nics[0].ip
149                 else:
150                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
151             if ip is None:
152                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
153             if type in (dns.A, dns.ALL_RECORDS):
154                 record = dns.Record_A(ip, ttl)
155                 results.append(dns.RRHeader(name, dns.A, dns.IN,
156                                             ttl, record, auth=True))
157             if other and type in (dns.TXT, dns.ALL_RECORDS):
158                 record = dns.Record_TXT(action if action else '', ttl=ttl)
159                 results.append(dns.RRHeader(name, dns.TXT, dns.IN,
160                                             ttl, record, auth=True))
161             if type == dns.SOA:
162                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
163                                             ttl, self.soa, auth=True))
164
165         if len(results) == 0:
166             authority = []
167             additional = []
168         return defer.succeed((results, authority, additional))
169
170 class DelegatingQuotingBindAuthority(authority.BindAuthority):
171     """
172     A delegating BindAuthority that (almost) deals with quoting correctly
173     
174     This will catch double quotes as marking the start or end of a
175     quoted phrase, unless the double quote is escaped by a backslash
176     """
177     # Match either a quoted or unquoted string literal followed by
178     # whitespace or the end of line.  This yields two groups, one of
179     # which has a match, and the other of which is None, depending on
180     # whether the string literal was quoted or unquoted; this is what
181     # necessitates the subsequent filtering out of groups that are
182     # None.
183     string_pat = \
184             re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
185
186     # For interpreting escapes.
187     escape_pat = re.compile(r'\\(.)')
188
189     def collapseContinuations(self, lines):
190         L = []
191         state = 0
192         for line in lines:
193             if state == 0:
194                 if line.find('(') == -1:
195                     L.append(line)
196                 else:
197                     L.append(line[:line.find('(')])
198                     state = 1
199             else:
200                 if line.find(')') != -1:
201                     L[-1] += ' ' + line[:line.find(')')]
202                     state = 0
203                 else:
204                     L[-1] += ' ' + line
205         lines = L
206         L = []
207
208         for line in lines:
209             in_quote = False
210             split_line = []
211             for m in self.string_pat.finditer(line):
212                 [x] = [x for x in m.groups() if x is not None]
213                 split_line.append(self.escape_pat.sub(r'\1', x))
214             L.append(split_line)
215         return filter(None, L)
216
217     # See https://twistedmatrix.com/documents/13.1.0/api/twisted.internet.defer.html#inlineCallbacks
218     @defer.inlineCallbacks
219     def _lookup(self, name, cls, type, timeout = None):
220         try:
221             result = yield authority.BindAuthority._lookup(self, name, cls,
222                                                            type, timeout)
223             defer.returnValue(result)
224         except dns.AuthoritativeDomainError:
225             # If we didn't find an exact match for the name we were
226             # seeking, check if it's within a subdomain we're supposed
227             # to delegate to some other DNS server.
228             while '.' in name:
229                 _, name = name.split('.', 1)
230                 try:
231                     # BindAuthority puts the NS in the authority
232                     # section automatically for us, so just return
233                     # it. We override the type to NS.
234                     result = yield authority.BindAuthority._lookup(self, name, cls,
235                                                                    dns.NS, timeout)
236                     defer.returnValue(result)
237                 except (dns.DomainError, dns.AuthoritativeDomainError):
238                     pass
239             # We didn't find a delegation, so return the original
240             # NXDOMAIN.
241             raise
242
243 class TypeLenientResolverChain(resolve.ResolverChain):
244     """
245     This is a ResolverChain which is more lenient in its handling of
246     queries requesting unimplemented record types.
247     """
248
249     def query(self, query, timeout = None):
250         try:
251             return self.typeToMethod[query.type](str(query.name), timeout)
252         except KeyError, e:
253             # We don't support the requested record type.  Twisted would
254             # have us return SERVFAIL.  Instead, we'll check whether the
255             # name exists in our zone at all and return NXDOMAIN or an empty
256             # result set with NOERROR as appropriate.
257             deferredResult = self.lookupAllRecords(str(query.name), timeout)
258             if isinstance(deferredResult.result, failure.Failure):
259                 return deferredResult
260             return defer.succeed(([], [], []))
261
262 if '__main__' == __name__:
263     resolvers = []
264     try:
265         for zone in config.dns.zone_files:
266             for origin in config.dns.domains:
267                 r = DelegatingQuotingBindAuthority(zone)
268                 # This sucks, but if I want a generic zone file, I have to
269                 # reload the information by hand
270                 r.origin = origin
271                 lines = open(zone).readlines()
272                 lines = r.collapseContinuations(r.stripComments(lines))
273                 r.parseLines(lines)
274                 
275                 resolvers.append(r)
276     except InvirtConfigError:
277         # Don't care if zone_files isn't defined
278         pass
279     resolvers.append(DatabaseAuthority())
280
281     verbosity = 0
282     f = server.DNSServerFactory(verbose=verbosity)
283     f.resolver = TypeLenientResolverChain(resolvers)
284     p = dns.DNSDatagramProtocol(f)
285     f.noisy = p.noisy = verbosity
286     
287     reactor.listenUDP(53, p)
288     reactor.listenTCP(53, f)
289     reactor.run()