2 from twisted.internet import reactor
3 from twisted.names import server
4 from twisted.names import dns
5 from twisted.names import common
6 from twisted.names import authority
7 from twisted.names import resolve
8 from twisted.internet import defer
9 from twisted.python import failure
11 from invirt.common import InvirtConfigError
12 from invirt.config import structs as config
13 import invirt.database
14 from invirt.database import NIC
20 class DatabaseAuthority(common.ResolverBase):
21 """An Authority that is loaded from a file."""
25 def __init__(self, domains=None, database=None):
26 common.ResolverBase.__init__(self)
27 if database is not None:
28 invirt.database.connect(database)
30 invirt.database.connect()
31 if domains is not None:
32 self.domains = domains
34 self.domains = config.dns.domains
35 ns = config.dns.nameservers[0]
36 self.soa = dns.Record_SOA(mname=ns.hostname,
37 rname=config.dns.contact.replace('@','.',1),
38 serial=1, refresh=3600, retry=900,
39 expire=3600000, minimum=21600, ttl=3600)
40 self.ns = dns.Record_NS(name=ns.hostname, ttl=3600)
41 record = dns.Record_A(address=ns.ip, ttl=3600)
42 self.ns1 = dns.RRHeader(ns.hostname, dns.A, dns.IN,
43 3600, record, auth=True)
46 def _lookup(self, name, cls, type, timeout = None):
49 value = self._lookup_unsafe(name, cls, type, timeout = None)
50 except (psycopg2.OperationalError, sqlalchemy.exceptions.DBAPIError):
53 print "Reloading database"
59 def _lookup_unsafe(self, name, cls, type, timeout):
60 invirt.database.clear_cache()
65 if name in self.domains:
68 # Look for the longest-matching domain.
70 for domain in self.domains:
71 if name.endswith('.'+domain) and len(domain) > len(best_domain):
74 if name.endswith('.in-addr.arpa'):
75 # Act authoritative for the IP address for reverse resolution requests
78 return defer.fail(failure.Failure(dns.DomainError(name)))
82 additional = [self.ns1]
83 authority.append(dns.RRHeader(domain, dns.NS, dns.IN,
84 3600, self.ns, auth=True))
88 # - What domain: in-addr.arpa, domain root, or subdomain?
89 # - What query type: A, PTR, NS, ...?
93 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
95 if name.endswith(".in-addr.arpa"):
96 if type in (dns.PTR, dns.ALL_RECORDS):
97 ip = '.'.join(reversed(name.split('.')[:-2]))
98 value = invirt.database.NIC.query.filter((NIC.ip == ip) | (NIC.other_ip == ip)).first()
99 if value and value.hostname:
100 hostname = value.hostname
101 if '.' not in hostname:
102 if ip == value.other_ip:
103 hostname = hostname + ".other"
104 hostname = hostname + "." + config.dns.domains[0]
105 record = dns.Record_PTR(hostname, ttl)
106 results.append(dns.RRHeader(name, dns.PTR, dns.IN,
107 ttl, record, auth=True))
108 else: # IP address doesn't point to an active host
109 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
110 elif type == dns.SOA:
111 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
112 ttl, self.soa, auth=True))
113 # FIXME: Should only return success with no records if the name actually exists
115 elif name == domain or name == '.'+domain or name == 'other.'+domain:
116 if type in (dns.A, dns.ALL_RECORDS):
117 record = dns.Record_A(config.dns.nameservers[0].ip, ttl)
118 results.append(dns.RRHeader(name, dns.A, dns.IN,
119 ttl, record, auth=True))
121 results.append(dns.RRHeader(domain, dns.NS, dns.IN,
122 ttl, self.ns, auth=True))
124 elif type == dns.SOA:
125 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
126 ttl, self.soa, auth=True))
129 host = name[:-len(domain)-1]
131 if host.endswith(".other"):
132 host = host[:-len(".other")]
134 value = invirt.database.NIC.query.filter_by(hostname=host).first()
138 action = value.other_action
142 value = invirt.database.Machine.query.filter_by(name=host).first()
145 ip = value.nics[0].other_ip
146 action = value.nics[0].other_action
148 ip = value.nics[0].ip
150 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
152 return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name)))
153 if type in (dns.A, dns.ALL_RECORDS):
154 record = dns.Record_A(ip, ttl)
155 results.append(dns.RRHeader(name, dns.A, dns.IN,
156 ttl, record, auth=True))
157 if other and type in (dns.TXT, dns.ALL_RECORDS):
158 record = dns.Record_TXT(action if action else '', ttl=ttl)
159 results.append(dns.RRHeader(name, dns.TXT, dns.IN,
160 ttl, record, auth=True))
162 results.append(dns.RRHeader(domain, dns.SOA, dns.IN,
163 ttl, self.soa, auth=True))
165 if len(results) == 0:
168 return defer.succeed((results, authority, additional))
170 class DelegatingQuotingBindAuthority(authority.BindAuthority):
172 A delegating BindAuthority that (almost) deals with quoting correctly
174 This will catch double quotes as marking the start or end of a
175 quoted phrase, unless the double quote is escaped by a backslash
177 # Match either a quoted or unquoted string literal followed by
178 # whitespace or the end of line. This yields two groups, one of
179 # which has a match, and the other of which is None, depending on
180 # whether the string literal was quoted or unquoted; this is what
181 # necessitates the subsequent filtering out of groups that are
184 re.compile(r'"((?:[^"\\]|\\.)*)"|((?:[^\\\s]|\\.)+)(?:\s+|\s*$)')
186 # For interpreting escapes.
187 escape_pat = re.compile(r'\\(.)')
189 def collapseContinuations(self, lines):
194 if line.find('(') == -1:
197 L.append(line[:line.find('(')])
200 if line.find(')') != -1:
201 L[-1] += ' ' + line[:line.find(')')]
211 for m in self.string_pat.finditer(line):
212 [x] = [x for x in m.groups() if x is not None]
213 split_line.append(self.escape_pat.sub(r'\1', x))
215 return filter(None, L)
217 def _lookup(self, name, cls, type, timeout = None):
218 maybeDelegate = False
219 deferredResult = authority.BindAuthority._lookup(self, name, cls,
221 # If we didn't find an exact match for the name we were seeking,
222 # check if it's within a subdomain we're supposed to delegate to
223 # some other DNS server.
224 while (isinstance(deferredResult.result, failure.Failure)
227 name = name[name.find('.') + 1 :]
228 deferredResult = authority.BindAuthority._lookup(self, name, cls,
230 return deferredResult
232 class TypeLenientResolverChain(resolve.ResolverChain):
234 This is a ResolverChain which is more lenient in its handling of
235 queries requesting unimplemented record types.
238 def query(self, query, timeout = None):
240 return self.typeToMethod[query.type](str(query.name), timeout)
242 # We don't support the requested record type. Twisted would
243 # have us return SERVFAIL. Instead, we'll check whether the
244 # name exists in our zone at all and return NXDOMAIN or an empty
245 # result set with NOERROR as appropriate.
246 deferredResult = self.lookupAllRecords(str(query.name), timeout)
247 if isinstance(deferredResult.result, failure.Failure):
248 return deferredResult
249 return defer.succeed(([], [], []))
251 if '__main__' == __name__:
254 for zone in config.dns.zone_files:
255 for origin in config.dns.domains:
256 r = DelegatingQuotingBindAuthority(zone)
257 # This sucks, but if I want a generic zone file, I have to
258 # reload the information by hand
260 lines = open(zone).readlines()
261 lines = r.collapseContinuations(r.stripComments(lines))
265 except InvirtConfigError:
266 # Don't care if zone_files isn't defined
268 resolvers.append(DatabaseAuthority())
271 f = server.DNSServerFactory(verbose=verbosity)
272 f.resolver = TypeLenientResolverChain(resolvers)
273 p = dns.DNSDatagramProtocol(f)
274 f.noisy = p.noisy = verbosity
276 reactor.listenUDP(53, p)
277 reactor.listenTCP(53, f)