2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
20 class DatabaseAuthority(common.ResolverBase):
21 """An Authority that is loaded from a file."""
25 def __init__(self, domains=None, database=None):
26 common.ResolverBase.__init__(self)
27 if database is not None:
28 invirt.database.connect(database)
30 invirt.database.connect()
31 if domains is not None:
32 self.domains = domains
34 self.domains = config.dns.domains
35 ns = config.dns.nameservers[0]
36 self.soa = dns.Record_SOA(mname=ns.hostname,
37 rname=config.dns.contact.replace('@','.',1),
38 serial=1, refresh=3600, retry=900,
39 expire=3600000, minimum=21600, ttl=3600)
40 self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
41 record = dns.Record_A(address=ns.ip, ttl=3600)
42 self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
43 3600, record, auth=True)
46 def _lookup(self, name, cls, type, timeout = None):
49 value = self._lookup_unsafe(name, cls, type, timeout = None)
50 except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
53 print "Reloading database"
59 def _lookup_unsafe(self, name, cls, type, timeout):
60 invirt.database.clear_cache()
65 if name in self.domains:
68 # Look for the longest-matching domain.
70 for domain in self.domains:
71 if name.endswith('.'+domain) and len(domain) > len(best_domain):
74 if name.endswith('.in-addr.arpa'):
75 # Act authoritative for the IP address for reverse resolution requests
78 return defer.fail(failure.Failure(dns.DomainError(name)))
82 additional = [self.ns1]
83 authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
84 3600, self.ns, auth=True))
88 # - What domain: in-addr.arpa, domain root, or subdomain?
89 # - What query type: A, PTR, NS, ...?
93 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
95 if name.endswith(".in-addr.arpa"):
96 if type in (dns.PTR, dns.ALL_RECORDS):
97 ip = '.'.join(reversed(name.split('.')[:-2]))
98 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
99 if value and value.hostname:
100 hostname = value.hostname
101 if '.' not in hostname:
102 if ip == value.other_ip:
103 hostname = hostname + ".other"
104 hostname = hostname + "." + config.dns.domains[0]
105 record = dns.Record_PTR(hostname, ttl)
106 results.append(dns.RRHeader(name, dns.PTR, dns.IN,
107 ttl, record, auth=True))
108 else: # IP address doesn't point to an active host
109 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
110 elif type == dns.SOA:
111 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
112 ttl, self.soa, auth=True))
113 # FIXME: Should only return success with no records if the name actually exists
115 elif name == domain or name == '.'+domain or name == 'other.'+domain:
116 if type in (dns.A, dns.ALL_RECORDS):
117 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
118 results.append(dns.RRHeader(name, dns.A, dns.IN,
119 ttl, record, auth=True))
121 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
122 ttl, self.ns, auth=True))
124 elif type == dns.SOA:
125 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
126 ttl, self.soa, auth=True))
129 host = name[:-len(domain)-1]
131 if host.endswith(".other"):
132 host = host[:-len(".other")]
134 value = invirt.database.NIC.query.filter_by(hostname=host).first()
141 value = invirt.database.Machine.query.filter_by(name=host).first()
144 ip = value.nics[0].other_ip
146 ip = value.nics[0].ip
148 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
150 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
151 if type in (dns.A, dns.ALL_RECORDS):
152 record = dns.Record_A(ip, ttl)
153 results.append(dns.RRHeader(name, dns.A, dns.IN,
154 ttl, record, auth=True))
155 elif type == dns.SOA:
156 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
157 ttl, self.soa, auth=True))
159 if len(results) == 0:
162 return defer.succeed((results, authority, additional))
164 class DelegatingQuotingBindAuthority(authority.BindAuthority):
166 A delegating BindAuthority that (almost) deals with quoting correctly
168 This will catch double quotes as marking the start or end of a
169 quoted phrase, unless the double quote is escaped by a backslash
171 # Match either a quoted or unquoted string literal followed by
172 # whitespace or the end of line. This yields two groups, one of
173 # which has a match, and the other of which is None, depending on
174 # whether the string literal was quoted or unquoted; this is what
175 # necessitates the subsequent filtering out of groups that are
178 re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
180 # For interpreting escapes.
181 escape_pat = re.compile(r'\\(.)')
183 def collapseContinuations(self, lines):
188 if line.find('(') == -1:
191 L.append(line[:line.find('(')])
194 if line.find(')') != -1:
195 L[-1] += ' ' + line[:line.find(')')]
205 for m in self.string_pat.finditer(line):
206 [x] = [x for x in m.groups() if x is not None]
207 split_line.append(self.escape_pat.sub(r'\1', x))
209 return filter(None, L)
211 def _lookup(self, name, cls, type, timeout = None):
212 maybeDelegate = False
213 deferredResult = authority.BindAuthority._lookup(self, name, cls,
215 # If we didn't find an exact match for the name we were seeking,
216 # check if it's within a subdomain we're supposed to delegate to
217 # some other DNS server.
218 while (isinstance(deferredResult.result, failure.Failure)
221 name = name[name.find('.') + 1 :]
222 deferredResult = authority.BindAuthority._lookup(self, name, cls,
224 return deferredResult
226 class TypeLenientResolverChain(resolve.ResolverChain):
228 This is a ResolverChain which is more lenient in its handling of
229 queries requesting unimplemented record types.
232 def query(self, query, timeout = None):
234 return self.typeToMethod[query.type](str(query.name), timeout)
236 # We don't support the requested record type. Twisted would
237 # have us return SERVFAIL. Instead, we'll check whether the
238 # name exists in our zone at all and return NXDOMAIN or an empty
239 # result set with NOERROR as appropriate.
240 deferredResult = self.lookupAllRecords(str(query.name), timeout)
241 if isinstance(deferredResult.result, failure.Failure):
242 return deferredResult
243 return defer.succeed(([], [], []))
245 if '__main__' == __name__:
248 for zone in config.dns.zone_files:
249 for origin in config.dns.domains:
250 r = DelegatingQuotingBindAuthority(zone)
251 # This sucks, but if I want a generic zone file, I have to
252 # reload the information by hand
254 lines = open(zone).readlines()
255 lines = r.collapseContinuations(r.stripComments(lines))
259 except InvirtConfigError:
260 # Don't care if zone_files isn't defined
262 resolvers.append(DatabaseAuthority())
265 f = server.DNSServerFactory(verbose=verbosity)
266 f.resolver = TypeLenientResolverChain(resolvers)
267 p = dns.DNSDatagramProtocol(f)
268 f.noisy = p.noisy = verbosity
270 reactor.listenUDP(53, p)
271 reactor.listenTCP(53, f)