Add TXT records in .other pseudo-domain to reveal the other_action value
[invirt/packages/invirt-dns.git] / invirt-dns
1 #!/usr/bin/python
2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
10
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
15 import psycopg2
16 import sqlalchemy
17 import time
18 import re
19
20 class DatabaseAuthority(common.ResolverBase):
21     """An Authority that is loaded from a file."""
22
23     soa = None
24
25     def __init__(self, domains=None, database=None):
26         common.ResolverBase.__init__(self)
27         if database is not None:
28             invirt.database.connect(database)
29         else:
30             invirt.database.connect()
31         if domains is not None:
32             self.domains = domains
33         else:
34             self.domains = config.dns.domains
35         ns = config.dns.nameservers[0]
36         self.soa = dns.Record_SOA(mname=ns.hostname,
37                                   rname=config.dns.contact.replace('@','.',1),
38                                   serial=1, refresh=3600, retry=900,
39                                   expire=3600000, minimum=21600, ttl=3600)
40         self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
41         record = dns.Record_A(address=ns.ip, ttl=3600)
42         self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
43                                 3600, record, auth=True)
44
45     
46     def _lookup(self, name, cls, type, timeout = None):
47         for i in range(3):
48             try:
49                 value = self._lookup_unsafe(name, cls, type, timeout = None)
50             except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
51                 if i == 2:
52                     raise
53                 print "Reloading database"
54                 time.sleep(0.5)
55                 continue
56             else:
57                 return value
58
59     def _lookup_unsafe(self, name, cls, type, timeout):
60         invirt.database.clear_cache()
61         
62         ttl = 900
63         name = name.lower()
64
65         if name in self.domains:
66             domain = name
67         else:
68             # Look for the longest-matching domain.
69             best_domain = ''
70             for domain in self.domains:
71                 if name.endswith('.'+domain) and len(domain) > len(best_domain):
72                     best_domain = domain
73             if best_domain == '':
74                 if name.endswith('.in-addr.arpa'):
75                     # Act authoritative for the IP address for reverse resolution requests
76                     best_domain = name
77                 else:
78                     return defer.fail(failure.Failure(dns.DomainError(name)))
79             domain = best_domain
80         results = []
81         authority = []
82         additional = [self.ns1]
83         authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
84                                       3600, self.ns, auth=True))
85
86         # The order of logic:
87         # - What class?
88         # - What domain: in-addr.arpa, domain root, or subdomain?
89         # - What query type: A, PTR, NS, ...?
90
91         if cls != dns.IN:
92             # Hahaha.  No.
93             return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
94
95         if name.endswith(".in-addr.arpa"):
96             if type in (dns.PTR, dns.ALL_RECORDS):
97                 ip = '.'.join(reversed(name.split('.')[:-2]))
98                 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
99                 if value and value.hostname:
100                     hostname = value.hostname
101                     if '.' not in hostname:
102                         if ip == value.other_ip:
103                             hostname = hostname + ".other"
104                         hostname = hostname + "." + config.dns.domains[0]
105                     record = dns.Record_PTR(hostname, ttl)
106                     results.append(dns.RRHeader(name, dns.PTR, dns.IN,
107                                                 ttl, record, auth=True))
108                 else: # IP address doesn't point to an active host
109                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
110             elif type == dns.SOA:
111                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
112                                             ttl, self.soa, auth=True))
113             # FIXME: Should only return success with no records if the name actually exists
114
115         elif name == domain or name == '.'+domain or name == 'other.'+domain:
116             if type in (dns.A, dns.ALL_RECORDS):
117                 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
118                 results.append(dns.RRHeader(name, dns.A, dns.IN,
119                                             ttl, record, auth=True))
120             elif type == dns.NS:
121                 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
122                                             ttl, self.ns, auth=True))
123                 authority = []
124             elif type == dns.SOA:
125                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
126                                             ttl, self.soa, auth=True))
127
128         else:
129             host = name[:-len(domain)-1]
130             other = False
131             if host.endswith(".other"):
132                 host = host[:-len(".other")]
133                 other = True
134             value = invirt.database.NIC.query.filter_by(hostname=host).first()
135             if value:
136                 if other:
137                     ip = value.other_ip
138                     action = value.other_action
139                 else:
140                     ip = value.ip
141             else:
142                 value = invirt.database.Machine.query.filter_by(name=host).first()
143                 if value:
144                     if other:
145                         ip = value.nics[0].other_ip
146                         action = value.nics[0].other_action
147                     else:
148                         ip = value.nics[0].ip
149                 else:
150                     return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
151             if ip is None:
152                 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
153             if type in (dns.A, dns.ALL_RECORDS):
154                 record = dns.Record_A(ip, ttl)
155                 results.append(dns.RRHeader(name, dns.A, dns.IN,
156                                             ttl, record, auth=True))
157             elif type == dns.SOA:
158                 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
159                                             ttl, self.soa, auth=True))
160             elif other and type == dns.TXT:
161                 record = dns.Record_TXT(action if action else '', ttl=ttl)
162                 results.append(dns.RRHeader(name, dns.TXT, dns.IN,
163                                             ttl, record, auth=True))
164
165         if len(results) == 0:
166             authority = []
167             additional = []
168         return defer.succeed((results, authority, additional))
169
170 class DelegatingQuotingBindAuthority(authority.BindAuthority):
171     """
172     A delegating BindAuthority that (almost) deals with quoting correctly
173     
174     This will catch double quotes as marking the start or end of a
175     quoted phrase, unless the double quote is escaped by a backslash
176     """
177     # Match either a quoted or unquoted string literal followed by
178     # whitespace or the end of line.  This yields two groups, one of
179     # which has a match, and the other of which is None, depending on
180     # whether the string literal was quoted or unquoted; this is what
181     # necessitates the subsequent filtering out of groups that are
182     # None.
183     string_pat = \
184             re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
185
186     # For interpreting escapes.
187     escape_pat = re.compile(r'\\(.)')
188
189     def collapseContinuations(self, lines):
190         L = []
191         state = 0
192         for line in lines:
193             if state == 0:
194                 if line.find('(') == -1:
195                     L.append(line)
196                 else:
197                     L.append(line[:line.find('(')])
198                     state = 1
199             else:
200                 if line.find(')') != -1:
201                     L[-1] += ' ' + line[:line.find(')')]
202                     state = 0
203                 else:
204                     L[-1] += ' ' + line
205         lines = L
206         L = []
207
208         for line in lines:
209             in_quote = False
210             split_line = []
211             for m in self.string_pat.finditer(line):
212                 [x] = [x for x in m.groups() if x is not None]
213                 split_line.append(self.escape_pat.sub(r'\1', x))
214             L.append(split_line)
215         return filter(None, L)
216
217     def _lookup(self, name, cls, type, timeout = None):
218         maybeDelegate = False
219         deferredResult = authority.BindAuthority._lookup(self, name, cls,
220                                                          type, timeout)
221         # If we didn't find an exact match for the name we were seeking,
222         # check if it's within a subdomain we're supposed to delegate to
223         # some other DNS server.
224         while (isinstance(deferredResult.result, failure.Failure)
225                and '.' in name):
226             maybeDelegate = True
227             name = name[name.find('.') + 1 :]
228             deferredResult = authority.BindAuthority._lookup(self, name, cls,
229                                                              dns.NS, timeout)
230         return deferredResult
231
232 class TypeLenientResolverChain(resolve.ResolverChain):
233     """
234     This is a ResolverChain which is more lenient in its handling of
235     queries requesting unimplemented record types.
236     """
237
238     def query(self, query, timeout = None):
239         try:
240             return self.typeToMethod[query.type](str(query.name), timeout)
241         except KeyError, e:
242             # We don't support the requested record type.  Twisted would
243             # have us return SERVFAIL.  Instead, we'll check whether the
244             # name exists in our zone at all and return NXDOMAIN or an empty
245             # result set with NOERROR as appropriate.
246             deferredResult = self.lookupAllRecords(str(query.name), timeout)
247             if isinstance(deferredResult.result, failure.Failure):
248                 return deferredResult
249             return defer.succeed(([], [], []))
250
251 if '__main__' == __name__:
252     resolvers = []
253     try:
254         for zone in config.dns.zone_files:
255             for origin in config.dns.domains:
256                 r = DelegatingQuotingBindAuthority(zone)
257                 # This sucks, but if I want a generic zone file, I have to
258                 # reload the information by hand
259                 r.origin = origin
260                 lines = open(zone).readlines()
261                 lines = r.collapseContinuations(r.stripComments(lines))
262                 r.parseLines(lines)
263                 
264                 resolvers.append(r)
265     except InvirtConfigError:
266         # Don't care if zone_files isn't defined
267         pass
268     resolvers.append(DatabaseAuthority())
269
270     verbosity = 0
271     f = server.DNSServerFactory(verbose=verbosity)
272     f.resolver = TypeLenientResolverChain(resolvers)
273     p = dns.DNSDatagramProtocol(f)
274     f.noisy = p.noisy = verbosity
275     
276     reactor.listenUDP(53, p)
277     reactor.listenTCP(53, f)
278     reactor.run()