Add ".other" pseudo-subdomain to find IPs under transition.
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
10
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
15 import psycopg2
16 import sqlalchemy
17 import time
18 import re
19
20 class DatabaseAuthority(common.ResolverBase):
21     """An Authority that is loaded from a file."""
22
23     soa = None
24
25     def __init__(self, domains=None, database=None):
26         common.ResolverBase.__init__(self)
27         if database is not None:
28             invirt.database.connect(database)
29         else:
30             invirt.database.connect()
31         if domains is not None:
32             self.domains = domains
33         else:
34             self.domains = config.dns.domains
35         ns = config.dns.nameservers[0]
36         self.soa = dns.Record_SOA(mname=ns.hostname,
37                                   rname=config.dns.contact.replace('@','.',1),
38                                   serial=1, refresh=3600, retry=900,
39                                   expire=3600000, minimum=21600, ttl=3600)
40         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
41         record = dns.Record_A(address=ns.ip, ttl=3600)
42         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
43                                 3600, record, auth=True)
44
45     
46     def _lookup(self, name, cls, type, timeout = None):
47         for i in range(3):
48             try:
49                 value = self._lookup_unsafe(name, cls, type, timeout = None)
50             except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
51                 if i == 2:
52                     raise
53                 print "Reloading database"
54                 time.sleep(0.5)
55                 continue
56             else:
57                 return value
58
59     def _lookup_unsafe(self, name, cls, type, timeout):
60         invirt.database.clear_cache()
61         
62         ttl = 900
63         name = name.lower()
64
65         if name in self.domains:
66             domain = name
67         else:
68             # Look for the longest-matching domain.
69             best_domain = ''
70             for domain in self.domains:
71                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
72                     best_domain = domain
73             if best_domain == '':
74                 if name.endswith('.in-addr.arpa'):
75                     # Act authoritative for the IP address for reverse resolution requests
76                     best_domain = name
77                 else:
78                     return defer.fail(failure.Failure(dns.DomainError(name)))
79             domain = best_domain
80         results = []
81         authority = []
82         additional = [self.ns1]
83         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
84                                       3600, self.ns, auth=True))
85
86         # The order of logic:
87         # - What class?
88         # - What domain: in-addr.arpa, domain root, or subdomain?
89         # - What query type: A, PTR, NS, ...?
90
91         if cls != dns.IN:
92             # Hahaha.  No.
93             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
94
95         if name.endswith(".in-addr.arpa"):
96             if type in (dns.PTR, dns.ALL_RECORDS):
97                 ip = '.'.join(reversed(name.split('.')[:-2]))
98                 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
99                 if value and value.hostname:
100                     hostname = value.hostname
101                     if '.' not in hostname:
102                         if ip == value.other_ip:
103                             hostname = hostname + ".other"
104                         hostname = hostname + "." + config.dns.domains[0]
105                     record = dns.Record_PTR(hostname, ttl)
106                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
107                                                 ttl, record, auth=True))
108                 else: # IP address doesn't point to an active host
109                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
110             elif type == dns.SOA:
111                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
112                                             ttl, self.soa, auth=True))
113             # FIXME: Should only return success with no records if the name actually exists
114
115         elif name == domain or name == '.'+domain or name == 'other.'+domain:
116             if type in (dns.A, dns.ALL_RECORDS):
117                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
118                 results.append(dns.RRHeader(name, dns.A, dns.IN,
119                                             ttl, record, auth=True))
120             elif type == dns.NS:
121                 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
122                                             ttl, self.ns, auth=True))
123                 authority = []
124             elif type == dns.SOA:
125                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
126                                             ttl, self.soa, auth=True))
127
128         else:
129             host = name[:-len(domain)-1]
130             other = False
131             if host.endswith(".other"):
132                 host = host[:-len(".other")]
133                 other = True
134             value = invirt.database.NIC.query.filter_by(hostname=host).first()
135             if value:
136                 if other:
137                     ip = value.other_ip
138                 else:
139                     ip = value.ip
140             else:
141                 value = invirt.database.Machine.query.filter_by(name=host).first()
142                 if value:
143                     if other:
144                         ip = value.nics[0].other_ip
145                     else:
146                         ip = value.nics[0].ip
147                 else:
148                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
149             if ip is None:
150                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
151             if type in (dns.A, dns.ALL_RECORDS):
152                 record = dns.Record_A(ip, ttl)
153                 results.append(dns.RRHeader(name, dns.A, dns.IN,
154                                             ttl, record, auth=True))
155             elif type == dns.SOA:
156                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
157                                             ttl, self.soa, auth=True))
158
159         if len(results) == 0:
160             authority = []
161             additional = []
162         return defer.succeed((results, authority, additional))
163
164 class DelegatingQuotingBindAuthority(authority.BindAuthority):
165     """
166     A delegating BindAuthority that (almost) deals with quoting correctly
167     
168     This will catch double quotes as marking the start or end of a
169     quoted phrase, unless the double quote is escaped by a backslash
170     """
171     # Match either a quoted or unquoted string literal followed by
172     # whitespace or the end of line.  This yields two groups, one of
173     # which has a match, and the other of which is None, depending on
174     # whether the string literal was quoted or unquoted; this is what
175     # necessitates the subsequent filtering out of groups that are
176     # None.
177     string_pat = \
178             re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
179
180     # For interpreting escapes.
181     escape_pat = re.compile(r'\\(.)')
182
183     def collapseContinuations(self, lines):
184         L = []
185         state = 0
186         for line in lines:
187             if state == 0:
188                 if line.find('(') == -1:
189                     L.append(line)
190                 else:
191                     L.append(line[:line.find('(')])
192                     state = 1
193             else:
194                 if line.find(')') != -1:
195                     L[-1] += ' ' + line[:line.find(')')]
196                     state = 0
197                 else:
198                     L[-1] += ' ' + line
199         lines = L
200         L = []
201
202         for line in lines:
203             in_quote = False
204             split_line = []
205             for m in self.string_pat.finditer(line):
206                 [x] = [x for x in m.groups() if x is not None]
207                 split_line.append(self.escape_pat.sub(r'\1', x))
208             L.append(split_line)
209         return filter(None, L)
210
211     def _lookup(self, name, cls, type, timeout = None):
212         maybeDelegate = False
213         deferredResult = authority.BindAuthority._lookup(self, name, cls,
214                                                          type, timeout)
215         # If we didn't find an exact match for the name we were seeking,
216         # check if it's within a subdomain we're supposed to delegate to
217         # some other DNS server.
218         while (isinstance(deferredResult.result, failure.Failure)
219                and '.' in name):
220             maybeDelegate = True
221             name = name[name.find('.') + 1 :]
222             deferredResult = authority.BindAuthority._lookup(self, name, cls,
223                                                              dns.NS, timeout)
224         return deferredResult
225
226 class TypeLenientResolverChain(resolve.ResolverChain):
227     """
228     This is a ResolverChain which is more lenient in its handling of
229     queries requesting unimplemented record types.
230     """
231
232     def query(self, query, timeout = None):
233         try:
234             return self.typeToMethod[query.type](str(query.name), timeout)
235         except KeyError, e:
236             # We don't support the requested record type.  Twisted would
237             # have us return SERVFAIL.  Instead, we'll check whether the
238             # name exists in our zone at all and return NXDOMAIN or an empty
239             # result set with NOERROR as appropriate.
240             deferredResult = self.lookupAllRecords(str(query.name), timeout)
241             if isinstance(deferredResult.result, failure.Failure):
242                 return deferredResult
243             return defer.succeed(([], [], []))
244
245 if '__main__' == __name__:
246     resolvers = []
247     try:
248         for zone in config.dns.zone_files:
249             for origin in config.dns.domains:
250                 r = DelegatingQuotingBindAuthority(zone)
251                 # This sucks, but if I want a generic zone file, I have to
252                 # reload the information by hand
253                 r.origin = origin
254                 lines = open(zone).readlines()
255                 lines = r.collapseContinuations(r.stripComments(lines))
256                 r.parseLines(lines)
257                 
258                 resolvers.append(r)
259     except InvirtConfigError:
260         # Don't care if zone_files isn't defined
261         pass
262     resolvers.append(DatabaseAuthority())
263
264     verbosity = 0
265     f = server.DNSServerFactory(verbose=verbosity)
266     f.resolver = TypeLenientResolverChain(resolvers)
267     p = dns.DNSDatagramProtocol(f)
268     f.noisy = p.noisy = verbosity
269     
270     reactor.listenUDP(53, p)
271     reactor.listenTCP(53, f)
272     reactor.run()