From: Eric Price Date: Wed, 10 Oct 2007 12:52:32 +0000 (-0400) Subject: DNS server! X-Git-Tag: sipb-xen-dns/1~10^2~1 X-Git-Url: http://xvm.mit.edu/gitweb/invirt/packages/invirt-dns.git/commitdiff_plain/38d26aed26a4daf8b387d6a3b1bba084d2a88c6c DNS server! svn path=/trunk/dns/; revision=181 --- diff --git a/db.ca b/db.ca deleted file mode 100644 index a20028e..0000000 --- a/db.ca +++ /dev/null @@ -1,83 +0,0 @@ -; This file holds the information on root name servers needed to -; initialize cache of Internet domain name servers -; (e.g. reference this file in the "cache . " -; configuration file of BIND domain name servers). -; -; This file is made available by InterNIC registration services -; under anonymous FTP as -; file /domain/named.root -; on server FTP.RS.INTERNIC.NET -; -OR- under Gopher at RS.INTERNIC.NET -; under menu InterNIC Registration Services (NSI) -; submenu InterNIC Registration Archives -; file named.root -; -; last update: Aug 22, 1997 -; related version of root zone: 1997082200 -; -; -; formerly NS.INTERNIC.NET -; -. 3600000 IN NS A.ROOT-SERVERS.NET. -A.ROOT-SERVERS.NET. 3600000 A 198.41.0.4 -; -; formerly NS1.ISI.EDU -; -. 3600000 NS B.ROOT-SERVERS.NET. -B.ROOT-SERVERS.NET. 3600000 A 128.9.0.107 -; -; formerly C.PSI.NET -; -. 3600000 NS C.ROOT-SERVERS.NET. -C.ROOT-SERVERS.NET. 3600000 A 192.33.4.12 -; -; formerly TERP.UMD.EDU -; -. 3600000 NS D.ROOT-SERVERS.NET. -D.ROOT-SERVERS.NET. 3600000 A 128.8.10.90 -; -; formerly NS.NASA.GOV -; -. 3600000 NS E.ROOT-SERVERS.NET. -E.ROOT-SERVERS.NET. 3600000 A 192.203.230.10 -; -; formerly NS.ISC.ORG -; -. 3600000 NS F.ROOT-SERVERS.NET. -F.ROOT-SERVERS.NET. 3600000 A 192.5.5.241 -; -; formerly NS.NIC.DDN.MIL -; -. 3600000 NS G.ROOT-SERVERS.NET. -G.ROOT-SERVERS.NET. 3600000 A 192.112.36.4 -; -; formerly AOS.ARL.ARMY.MIL -; -. 3600000 NS H.ROOT-SERVERS.NET. -H.ROOT-SERVERS.NET. 3600000 A 128.63.2.53 -; -; formerly NIC.NORDU.NET -; -. 3600000 NS I.ROOT-SERVERS.NET. -I.ROOT-SERVERS.NET. 3600000 A 192.36.148.17 -; -; temporarily housed at NSI (InterNIC) -; -. 3600000 NS J.ROOT-SERVERS.NET. -J.ROOT-SERVERS.NET. 3600000 A 198.41.0.10 -; -; housed in LINX, operated by RIPE NCC -; -. 3600000 NS K.ROOT-SERVERS.NET. -K.ROOT-SERVERS.NET. 3600000 A 193.0.14.129 -; -; temporarily housed at ISI (IANA) -; -. 3600000 NS L.ROOT-SERVERS.NET. -L.ROOT-SERVERS.NET. 3600000 A 198.32.64.12 -; -; housed in Japan, operated by WIDE -; -. 3600000 NS M.ROOT-SERVERS.NET. -M.ROOT-SERVERS.NET. 3600000 A 202.12.27.33 -; End of File diff --git a/db.servers.csail.mit.edu b/db.servers.csail.mit.edu deleted file mode 100644 index e0f5cb5..0000000 --- a/db.servers.csail.mit.edu +++ /dev/null @@ -1,7 +0,0 @@ -servers.csail.mit.edu. 10 IN SOA sipb-xen-dev.mit.edu. sipb-xen.mit.edu. ( - 3 - 10800 - 3600 - 604800 - 3600 ) -servers.csail.mit.edu. 10 IN NS sipb-xen-dev.mit.edu. diff --git a/dnsserver.py b/dnsserver.py new file mode 100644 index 0000000..c2069df --- /dev/null +++ b/dnsserver.py @@ -0,0 +1,61 @@ +#!/usr/bin/python +from twisted.internet import reactor +from twisted.names import server +from twisted.names import dns +from twisted.names import common +from twisted.internet import defer +from twisted.python import failure + +import sipb_xen_database + +class DatabaseAuthority(common.ResolverBase): + """An Authority that is loaded from a file.""" + + soa = None + + def __init__(self, domain, database=None): + common.ResolverBase.__init__(self) + if database is not None: + sipb_xen_database.connect(database) + self.domain = domain + self.soa = dns.Record_SOA(mname='sipb-xen-dev.mit.edu', + rname='sipb-xen.mit.edu', + serial=1, refresh=3600, retry=900, + expire=3600000, minimum=21600, ttl=3600) + def _lookup(self, name, cls, type, timeout = None): + if not (name.lower() == self.domain or + name.lower().endswith('.'+self.domain)): + #Not us + return defer.fail(failure.Failure(dns.DomainError(name))) + results = [] + if cls == dns.IN and type in (dns.A, dns.ALL_RECORDS): + host = name[:-len(self.domain)-1] + value = sipb_xen_database.NIC.get_by(hostname=host) + if value is None: + return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) + ip = value.ip + if ip is None: #Deactivated? + return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) + ttl = 900 + record = dns.Record_A(ip, ttl) + results.append(dns.RRHeader(name, dns.A, dns.IN, + ttl, record, auth=True)) + authority = [] + authority.append(dns.RRHeader(self.domain, dns.SOA, dns.IN, 3600, + self.soa, auth=True)) + return defer.succeed((results, authority, [])) + #Doesn't exist + return defer.fail(failure.Failure(dns.AuthoritativeDomainError(name))) + +if '__main__' == __name__: + resolver = DatabaseAuthority('servers.csail.mit.edu', + 'postgres://sipb-xen@sipb-xen-dev/sipb_xen') + + verbosity = 0 + f = server.DNSServerFactory(authorities=[resolver], verbose=verbosity) + p = dns.DNSDatagramProtocol(f) + f.noisy = p.noisy = verbosity + + reactor.listenUDP(53, p) + reactor.listenTCP(53, f) + reactor.run() diff --git a/nameserver.py b/nameserver.py deleted file mode 100755 index 658a662..0000000 --- a/nameserver.py +++ /dev/null @@ -1,2960 +0,0 @@ -#!/usr/bin/python -# Python Domain Name Server -# Copyright (C) 2002 Digital Lumber, Inc. - -# This library is free software; you can redistribute it and/or -# modify it under the terms of the GNU Lesser General Public -# License as published by the Free Software Foundation; either -# version 2.1 of the License, or (at your option) any later version. - -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -# Lesser General Public License for more details. - -# You should have received a copy of the GNU Lesser General Public -# License along with this library; if not, write to the Free Software -# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - -import socket -import asyncore -import asynchat -import select -import types -import random -import time -import signal -import string -import sys -import sipb_xen_database -from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \ - ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT - -# EXAMPLE ZONE FILE DATA STRUCTURE - -# NOTE: -# There are no trailing dots in the internal data -# structure. Although it's hard to tell by reading -# the RFC's, the dots on the end of names are just -# used internally by the resolvers and servers to -# see if they need to append a domain name onto -# the end of names. There are no trailing dots -# on names in queries on the network. - -examplenet = {'example.net':{'SOA':[{'class':'IN', - 'ttl':10, - 'mname':'ns1.example.net', - 'rname':'hostmaster.example.net', - 'serial':1, - 'refresh':10800, - 'retry':3600, - 'expire':604800, - 'minimum':3600}], - 'NS':[{'class':'IN', - 'ttl':10, - 'nsdname':'ns1.example.net'}, - {'ttl':10, - 'nsdname':'ns2.example.net'}], - 'MX':[{'class':'IN', - 'ttl':10, - 'preference':10, - 'exchange':'mail.example.net'}]}, - 'server1.example.net':{'A':[{'class':'IN', - 'ttl':10, - 'address':'10.1.2.3'}]}, - 'www.example.net':{'CNAME':[{'class':'IN', - 'ttl':10, - 'cname':'server1.example.net'}]}, - 'router.example.net':{'A':[{'class':'IN', - 'ttl':10, - 'address':'10.1.2.1'}, - {'class':'IN', - 'ttl':10, - 'address':'10.2.1.1'}]} - - } - -# setup logging defaults -loglevel = 0 -logfile = sys.stdout - -try: - file -except NameError: - def file(name, mode='r', buffer=0): - return open(name, mode, buffer) - -def log(level,msg): - if level <= loglevel: - logfile.write(msg+'\n') - -def timestamp(): - return time.strftime('%m/%d/%y %H:%M:%S')+ '-' - -def inttoasc(number): - try: - hs = hex(number)[2:] - except: - log(0,'inttoasc cannot convert ' + repr(number)) - if hs[-1:].upper() == 'L': - hs = hs[:-1] - result = '' - while len(hs) > 2: - result = chr(int(hs[-2:],16)) + result - hs = hs[:-2] - result = chr(int(hs,16)) + result - - return result - -def asctoint(ascnum): - rascnum = '' - for i in range(len(ascnum)-1,-1,-1): - rascnum = rascnum + ascnum[i] - result = 0 - count = 0 - for c in rascnum: - x = ord(c) << (8*count) - result = result + x - count = count + 1 - - return result - -def ipv6net_aton(ip_string): - packed_ip = '' - # first account for shorthand syntax - pieces = ip_string.split(':') - pcount = 0 - for part in pieces: - if part != '': - pcount = pcount + 1 - if pcount < 8: - rs = '0:'*(8-pcount) - ip_string = ip_string.replace('::',':'+rs) - if ip_string[0] == ':': - ip_string = ip_string[1:] - pieces = ip_string.split(':') - for part in pieces: - # pad with the zeros - i = 4-len(part) - part = i*'0'+part - packed_ip = packed_ip + chr(int(part[:2],16))+ chr(int(part[2:],16)) - return packed_ip - -def ipv6net_ntoa(packed_ip): - ip_string = '' - count = 0 - for c in packed_ip: - ip_string = ip_string + hex(ord(c))[2:] - count = count + 1 - if count == 2: - ip_string = ip_string + ':' - count = 0 - return ip_string[:-1] - -def getversion(qname, id, rd, ra, versionstr): - msg = message() - msg.header.id = id - msg.header.qr = 1 - msg.header.aa = 1 - msg.header.rd = rd - msg.header.ra = ra - msg.header.rcode = 0 - msg.question.qname = qname - msg.question.qtype = 'TXT' - msg.question.qclass = 'CH' - if qname == 'version.bind': - msg.header.ancount = 2 - msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak', - 'ttl':360000, - 'class':'CH'}]}}) - msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr, - 'ttl':360000, - 'class':'CH'}]}}) - else: - msg.header.ancount = 1 - msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr, - 'ttl':360000, - 'class':'CH'}]}}) - return msg - -def getrcode(rcode): - if rcode == 0: - rcodestr = 'NOERROR(No error condition)' - elif rcode == 1: - rcodestr = 'FORMERR(Format Error)' - elif rcode == 2: - rcodestr = 'SERVFAIL(Internal failure)' - elif rcode == 3: - rcodestr = 'NXDOMAIN(Name does not exist)' - elif rcode == 4: - rcodestr = 'NOTIMP(Not Implemented)' - elif rcode == 5: - rcodestr = 'REFUSED(Security violation)' - elif rcode == 6: - rcodestr = 'YXDOMAIN(Name exists)' - elif rcode == 7: - rcodestr = 'YXRRSET(RR exists)' - elif rcode == 8: - rcodestr = 'NXRRSET(RR does not exist)' - elif rcode == 9: - rcodestr = 'NOTAUTH(Server not Authoritative)' - elif rcode == 10: - rcodestr = 'NOTZONE(Name not in zone)' - else: - rcodestr = 'Unknown RCODE(' + str(rcode) + ')' - return rcodestr - -def printrdata(dnstype, rdata): - if dnstype == 'A': - return rdata['address'] - elif dnstype == 'MX': - return str(rdata['preference'])+'\t'+rdata['exchange']+'.' - elif dnstype == 'NS': - return rdata['nsdname']+'.' - elif dnstype == 'PTR': - return rdata['ptrdname']+'.' - elif dnstype == 'CNAME': - return rdata['cname']+'.' - elif dnstype == 'SOA': - return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+ - 35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+ - str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )') - -def makezonedatalist(zonedata, origin): - # unravel structure into list - zonedatalist = [] - # get soa first - soanode = zonedata[origin] - zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]]) - for item in soanode.keys(): - if item != 'SOA': - for listitem in soanode[item]: - zonedatalist.append([origin+'.', item, listitem]) - for nodename in zonedata.keys(): - if nodename != origin: - for item in zonedata[nodename].keys(): - for listitem in zonedata[nodename][item]: - zonedatalist.append([nodename+'.', item, listitem]) - return zonedatalist - -def writezonefile(zonedata, origin, file): - zonedatalist = makezonedatalist(zonedata, origin) - for rr in zonedatalist: - owner = rr[0] - dnstype = rr[1] - line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' + - dnstype + '\t' + printrdata(dnstype, rr[2])) - file.write(line + '\n') - -def readzonefiles(zonedict): - for k in zonedict.keys(): - filepath = zonedict[k]['filename'] - try: - pr = zonefileparser() - pr.parse(zonedict[k]['origin'],filepath) - zonedict[k]['zonedata'] = pr.getzdict() - except ZonefileError, lineno: - log(0,'Error reading zone file ' + filepath + ' at line ' + - str(lineno) + '\n') - del zonedict[k] - -def slowloop(tofunc='',timeout=5.0): - if not tofunc: - def tofunc(): return - map = asyncore.socket_map - while map: - r = []; w=[]; e=[] - for fd, obj in map.items(): - if obj.readable(): - r.append(fd) - if obj.writable(): - w.append(fd) - try: - starttime = time.time() - r,w,e = select.select(r,w,e,timeout) - endtime = time.time() - if endtime-starttime >= timeout: - tofunc() - except select.error, err: - if err[0] != EINTR: - raise - r=[]; w=[]; e=[] - log(0,'ERROR in select') - - for fd in r: - try: - obj=map[fd] - except KeyError: - log(0,'KeyError in socket map') - continue - try: - obj.handle_read_event() - except: - log(0,'calling HANDLE ERROR from loop') - log(0,repr(obj)) - obj.handle_error() - for fd in w: - try: - obj=map[fd] - except KeyError: - log(0,'KeyError in socket map') - continue - try: - obj.handle_read_event() - except: - log(0,'calling HANDLE ERROR from loop') - log(0,repr(obj)) - obj.handle_error() - -def fastloop(tofunc='',timeout=5.0): - if not tofunc: - def tofunc(): return - polltimeout = timeout*1000 - map = asyncore.socket_map - while map: - regfds = 0 - pollobj = select.poll() - for fd, obj in map.items(): - flags = 0 - if obj.readable(): - flags = select.POLLIN - if obj.writable(): - flags = flags | select.POLLOUT - if flags: - pollobj.register(fd, flags) - regfds = regfds + 1 - try: - starttime = time.time() - r = pollobj.poll(polltimeout) - endtime = time.time() - if endtime-starttime >= timeout: - tofunc() - except select.error, err: - if err[0] != EINTR: - raise - r = [] - log(0,'ERROR in select') - for fd, flags in r: - try: - obj = map[fd] - badvals = (select.POLLPRI + select.POLLERR + - select.POLLHUP + select.POLLNVAL) - if (flags & badvals): - if (flags & select.POLLPRI): - log(0,'POLLPRI') - if (flags & select.POLLERR): - log(0,'POLLERR') - if (flags & select.POLLHUP): - log(0,'POLLHUP') - if (flags & select.POLLNVAL): - log(0,'POLLNVAL') - obj.handle_error() - else: - if (flags & select.POLLIN): - obj.handle_read_event() - if (flags & select.POLLOUT): - obj.handle_write_event() - except KeyError: - log(0,'KeyError in socket map') - continue - except: - # print traceback - sf = StringIO.StringIO() - traceback.print_exc(file=sf) - log(0,'ERROR IN LOOP:') - log(0,sf.getvalue()) - sf.close() - log(0,repr(obj)) - obj.handle_error() - -if hasattr(select,'poll'): - loop = fastloop -else: - loop = slowloop - -class ZonefileError(Exception): - def __init__(self, linenum, errordesc=''): - self.linenum = linenum - self.errordesc = errordesc - def __str__(self): - return str(self.linenum) + ' (' + self.errordesc + ')' - -class zonefileparser: - def __init__(self): - self.zonedata = {} - self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX', - 'NS','PTR','RP','SOA','SRV','TXT'] - - def stripcomments(self, line): - i = line.find(';') - if i >= 0: - line = line[:i] - return line - - def strip(self, line): - # strip trailing linefeeds - if line[-1:] == '\n': - line = line[:-1] - return line - - def getzdict(self): - return self.zonedata - - def addorigin(self, origin, name): - if name[-1:] != '.': - return name + '.' + origin - else: - return name[:-1] - - def getstrings(self, s): - if s.find('"') == -1: - return s.split() - else: - x = s.split('"') - rlist = [] - for i in x: - if i != '' and i != ' ': - rlist.append(i) - return rlist - - def getlocsize(self, s): - if s[-1:] == 'm': - size = float(s[:-1])*100 - else: - size = float(s)*100 - i = 0 - while size > 9: - size = size/10 - i = i + 1 - return (int(size),i) - - def getloclat(self, l,c): - deg = float(l[0]) - min = 0 - secs = 0 - if len(l) == 3: - min = float(l[1]) - secs = float(l[2]) - elif len(l) == 2: - min = float(l[1]) - rval = ((((deg *60) + min) * 60) + secs) * 1000 - if c in ['N','E']: - rval = rval + (2**31) - elif c in ['S','W']: - rval = (2**31) - rval - else: - log(0,'ERROR: unsupported latitude/longitude direction') - return long(rval) - - def getgname(self, name, iter): - if name == '0' or name == 'O': - return '' - start = 0 - offset = 0 - width = 0 - base = 'd' - for x in range(name.count('$')): - i = name.find('$',start) - j = i - start = i+1 - if i>0: - if name[i-1] == '\\': - continue - if len(name)>i+1: - if name[i+1] == '$': - continue - if name[i+1] == '{': - j = name.find('}',i+1) - owb = name[i+2:j].split(',') - if len(owb) == 1: - offset = int(owb[0]) - elif len(owb) == 2: - offset = int(owb[0]) - width = int(owb[1]) - elif len(owb) == 3: - offset = int(owb[0]) - width = int(owb[1]) - base = owb[2] - val = iter - offset - if base == 'd': - rs = str(val) - elif base == 'o': - rs = oct(val) - elif base == 'x': - rs = hex(val)[2:].lower() - elif base == 'X': - rs = hex(val)[2:].upper() - else: - rs = '' - if len(rs) > width: - rs = (width-len(rs))*'0'+rs - name = name[:i]+rs+name[j+1:] - start = i+len(rs)+1 - - return name - - def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens): - rdata = {} - rdata['class'] = dnsclass - rdata['ttl'] = ttl - if dnstype == 'A': - rdata['address'] = tokens[0] - elif dnstype == 'AAAA': - rdata['address'] = tokens[0] - elif dnstype == 'CNAME': - rdata['cname'] = self.addorigin(origin,tokens[0].lower()) - elif dnstype == 'HINFO': - sl = self.getstrings(' '.join(tokens)) - rdata['cpu'] = sl[0] - rdata['os'] = sl[1] - elif dnstype == 'LOC': - if 'N' in tokens: - i = tokens.index('N') - else: - i = tokens.index('S') - lat = self.getloclat(tokens[0:i],tokens[i]) - if 'E' in tokens: - j = tokens.index('E') - else: - j = tokens.index('W') - lng = self.getloclat(tokens[i+1:j],tokens[j]) - size = self.getlocsize('1m') - horiz_pre = self.getlocsize('1000m') - vert_pre = self.getlocsize('10m') - if len(tokens[j+1:]) == 2: - size = self.getlocsize(tokens[-1:][0]) - elif len(tokens[j+1:]) == 3: - size = self.getlocsize(tokens[-2:-1][0]) - horiz_pre = self.getlocsize(tokens[-1:][0]) - elif len(tokens[j+1:]) == 4: - size = self.getlocsize(tokens[-3:-2][0]) - horiz_pre = self.getlocsize(tokens[-2:-1][0]) - vert_pre = self.getlocsize(tokens[-1:][0]) - if tokens[j+1][-1:] == 'm': - alt = int((float(tokens[j+1][:-1])*100)+10000000) - else: - size = int((float(tokens[j+1])*100)+10000000) - rdata['version'] = 0 - rdata['size'] = size - rdata['horiz_pre'] = horiz_pre - rdata['vert_pre'] = vert_pre - rdata['latitude'] = lat - rdata['longitude'] = lng - rdata['altitude'] = 0 - elif dnstype == 'MX': - rdata['preference'] = int(tokens[0]) - rdata['exchange'] = self.addorigin(origin,tokens[1].lower()) - elif dnstype == 'NS': - rdata['nsdname'] = self.addorigin(origin,tokens[0].lower()) - elif dnstype == 'PTR': - rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower()) - elif dnstype == 'RP': - rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower()) - rdata['txtdname'] = self.addorigin(origin,tokens[1].lower()) - elif dnstype == 'SOA': - rdata['mname'] = self.addorigin(origin,tokens[0].lower()) - rdata['rname'] = self.addorigin(origin,tokens[1].lower()) - rdata['serial'] = int(tokens[2]) - rdata['refresh'] = int(tokens[3]) - rdata['retry'] = int(tokens[4]) - rdata['expire'] = int(tokens[5]) - rdata['minimum'] = int(tokens[6]) - elif dnstype == 'SRV': - rdata['priority'] = int(tokens[0]) - rdata['weight'] = int(tokens[1]) - rdata['port'] = int(tokens[2]) - rdata['target'] = self.addorigin(origin,tokens[3].lower()) - elif dnstype == 'TXT': - rdata['txtdata'] = self.getstrings(' '.join(tokens))[0] - else: - raise ZonefileError(lineno,'bad DNS type') - return rdata - - def addrec(self, owner, dnstype, rrdata): - if self.zonedata.has_key(owner): - if not self.zonedata[owner].has_key(dnstype): - self.zonedata[owner][dnstype] = [] - else: - self.zonedata[owner] = {} - self.zonedata[owner][dnstype] = [] - self.zonedata[owner][dnstype].append(rrdata) - - def parse(self, origin, f): - closefile = 0 - if type(f) != types.FileType: - # must be a path - try: - f = file(f) - closefile = 1 - except: - log(0,'Invalid path to zonefile') - return - lastowner = '' - lastdnsclass = '' - lastttl = 3600 - lineno = 0 - while 1: - line = f.readline() - if not line: - break - lineno = lineno + 1 - line = self.stripcomments(line) - line = self.strip(line) - if not line: - continue - if line.find('(') >= 0: - # grab lines until end paren - if line.find(')') == -1: - line2 = self.stripcomments(f.readline()) - lineno = lineno + 1 - line2 = self.strip(line2) - line = line + line2 - while line2.find(')') == -1: - line2 = self.strip(self.stripcomments(f.readline())) - lineno = lineno + 1 - line = line + line2 - # now strip the parenthesis - line = line.replace(')','') - line = line.replace('(','') - # now line equals the entire RR entry - tokens = line.split() - if tokens[0].upper() == '$ORIGIN': - try: - origin = tokens[1].lower() - except: - raise ZonefileError(lineno, 'bad origin') - elif tokens[0].upper() == '$INCLUDE': - try: - f2 = file(tokens[1].lower()) - if len(tokens) > 2: - self.parse(tokens[2].lower(), f2) - else: - self.parse(origin, f2) - f2.close() - except: - raise ZonefileError(lineno, 'bad INCLUDE directive') - elif tokens[0].upper() == '$TTL': - try: - lastttl = int(tokens[1]) - except: - raise ZonefileError(lineno, 'bad TTL directive') - elif tokens[0].upper() == '$GENERATE': - try: - lhs = tokens[2].lower() - dnstype = tokens[3].upper() - rhs = tokens[4].lower() - rng = tokens[1].split('-') - start = int(rng[0]) - i = rng[1].find('/') - if i != -1: - stop = int(rng[1][:i])+1 - step = int(rng[1][i+1:]) - else: - stop = int(rng[1])+1 - step = 1 - for i in range(start,stop,step): - grhs = self.getgname(rhs,i) - if dnstype in ['NS','CNAME','PTR']: - grhs = self.addorigin(origin,grhs) - rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl, - [grhs]) - glhs = self.addorigin(origin,self.getgname(lhs,i)) - self.addrec(glhs,dnstype, rrdata) - except KeyError: - raise ZonefileError(lineno, 'bad GENERATE directive') - else: - try: - # if line begins with blank then owner is last owner - if line[0] in string.whitespace: - owner = lastowner - else: - owner = tokens[0].lower() - tokens = tokens[1:] - if owner == '@': - owner = origin - elif owner[-1:] != '.': - owner = owner + '.' + origin - else: - owner = owner[:-1] # strip off trailing dot - # line format is either: [class] [ttl] type RDATA - # or [ttl] [class] type RDATA - # - items in brackets are optional - # - # need to figure out which token is type - # and backfill the missing data - count = 0 - for token in tokens: - if token.upper() in self.dnstypes: - break - count = count + 1 - # the following strips off the ttl and class if they exist - if count == 0: - ttl = lastttl - dnsclass = lastdnsclass - elif count == 1: - if tokens[0].isdigit(): - ttl = int(tokens[0]) - dnsclass = lastdnsclass - else: - ttl = lastttl - dnsclass = tokens[0].upper() - tokens = tokens[1:] - elif count == 2: - if tokens[0].isdigit(): - ttl = int(tokens[0]) - dnsclass = tokens[1].upper() - else: - ttl = int(tokens[1]) - dnsclass = tokens[0].upper() - tokens = tokens[2:] - else: - raise ZonefileError(lineno,'bad ttl or class') - dnstype = tokens[0] - # make sure all of the structure is there - rrdata = self.getrrdata(origin, dnstype, dnsclass, - ttl, tokens[1:]) - self.addrec(owner, dnstype, rrdata) - lastowner = owner - lastttl = ttl - lastdnsclass = dnsclass - except: - raise ZonefileError(lineno,'unable to parse line') - if closefile: - f.close() - -class dnsconfig: - def __init__(self): - # self.zonedb = zonedb({}) - self.cached = {} - self.loglevel = 0 - - def getview(self, msg, address, port): - # return: - # 1. a list of zone keys - # 2. whether or not to use the resolver - # (i.e. answer recursive queries) - # 3. a list of forwarder addresses - return ['servers.csail.mit.edu'], 1, [] - - def allowupdate(self, msg, address, port): - # return 1 if updates are allowed - # NOTE: can only update the zones - # returned by the getview func - return 1 - - def outpackets(self, packetlist): - # modify outgoing packets - return packetlist - -class dnsheader: - def __init__(self, id=1): - self.id = id # 16bit identifier generated by queryer - self.qr = 0 # one bit field specifying query(0) or response(1) - self.opcode = 0 # 4bit field specifying type of query - self.aa = 0 # authoritative answer - self.tc = 0 # message is not truncated - self.rd = 1 # recursion desired - self.ra = 0 # recursion available? - self.z = 0 # reserved for future use - self.rcode = 0 # response code (set in response) - self.qdcount = 1 # number of questions, only 1 is supported - self.ancount = 0 # number of rrs in the answer section - self.nscount = 0 # number of name server rrs in authority section - self.arcount = 0 # number or rrs in the additional section - -class dnsquestion: - def __init__(self): - self.qname = 'localhost' - self.qtype = 'A' - self.qclass = 'IN' - -class dnsupdatezone: - pass - -class message: - def __init__(self, msgdata=''): - if msgdata: - self.header = dnsheader() - else: - self.header = dnsheader(id=random.randrange(1,32768)) - self.question = dnsquestion() - self.answerlist = [] - self.authlist = [] - self.addlist = [] - self.u = '' - self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA', - 7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS', - 12:'PTR',13:'HINFO',14:'MINFO',15:'MX', - 16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV', - 38:'A6',39:'DNAME',251:'IXFR',252:'AXFR', - 253:'MAILB',254:'MAILA',255:'ANY'} - self.rqtypes = {} - for key in self.qtypes.keys(): - self.rqtypes[self.qtypes[key]] = key - self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'} - self.rqclasses = {} - for key in self.qclasses.keys(): - self.rqclasses[self.qclasses[key]] = key - - if msgdata: - self.processpkt(msgdata) - - def getdomainname(self, data, i): - log(4,'IN GETDOMAINNAME') - domainname = '' - gotpointer = 0 - labellength= ord(data[i]) - log(4,'labellength:' + str(labellength)) - i = i + 1 - while labellength != 0: - while labellength >= 192: - # pointer - if not gotpointer: - rindex = i + 1 - gotpointer = 1 - log(4,'got pointer') - i = asctoint(chr(ord(data[i-1]) & 63)+data[i]) - log(4,'new index:'+str(i)) - labellength = ord(data[i]) - log(4,'labellength:' + str(labellength)) - i = i + 1 - if domainname: - domainname = domainname + '.' + data[i:i+labellength] - else: - domainname = data[i:i+labellength] - log(4,'domainname:'+domainname) - i = i + labellength - labellength = ord(data[i]) - log(4,'labellength:' + str(labellength)) - i = i + 1 - if not gotpointer: - rindex = i - - return domainname.lower(), rindex - - def getrrdata(self, type, msgdata, rdlength, i): - log(4,'unpacking RR data') - rdata = msgdata[i:i+rdlength] - if type == 'A': - return {'address':socket.inet_ntoa(rdata)} - elif type == 'AAAA': - return {'address':ipv6net_ntoa(rdata)} - elif type == 'CNAME': - cname, i = self.getdomainname(msgdata,i) - return {'cname':cname} - elif type == 'HINFO': - cpulen = ord(rdata[0]) - cpu = rdata[1:cpulen+1] - return {'cpu':cpu, - 'os':rdata[cpulen+2:]} - elif type == 'LOC': - return {'version':ord(rdata[0]), - 'size':self.locsize(rdata[1]), - 'horiz_pre':self.locsize(rdata[2]), - 'vert_pre':self.locsize(rdata[3]), - 'latitude':asctoint(rdata[4:8]), - 'longitude':asctoint(rdata[8:12]), - 'altitude':asctoint(rdata[12:16])} - elif type == 'MX': - exchange, i = self.getdomainname(msgdata,i+2) - return {'preference':asctoint(rdata[:2]), - 'exchange':exchange} - elif type == 'NS': - nsdname, i = self.getdomainname(msgdata,i) - return {'nsdname':nsdname} - elif type == 'PTR': - ptrdname, i = self.getdomainname(msgdata,i) - return {'ptrdname':ptrdname} - elif type == 'RP': - mboxdname, i = self.getdomainname(msgdata,i) - txtdname, i = self.getdomainname(msgdata,i) - return {'mboxdname':mboxdname, - 'txtdname':txtdname} - elif type == 'SOA': - mname, i = self.getdomainname(msgdata,i) - rname, i = self.getdomainname(msgdata,i) - return {'mname':mname, - 'rname':rname, - 'serial':asctoint(msgdata[i:i+4]), - 'refresh':asctoint(msgdata[i+4:i+8]), - 'retry':asctoint(msgdata[i+8:i+12]), - 'expire':asctoint(msgdata[i+12:i+16]), - 'minimum':asctoint(msgdata[i+16:i+20])} - elif type == 'SRV': - target, i = self.getdomainname(msgdata,i+6) - return {'priority':asctoint(rdata[0:2]), - 'weight':asctoint(rdata[2:4]), - 'port':asctoint(rdata[4:6]), - 'target':target} - elif type == 'TXT': - return {'txtdata':rdata[1:]} - else: - return {'rdata':rdata} - - def getrr(self, data, i): - log(4,'unpacking RR name') - name, i = self.getdomainname(data, i) - type = asctoint(data[i:i+2]) - type = self.qtypes.get(type,chr(type)) - klass = asctoint(data[i+2:i+4]) - klass = self.qclasses.get(klass,chr(klass)) - ttl = asctoint(data[i+4:i+8]) - rdlength = asctoint(data[i+8:i+10]) - rrdata = self.getrrdata(type,data,rdlength,i+10) - rrdata['ttl'] = ttl - rrdata['class'] = klass - rr = {name:{type:[rrdata]}} - return rr, i+10+rdlength - - def processpkt(self, msgdata): - self.header.id = asctoint(msgdata[:2]) - self.header.qr = ord(msgdata[2]) >> 7 - self.header.opcode = (ord(msgdata[2]) & 127) >> 3 - if self.header.opcode == 5: - # UPDATE packet - log(4,'processing UPDATE packet') - del self.header.aa - del self.header.tc - del self.header.rd - del self.header.ra - del self.header.qdcount - del self.header.ancount - del self.header.nscount - del self.header.arcount - del self.question - self.zone = dnsupdatezone() - del self.answerlist - del self.authlist - del self.addlist - self.header.z = 0 - self.header.rcode = ord(msgdata[3]) & 15 - self.header.zocount = asctoint(msgdata[4:6]) - self.header.prcount = asctoint(msgdata[6:8]) - self.header.upcount = asctoint(msgdata[8:10]) - self.header.arcount = asctoint(msgdata[10:12]) - self.zolist = [] - self.prlist = [] - self.uplist = [] - self.addlist = [] - i = 12 - for x in range(self.header.zocount): - (dn, i) = self.getdomainname(msgdata,i) - self.zone.zname = dn - type = asctoint(msgdata[i:i+2]) - self.zone.ztype = self.qtypes.get(type,chr(type)) - klass = asctoint(msgdata[i+2:i+4]) - self.zone.zclass = self.qclasses.get(klass,chr(klass)) - i = i + 4 - for x in range(self.header.prcount): - rr, i = self.getrr(msgdata,i) - self.prlist.append(rr) - for x in range(self.header.upcount): - rr, i = self.getrr(msgdata,i) - self.uplist.append(rr) - for x in range(self.header.arcount): - rr, i = self.getrr(msgdata,i) - self.adlist.append(rr) - else: - self.header.aa = (ord(msgdata[2]) & 4) >> 2 - self.header.tc = (ord(msgdata[2]) & 2) >> 1 - self.header.rd = ord(msgdata[2]) & 1 - self.header.ra = ord(msgdata[3]) >> 7 - self.header.z = (ord(msgdata[3]) & 112) >> 4 - self.header.rcode = ord(msgdata[3]) & 15 - self.header.qdcount = asctoint(msgdata[4:6]) - self.header.ancount = asctoint(msgdata[6:8]) - self.header.nscount = asctoint(msgdata[8:10]) - self.header.arcount = asctoint(msgdata[10:12]) - i = 12 - for x in range(self.header.qdcount): - log(4,'unpacking question') - (dn, i) = self.getdomainname(msgdata,i) - self.question.qname = dn - rrtype = asctoint(msgdata[i:i+2]) - self.question.qtype = self.qtypes.get(rrtype,chr(rrtype)) - klass = asctoint(msgdata[i+2:i+4]) - self.question.qclass = self.qclasses.get(klass,chr(klass)) - i = i + 4 - for x in range(self.header.ancount): - log(4,'unpacking answer RR') - rr, i = self.getrr(msgdata,i) - self.answerlist.append(rr) - for x in range(self.header.nscount): - log(4,'unpacking auth RR') - rr, i = self.getrr(msgdata,i) - self.authlist.append(rr) - for x in range(self.header.arcount): - log(4,'unpacking additional RR') - rr, i = self.getrr(msgdata,i) - self.addlist.append(rr) - return - - def pds(self, s, l): - # pad string with chr(0)'s so that - # return string length is l - x = l - len(s) - return x*chr(0) + s - - def locsize(self, s): - x1 = ord(s) >> 4 - x2 = ord(s) & 15 - return (x1, x2) - - def packlocsize(self, x): - return chr((x[0] << 4) + x[1]) - - def packdomainname(self, name, i, msgcomp): - log(4,'packing domainname: ' + name) - if name == '': - return chr(0) - if name in msgcomp.keys(): - log(4,'using pointer for: ' + name) - return msgcomp[name] - packedname = '' - tokens = name.split('.') - for j in range(len(tokens)): - packedname = packedname + chr(len(tokens[j])) + tokens[j] - nameleft = '.'.join(tokens[j+1:]) - if nameleft in msgcomp.keys(): - log(4,'using pointer for: ' + nameleft) - return packedname+msgcomp[nameleft] - # haven't used a pointer so put this in the dictionary - pointer = inttoasc(i) - if len(pointer) == 1: - msgcomp[name] = chr(192)+pointer - else: - msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1] - log(4,'added pointer for ' + name + '(' + str(i) + ')') - return packedname + chr(0) - - def packrr(self, rr, i, msgcomp): - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - if self.rqtypes.has_key(rrtype): - typeval = self.rqtypes[rrtype] - else: - typeval = ord(rrtype) - dbrec = rr[rrname][rrtype][0] - ttl = dbrec['ttl'] - rclass = self.rqclasses[dbrec['class']] - packedrr = (self.packdomainname(rrname, i, msgcomp) + - self.pds(inttoasc(typeval),2) + - self.pds(inttoasc(rclass),2) + - self.pds(inttoasc(ttl),4)) - i = i + len(packedrr) + 2 - if rrtype == 'A': - rdata = socket.inet_aton(dbrec['address']) - elif rrtype == 'AAAA': - rdata = ipv6net_aton(dbrec['address']) - elif rrtype == 'CNAME': - rdata = self.packdomainname(dbrec['cname'], i, msgcomp) - elif rrtype == 'HINFO': - rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] + - chr(len(dbrec['os'])) + dbrec['os']) - elif rrtype == 'LOC': - rdata = (chr(dbrec['version']) + - self.packlocsize(dbrec['size']) + - self.packlocsize(dbrec['horiz_pre']) + - self.packlocsize(dbrec['vert_pre']) + - self.pds(inttoasc(dbrec['latitude']),4) + - self.pds(inttoasc(dbrec['longitude']),4) + - self.pds(inttoasc(dbrec['altitude']),4)) - elif rrtype == 'MX': - rdata = (self.pds(inttoasc(dbrec['preference']),2) + - self.packdomainname(dbrec['exchange'], i+2, msgcomp)) - elif rrtype == 'NS': - rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp) - elif rrtype == 'PTR': - rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp) - elif rrtype == 'RP': - rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp) - i = i + len(rdata1) - rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp) - rdata = rdata1 + rdata2 - elif rrtype == 'SOA': - rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp) - i = i + len(rdata1) - rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp) - rdata = (rdata1 + - rdata2 + - self.pds(inttoasc(dbrec['serial']),4) + - self.pds(inttoasc(dbrec['refresh']),4) + - self.pds(inttoasc(dbrec['retry']),4) + - self.pds(inttoasc(dbrec['expire']),4) + - self.pds(inttoasc(dbrec['minimum']),4)) - elif rrtype == 'SRV': - rdata = (self.pds(inttoasc(dbrec['priority']),2) + - self.pds(inttoasc(dbrec['weight']),2) + - self.pds(inttoasc(dbrec['port']),2) + - self.packdomainname(dbrec['target'], i+6, msgcomp)) - elif rrtype == 'TXT': - rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata'] - else: - rdata = dbrec['rdata'] - - return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata - - def buildpkt(self): - # keep dictionary of names packed (so we can use pointers) - msgcomp = {} - # header - if self.header.id > 65535: - log(0,'building packet with bad ID field') - self.header.id = 1 - msgdata = inttoasc(self.header.id) - if len(msgdata) == 1: - msgdata = chr(0) + msgdata - h1 = ((self.header.qr << 7) + - (self.header.opcode << 3) + - (self.header.aa << 2) + - (self.header.tc << 1) + - (self.header.rd)) - h2 = ((self.header.ra << 7) + - (self.header.z << 4) + - (self.header.rcode)) - msgdata = msgdata + chr(h1) + chr(h2) - msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2) - msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2) - msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2) - msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2) - # question - msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp) - if self.rqtypes.has_key(self.question.qtype): - typeval = self.rqtypes[self.question.qtype] - else: - typeval = ord(self.question.qtype) - msgdata = msgdata + self.pds(inttoasc(typeval),2) - if self.rqclasses.has_key(self.question.qclass): - classval = self.rqclasses[self.question.qclass] - else: - classval = ord(self.question.qclass) - msgdata = msgdata + self.pds(inttoasc(classval),2) - # rr's - # RR record format: - # {'name' : {'type' : [rdata, rdata, ...]} - # example: {'test.blah.net': {'A': [{'address': '10.1.1.2', - # 'ttl': 3600L}]}} - for rr in self.answerlist: - log(4,'packing answer RR') - msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) - for rr in self.authlist: - log(4,'packing auth RR') - msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) - for rr in self.addlist: - log(4,'packing additional RR') - msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp) - - return msgdata - - def printpkt(self): - print 'ID: ' +str(self.header.id) - if self.header.qr: - print 'QR: RESPONSE' - else: - print 'QR: QUERY' - if self.header.opcode == 0: - print 'OPCODE: STANDARD QUERY' - elif self.header.opcode == 1: - print 'OPCODE: INVERSE QUERY' - elif self.header.opcode == 2: - print 'OPCODE: SERVER STATUS REQUEST' - elif self.header.opcode == 5: - print 'UPDATE REQUEST' - else: - print 'OPCODE: UNKNOWN QUERY TYPE' - if self.header.opcode != 5: - if self.header.aa: - print 'AA: AUTHORITATIVE ANSWER' - else: - print 'AA: NON-AUTHORITATIVE ANSWER' - if self.header.tc: - print 'TC: MESSAGE IS TRUNCATED' - else: - print 'TC: MESSAGE IS NOT TRUNCATED' - if self.header.rd: - print 'RD: RECURSION DESIRED' - else: - print 'RD: RECURSION NOT DESIRED' - if self.header.ra: - print 'RA: RECURSION AVAILABLE' - else: - print 'RA: RECURSION IS NOT AVAILABLE' - if self.header.rcode == 1: - printrcode = 'FORMERR' - elif self.header.rcode == 2: - printrcode = 'SERVFAIL' - elif self.header.rcode == 3: - printrcode = 'NXDOMAIN' - elif self.header.rcode == 4: - printrcode = 'NOTIMP' - elif self.header.rcode == 5: - printrcode = 'REFUSED' - elif self.header.rcode == 6: - printrcode = 'YXDOMAIN' - elif self.header.rcode == 7: - printrcode = 'YXRRSET' - elif self.header.rcode == 8: - printrcode = 'NXRRSET' - elif self.header.rcode == 9: - printrcode = 'NOTAUTH' - elif self.header.rcode == 10: - printrcode = 'NOTZONE' - else: - printrcode = 'NOERROR' - print 'RCODE: ' + printrcode - if self.header.opcode == 5: - print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount) - print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount) - print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount) - print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount) - print 'ZONE SECTION:' - print 'zname: ' + self.zone.zname - print 'zonetype: ' + self.zone.ztype - print 'zoneclass: ' + self.zone.zclass - print 'PREREQUISITE RRs:' - for rr in self.prlist: - print rr - print 'UPDATE RRs:' - for rr in self.uplist: - print rr - print 'ADDITIONAL RRs:' - for rr in self.addlist: - print rr - - - else: - print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount) - print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount) - print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount) - print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount) - print 'QUESTION SECTION:' - print 'qname: ' + self.question.qname - print 'querytype: ' + self.question.qtype - print 'queryclass: ' + self.question.qclass - print 'ANSWER RRs:' - for rr in self.answerlist: - print rr - print 'AUTHORITY RRs:' - for rr in self.authlist: - print rr - print 'ADDITIONAL RRs:' - for rr in self.addlist: - print rr - -class zonedb: - def __init__(self, zdict): - self.zdict = zdict - self.updates = {} - for k in self.zdict.keys(): - if self.zdict[k]['type'] == 'slave': - self.zdict[k]['lastupdatetime'] = 0 - - def error(self, id, qname, querytype, queryclass, rcode): - error = message() - error.header.id = id - error.header.rcode = rcode - error.header.qr = 1 - error.question.qname = qname - error.question.qtype = querytype - error.question.qclass = queryclass - return error - - def getorigin(self, zkey): - origin = '' - if self.zdict.has_key(zkey): - origin = self.zdict[zkey]['origin'] - return origin - - def getmasterip(self, zkey): - masterip = '' - if self.zdict.has_key(zkey): - if self.zdict[zkey].has_key('masterip'): - masterip = self.zdict[zkey]['masterip'] - return masterip - - def zonetrans(self, query): - # build a list of messages - # each message contains one rr of the zone - # the first and last message are the - # SOA records - origin = query.question.qname - querytype = query.question.qtype - zkey = '' - for zonekey in self.zdict.keys(): - if self.zdict[zonekey]['origin'] == query.question.qname: - zkey = zonekey - if not zkey: - return [] - zonedata = self.zdict[zkey]['zonedata'] - queryid = query.header.id - soarec = zonedata[origin]['SOA'][0] - soa = {origin:{'SOA':[soarec]}} - curserial = soarec['serial'] - rrlist = [] - if querytype == 'IXFR': - clientserial = query.authlist[0][origin]['SOA'][0]['serial'] - if clientserial < curserial: - for i in range(clientserial,curserial+1): - if self.updates[zkey].has_key(i): - for rr in self.updates[zkey][i]['added']: - rrlist.append(rr) - for rr in self.updates[zkey][i]['removed']: - rrlist.append(rr) - if len(rrlist) > 0: - rrlist.insert(0,soa) - rrlist.append(soa) - else: - rrlist.append(soa) - else: - for nodename in zonedata.keys(): - for rrtype in zonedata[nodename].keys(): - if not (rrtype == 'SOA' and nodename == origin): - for rr in zonedata[nodename][rrtype]: - rrlist.append({nodename:{rrtype:[rr]}}) - rrlist.insert(0,soa) - rrlist.append(soa) - msglist = [] - for rr in rrlist: - msg = message() - msg.header.id = queryid - msg.header.qr = 1 - msg.header.aa = 1 - msg.header.rd = 0 - msg.header.qdcount = 1 - msg.question.qname = origin - msg.question.qtype = querytype - msg.question.qclass = 'IN' - msg.header.ancount = 1 - msg.answerlist.append(rr) - msglist.append(msg) - return msglist - - def update_zone(self, rrlist, params): - zonekey = params[0] - zonedata = {} - soa = rrlist.pop() - origin = soa.keys()[0] - for rr in rrlist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if zonedata.has_key(rrname): - if not zonedata[rrname].has_key(rrtype): - zonedata[rrname][rrtype] = [] - else: - zonedata[rrname] = {} - zonedata[rrname][rrtype] = [] - zonedata[rrname][rrtype].append(dbrec) - self.zdict[zonekey]['zonedata'] = zonedata - curtime = time.time() - self.zdict[zonekey]['lastupdatetime'] = curtime - try: - f = file(self.zdict[zonekey]['filename'],'w') - writezonefile(zonedata, self.zdict[zonekey]['origin'], f) - f.close() - except: - log(0,'unable to write zone ' + zonekey + 'to disk') - log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')') - - def remove_zone(self, zonekey): - if self.zdict.has_key(zonekey): - del self.zdict[zonekey] - - def getslaves(self, curtime): - rlist = [] - for k in self.zdict.keys(): - if self.zdict[k]['type'] == 'slave': - origin = self.zdict[k]['origin'] - refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh'] - if self.zdict[k]['lastupdatetime'] + refresh < curtime: - rlist.append((k, origin, self.zdict[k]['masterip'])) - return rlist - - def zmatch(self, qname, zkeys): - for zkey in zkeys: - if self.zdict.has_key(zkey): - origin = self.zdict[zkey]['origin'] - if qname.rfind(origin) != -1: - return zkey - return '' - - def getzlist(self, name, zone): - if name == zone: - return - zlist = [] - i = name.rfind(zone) - if i == -1: - return - firstpart = name[:i-1] - partlist = firstpart.split('.') - partlist.reverse() - lastpart = zone - for x in range(len(partlist)): - lastpart = partlist[x] + '.' + lastpart - zlist.append(lastpart) - return zlist - - def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc): - # handle zone transfers seperately - qname = query.question.qname - querytype = query.question.qtype - queryclass = query.question.qclass - if querytype in ['AXFR','IXFR']: - for zkey in self.zdict.keys(): - if zkey in zkeys: - if qname == self.zdict[zkey]['origin']: - answerlist = self.zonetrans(query) - break - else: - answerlist = [] - cbfunc(query, addr, server, dorecursion, flist, answerlist) - else: - zonekey = self.zmatch(qname, zkeys) - if zonekey: - origin = self.zdict[zonekey]['origin'] - zonedict = self.zdict[zonekey]['zonedata'] - referral = 0 - rranswerlist = [] - rrnslist = [] - rraddlist = [] - answer = message() - answer.header.aa = 1 - answer.header.id = query.header.id - answer.header.qr = 1 - answer.header.opcode = query.header.opcode - answer.header.rcode = 4 - answer.header.ra = dorecursion - answer.question.qname = query.question.qname - answer.question.qtype = query.question.qtype - answer.question.qclass = query.question.qclass - answer.header.ra = dorecursion - s = '.servers.csail.mit.edu' - if qname.endswith(s): - host = qname[:-len(s)] - value = sipb_xen_database.NIC.get_by(hostname=host) - if value is None: - pass - else: - ip = value.ip - rranswerlist.append({qname: {'A': [{'address': ip, - 'class': 'IN', - 'ttl': 10}]}}) - if zonedict.has_key(qname): - # found the node, now take care of CNAMEs - if zonedict[qname].has_key('CNAME'): - if querytype != 'CNAME': - nodetype = 'CNAME' - while nodetype == 'CNAME': - rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}}) - qname = zonedict[qname]['CNAME'][0]['cname'] - if zonedict.has_key(qname): - nodetype = zonedict[qname].keys()[0] - else: - # error, shouldn't have a CNAME that points to nothing - return - # if we get this far, then the record has matched and we should return - # a reply that has no error (even if there is no info macthing the qtype) - answer.header.rcode = 0 - answernode = zonedict[qname] - if querytype == 'ANY': - for type in answernode.keys(): - for rec in answernode[type]: - rranswerlist.append({qname:{type:[rec]}}) - elif answernode.has_key(querytype): - for rec in answernode[querytype]: - rranswerlist.append({qname:{querytype:[rec]}}) - # do rrset ordering (cyclic) - if len(answernode[querytype]) > 1: - rec = answernode[querytype].pop(0) - answernode[querytype].append(rec) - else: - # remove all cname rrs from answerlist - rranswerlist = [] - else: - # would check for wildcards here (but aren't because they seem bad) - # see if we need to give a referral - zlist = self.getzlist(qname,origin) - for zonename in zlist: - if zonedict.has_key(zonename): - if zonedict[zonename].has_key('NS'): - answer.header.rcode = 0 - referral = 1 - for rec in zonedict[zonename]['NS']: - rrnslist.append({zonename:{'NS':[rec]}}) - nsdname = rec['nsdname'] - # add glue records if they exist - if zonedict.has_key(nsdname): - if zonedict[nsdname].has_key('A'): - for gluerec in zonedict[nsdname]['A']: - rraddlist.append({nsdname:{'A':[gluerec]}}) - # negative caching stuff - if not referral: - if not rranswerlist: - # NOTE: RFC1034 section 4.3.4 says we should add the SOA record - # to the additional section of the response. BIND adds - # it to the ns section though - answer.header.rcode = 3 - rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}}) - else: - for rec in zonedict[origin]['NS']: - rrnslist.append({origin:{'NS':[rec]}}) - answer.header.ancount = len(rranswerlist) - answer.header.nscount = len(rrnslist) - answer.header.arcount = len(rraddlist) - answer.answerlist = rranswerlist - answer.authlist = rrnslist - answer.addlist = rraddlist - cbfunc(query, addr, server, dorecursion, flist, [answer]) - else: - cbfunc(query, addr, server, dorecursion, flist, []) - - def handle_update(self, msg, addr, ns): - zkey = '' - slaves = [] - for zonekey in self.zdict.keys(): - if (self.zdict[zonekey]['type'] == 'master' and - self.zdict[zonekey]['origin'] == msg.zone.zname): - zkey = zonekey - if not zkey: - log(2,'SENDING NOTAUTH UPDATE ERROR') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 9) - return errormsg, '', slaves - # find the slaves for the zone - if self.zdict[zkey].has_key('slaves'): - slaves = self.zdict[zkey]['slaves'] - origin = self.zdict[zkey]['origin'] - zd = self.zdict[zkey]['zonedata'] - # check the permissions - if not ns.config.allowupdate(msg, addr[0], addr[1]): - log(2,'SENDING REFUSED UPDATE ERROR') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 5) - return errormsg, origin, slaves - # now check the prereqs - temprrset = {} - for rr in msg.prlist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if dbrec['ttl'] != 0: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - if rrname.rfind(msg.zone.zname) == -1: - log(2,'NOTZONE(10)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 10) - return errormsg, origin, slaves - if dbrec['class'] == 'ANY': - if dbrec['rdata']: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - if rrtype == 'ANY': - if not zd.has_key(rrname): - log(2,'NXDOMAIN(3)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 3) - return errormsg, origin, slaves - else: - rrsettest = 0 - if zd.has_key(rrname): - if zd[rrname].has_key(rrtype): - rrsettest = 1 - if not rrsettest: - log(2,'NXRRSET(8)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 8) - return errormsg, origin, slaves - if dbrec['class'] == 'NONE': - if dbrec['rdata']: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - if rrtype == 'ANY': - if zd.has_key(rrname): - log(2,'YXDOMAIN(6)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 6) - return errormsg, origin, slaves - else: - if zd.has_key(rrname): - if zd[rrname].has_key(rrtype): - log(2,'YXRRSET(7)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 7) - return errormsg, origin, slaves - if dbrec['class'] == msg.zone.zclass: - if temprrset.has_key(rrname): - if not temprrset[rrname].has_key(rrtype): - temprrset[rrname][rrtype] = [] - else: - temprrset[rrname] = {} - temprrset[rrname][rrtype] = [] - temprrset[rrname][rrtype].append(dbrec) - else: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - for nodename in temprrset.keys(): - if not self.rrmatch(temprrset[nodename],zd[nodename]): - log(2,'NXRRSET(8)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 8) - return errormsg, origin, slaves - - # update section prescan - for rr in msg.uplist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if rrname.rfind(msg.zone.zname) == -1: - log(2,'NOTZONE(10)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 10) - return errormsg, origin, slaves - if dbrec['class'] == msg.zone.zclass: - if rrtype in ['ANY','MAILA','MAILB','AXFR']: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - elif dbrec['class'] == 'ANY': - if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - elif dbrec['class'] == 'NONE': - if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - else: - log(2,'FORMERROR(1)') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - return errormsg, origin, slaves - - # now handle actual update - curserial = zd[msg.zone.zname]['SOA'][0]['serial'] - # update the soa serial here - clearupdatehist = 0 - if len(msg.uplist) > 0: - # initialize history structure - if not self.updates.has_key(zkey): - self.updates[zkey] = {} - self.updates[zkey][curserial] = {'removed':[], - 'added':[]} - if curserial == 2**32: - newserial = 2 - clearupdatehist = 1 - else: - newserial = curserial + 1 - self.updates[zkey][newserial] = {'removed':[], - 'added':[]} - zd[msg.zone.zname]['SOA'][0]['serial'] = newserial - for rr in msg.uplist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if dbrec['class'] == msg.zone.zclass: - if rrtype == 'SOA': - if zd.has_key(rrname): - if zd[rrname].has_key('SOA'): - if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']: - del zd[rrname]['SOA'][0] - zd[rrname]['SOA'].append(dbrec) - clearupdatehist = 1 - elif rrtype == 'WKS': - if zd.has_key(rrname): - if zd[rrname].has_key('WKS'): - rdata = zd[rrname]['WKS'][0] - oldrr = {rrname:{'WKS':[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname]['WKS'][0] - zd[rrname]['WKS'].append(dbrec) - newrr = {rrname:{'WKS':[dbrec]}} - self.updates[zkey][newserial]['added'].append(newrr) - else: - if zd.has_key(rrname): - if not zd[rrname].has_key(rrtype): - zd[rrname][rrtype] = [] - else: - zd[rrname] = {} - zd[rrname][rrtype] = [] - zd[rrname][rrtype].append(dbrec) - newrr = {rrname:{rrtype:[dbrec]}} - self.updates[zkey][newserial]['added'].append(newrr) - elif dbrec['class'] == 'ANY': - if rrtype == 'ANY': - if rrname == msg.zone.zname: - if zd.has_key(rrname): - for dnstype in zd[rrname].keys(): - if dnstype not in ['SOA','NS']: - for rdata in zd[rrname][dnstype]: - oldrr = {rrname:{dnstype:[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname][dnstype] - else: - if zd.has_key(rrname): - for dnstype in zd[rrname].keys(): - for rdata in zd[rrname][dnstype]: - oldrr = {rrname:{dnstype:[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname] - else: - if zd.has_key(rrname): - if zd[rrname].has_key(rrtype): - if rrname == msg.zone.zname: - if rrtype not in ['SOA','NS']: - for rdata in zd[rrname][dnstype]: - oldrr = {rrname:{dnstype:[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname][rrtype] - else: - for rdata in zd[rrname][dnstype]: - oldrr = {rrname:{dnstype:[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname][rrtype] - elif dbrec['class'] == 'NONE': - if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']): - if zd.had_key(rrname): - if zd[rrname].has_key(rrtype): - for i in range(len(zd[rrname][rrtype])): - if dbrec == zd[rrname][rrtype][i]: - rdata = zd[rrname][dnstype][i] - oldrr = {rrname:{dnstype:[rdata]}} - self.updates[zkey][curserial]['removed'].append(oldrr) - del zd[rrname][rrtype][i] - if len(zd[rrname][rrtype]) == 0: - del zd[rrname][rrtype] - if clearupdatehist: - self.updates[zkey] = {} - log(2,'SENDING UPDATE NOERROR MSG') - noerrormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 0) - return noerrormsg, origin, slaves - -class dnscache: - def __init__(self,cachezone): - self.cachedb = cachezone - # go through and set all of the root ttls to zero - for node in self.cachedb.keys(): - for rtype in self.cachedb[node].keys(): - for rr in self.cachedb[node][rtype]: - rr['ttl'] = 0 - if rtype == 'NS': - rr['rtt'] = 0 - # add special entries for localhost - self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]} - self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]} - self.cachedb['']['SOA'] = [] - self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb', - 'rname':'cachedb@localhost','serial':1,'refresh':10800, - 'retry':3600,'expire':604800,'minimum':3600}) - - def hasrdata(self, irrdata, rrdatalist): - # compare everything but ttls - test = 0 - testrrdata = irrdata.copy() - del testrrdata['ttl'] - for rrdata in rrdatalist: - temprrdata = rrdata.copy() - del temprrdata['ttl'] - if temprrdata == testrrdata: - test = 1 - return test - - def add(self, rr, qzone, nsdname): - # NOTE: can't cache records from sites - # that don't own those records (i.e. example.com - # can't give us A records for www.example.net) - name = rr.keys()[0] - if (qzone != '') and (name[-len(qzone):] != qzone): - log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone) - return - rtype = rr[name].keys()[0] - rdata = rr[name][rtype][0] - if rdata['ttl'] < 3600: - log(2,'low ttl: ' + str(rdata['ttl'])) - rdata['ttl'] = 3600 - rdata['ttl'] = int(time.time() + rdata['ttl']) - if rtype == 'NS': - rdata['rtt'] = 0 - name = name.lower() - rtype = rtype.upper() - if self.cachedb.has_key(name): - if self.cachedb[name].has_key(rtype): - if not self.hasrdata(rdata, self.cachedb[name][rtype]): - self.cachedb[name][rtype].append(rdata) - log(3,'appended rdata to ' + - name + '(' + rtype + ') in cache') - else: - log(3,'same rdata for ' + name + '(' + - rtype + ') is already in cache') - else: - self.cachedb[name][rtype] = [rdata] - log(3,'appended ' + rtype + ' and rdata to node ' + - name + ' in cache') - else: - self.cachedb[name] = {rtype:[rdata]} - log(3,'added node ' + name + '(' + rtype + ') to cache') - self.reap() - - def addneg(self, qname, querytype, queryclass): - if not self.cachedb.has_key(qname): - self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]} - else: - if not self.cachedb[qname].has_key(querytype): - self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}] - - def haskey(self, qname, querytype, msg=''): - log(3,'looking for ' + qname + '(' + querytype + ') in cache') - if self.cachedb.has_key(qname): - rranswerlist = [] - rrnslist = [] - rraddlist = [] - if self.cachedb[qname].has_key('CNAME'): - if querytype != 'CNAME': - nodetype = 'CNAME' - while nodetype == 'CNAME': - if len(self.cachedb[qname]['CNAME'][0].keys()) > 1: - log(3,'Adding CNAME to cache answer') - rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}}) - qname = self.cachedb[qname]['CNAME'][0]['cname'] - if self.cachedb.has_key(qname): - nodetype = self.cachedb[qname].keys()[0] - else: - # shouldn't have a CNAME that points to nothing - return - if querytype == 'ANY': - for type in self.cache[qname].keys(): - for rec in self.cachedb[qname][type]: - # can't append negative entries - if len(rec.keys()) > 1: - rranswerlist.append({qname:{type:[rec]}}) - elif self.cachedb[qname].has_key(querytype): - for rec in self.cachedb[qname][querytype]: - if len(rec.keys()) > 1: - rranswerlist.append({qname:{querytype:[rec]}}) - if rranswerlist: - if msg: - answer = message() - answer.header.id = msg.header.id - answer.header.qr = 1 - answer.header.opcode = msg.header.opcode - answer.header.ra = 1 - answer.question.qname = msg.question.qname - answer.question.qtype = msg.question.qtype - answer.question.qclass = msg.question.qclass - answer.header.rcode = 0 - answer.header.ancount = len(rranswerlist) - answer.answerlist = rranswerlist - return answer - else: - return 1 - else: - log(3,'Cache has no node for ' + qname) - - def getnslist(self, qname): - # find the best nameserver to ask from the cache - tokens = qname.split('.') - nsdict = {} - curtime = time.time() - for i in range(len(tokens)): - domainname = '.'.join(tokens[i:]) - if self.cachedb.has_key(domainname): - if self.cachedb[domainname].has_key('NS'): - for nsrec in self.cachedb[domainname]['NS']: - badserver = 0 - if nsrec.has_key('badtill'): - if nsrec['badtill'] < curtime: - del nsrec['badtill'] - else: - badserver = 1 - if badserver: - log(2,'BAD SERVER, not using ' + nsrec['nsdname']) - if self.cachedb.has_key(nsrec['nsdname']) and not badserver: - if self.cachedb[nsrec['nsdname']].has_key('A'): - for arec in self.cachedb[nsrec['nsdname']]['A']: - nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'], - 'ip':arec['address']} - if nsdict: - break - if not nsdict: - domainname = '' - # nothing in the cache matches so give back the root servers - for nsrec in self.cachedb['']['NS']: - badserver = 0 - if nsrec.has_key('badtill'): - if curtime > nsrec['badtill']: - del nsrec['badtill'] - else: - badserver = 1 - if not badserver: - for arec in self.cachedb[nsrec['nsdname']]['A']: - nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']} - - return (domainname, nsdict) - - def badns(self, zonename, nsdname): - if self.cachedb.has_key(zonename): - if self.cachedb[zonename].has_key('NS'): - for nsrec in self.cachedb[zonename]['NS']: - if nsrec['nsdname'] == nsdname: - log(2,'Setting ' + nsdname + ' as bad nameserver') - nsrec['badtill'] = time.time() + 3600 - - - def updatertt(self, qname, zone, rtt): - if self.cachedb.has_key(zone): - if self.cachedb[zone].has_key('NS'): - for rr in self.cachedb[zone]['NS']: - if rr['nsdname'] == qname: - log(2,'updating rtt for ' + qname + ' to ' + str(rtt)) - rr['rtt'] = rtt - - def reap(self): - # expire all old records - ntime = time.time() - for nodename in self.cachedb.keys(): - for rrtype in self.cachedb[nodename].keys(): - for rdata in self.cachedb[nodename][rrtype]: - ttl = rdata['ttl'] - if ttl != 0: - if ttl < ntime: - self.cachedb[nodename][rrtype].remove(rdata) - if len(self.cachedb[nodename][rrtype]) == 0: - del self.cachedb[nodename][rrtype] - if len(self.cachedb[nodename]) == 0: - del self.cachedb[nodename] - - return - - def zonetrans(self, queryid): - # build a list of messages - # each message contains one rr of the zone - # the first and last message are the - # SOA records - zonedata = self.cachedb - rrlist = [] - soa = {'':{'SOA':[zonedata['']['SOA'][0]]}} - for nodename in zonedata.keys(): - for rrtype in zonedata[nodename].keys(): - if not (rrtype == 'SOA' and nodename == ''): - for rr in zonedata[nodename][rrtype]: - rrlist.append({nodename:{rrtype:[rr]}}) - rrlist.insert(0,soa) - rrlist.append(soa) - msglist = [] - for rr in rrlist: - msg = message() - msg.header.id = queryid - msg.header.qr = 1 - msg.header.aa = 1 - msg.header.rd = 0 - msg.header.qdcount = 1 - msg.question.qname = 'cache' - msg.question.qtype = 'AXFR' - msg.question.qclass = 'IN' - msg.header.ancount = 1 - msg.answerlist.append(rr) - msglist.append(msg) - return msglist - -class gethostaddr(asyncore.dispatcher): - def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'): - asyncore.dispatcher.__init__(self) - self.msg = message() - self.msg.question.qname = hostname - self.msg.question.qtype = 'A' - self.cbfunc = cbfunc - self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) - self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) - self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) - self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) - - def handle_read(self): - replydata, addr = self.socket.recvfrom(1500) - self.close() - try: - replymsg = message(replydata) - except: - log(0,'unable to process packet') - return - answername = replymsg.question.qname - cname = '' - # go through twice to catch cnames after A recs - for rr in replymsg.answerlist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if rrname == answername and rrtype == 'CNAME': - answername = dbrec['cname'] - cname = answername - for rr in replymsg.answerlist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - dbrec = rr[rrname][rrtype][0] - if rrname == answername and rrtype == 'A': - self.cbfunc(dbrec['address']) - return - # if we got a cname and no A send query for cname - if cname: - self.msg = message() - self.msg.question.qname = cname - self.msg.question.qtype = 'A' - self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) - else: - self.cbfunc('') - - def writable(self): - return 0 - - def handle_write(self): - pass - - def handle_connect(self): - pass - - def handle_close(self): - self.close() - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class simpleudprequest(asyncore.dispatcher): - def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''): - asyncore.dispatcher.__init__(self) - self.gotanswer = 0 - self.msg = msg - self.cbfunc = cbfunc - self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) - self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) - self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) - self.outqkey = outqkey - self.socket.sendto(self.msg.buildpkt(), (serveraddr,53)) - - def handle_read(self): - replydata, addr = self.socket.recvfrom(1500) - self.close() - try: - replymsg = message(replydata) - except: - log(0,'unable to process packet') - return - self.cbfunc(replymsg, self.outqkey) - - def writable(self): - return 0 - - def handle_write(self): - pass - - def handle_connect(self): - pass - - def handle_close(self): - self.close() - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class simpletcprequest(asyncore.dispatcher): - def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''): - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.query = msg - self.cbfunc = cbfunc - self.cbparams = cbparams - self.errorfunc = errorfunc - msgdata = msg.buildpkt() - ml = inttoasc(len(msgdata)) - if len(ml) == 1: - ml = chr(0) + ml - self.buffer = ml+msgdata - self.rbuffer = '' - self.rmsgleft = 0 - self.rrlist = [] - log(2,'sending tcp request to ' + serveraddr) - self.connect((serveraddr,53)) - - def recv (self, buffer_size): - try: - data = self.socket.recv (buffer_size) - if not data: - # a closed connection is indicated by signaling - # a read condition, and having recv() return 0. - self.handle_close() - return '' - else: - return data - except socket.error, why: - # winsock sometimes throws ENOTCONN - if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]: - self.handle_close() - return '' - else: - raise socket.error, why - - def handle_connect(self): - pass - - def handle_msg(self, msg): - if self.query.question.qtype == 'AXFR': - if len(self.rrlist) == 0: - if len(msg.answerlist) == 0: - if self.errorfunc: - self.errorfunc(self.cbparams[0]) - self.close() - return - rr = msg.answerlist[0] - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - self.rrlist.append(rr) - if rrtype == 'SOA' and len(self.rrlist) > 1: - self.close() - if self.cbparams: - self.cbfunc(self.rrlist, self.cbparams) - else: - self.cbfunc(self.rrlist) - else: - self.close() - if self.cbparams: - self.cbfunc(msg, self.cbparams) - else: - self.cbfunc(msg) - - def handle_read(self): - data = self.recv(8192) - if len(self.rbuffer) == 0: - self.rmsglength = asctoint(data[:2]) - data = data[2:] - self.rbuffer = self.rbuffer + data - while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0: - msgdata = self.rbuffer[:self.rmsglength] - self.rbuffer = self.rbuffer[self.rmsglength:] - if len(self.rbuffer) == 0: - self.rmsglength = 0 - else: - self.rmsglength = asctoint(self.rbuffer[:2]) - self.rbuffer = self.rbuffer[2:] - try: - self.handle_msg(message(msgdata)) - except: - return - - def writable(self): - return (len(self.buffer) > 0) - - def handle_write(self): - sent = self.send(self.buffer) - self.buffer = self.buffer[sent:] - - def handle_close(self): - if self.errorfunc: - self.errorfunc(self.query.question.qname) - self.close() - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class udpdnsserver(asyncore.dispatcher): - def __init__(self, port, dnsserver): - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) - self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) - self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) - self.bind(('',port)) - self.dnsserver = dnsserver - self.maxmsgsize = 500 - - def handle_read(self): - try: - while 1: - msgdata, addr = self.socket.recvfrom(1500) - self.dnsserver.handle_packet(msgdata, addr, self) - except socket.error, why: - if why[0] != asyncore.EWOULDBLOCK: - raise socket.error, why - - def sendpackets(self, msglist, addr): - for msg in msglist: - msgdata = msg.buildpkt() - if len(msgdata) > self.maxmsgsize: - msg.header.tc = 1 - # take off all the answers to ensure - # the packet size is small enough - msg.header.ancount = 0 - msg.header.nscount = 0 - msg.header.arcount = 0 - msg.answerlist = [] - msg.authlist = [] - msg.addlist = [] - msgdata = msg.buildpkt() - self.sendto(msgdata, addr) - - def writable(self): - return 0 - - def handle_write(self): - pass - - def handle_connect(self): - pass - - def handle_close(self): - # print '1:In handle close' - return - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class tcpdnschannel(asynchat.async_chat): - def __init__(self, server, s, addr): - asynchat.async_chat.__init__(self, s) - self.server = server - self.addr = addr - self.set_terminator(None) - self.databuffer = '' - self.msglength = 0 - log(3,'Created new tcp channel') - - def collect_incoming_data(self, data): - if self.msglength == 0: - self.msglength = asctoint(data[:2]) - data = data[2:] - self.databuffer = self.databuffer + data - if len(self.databuffer) == self.msglength: - # got entire message - self.server.dnsserver.handle_packet(self.databuffer, self.addr, self) - self.databuffer = '' - - def sendpackets(self, msglist, addr): - for msg in msglist: - x = msg.buildpkt() - ml = inttoasc(len(x)) - if len(ml) == 1: - ml = chr(0) + ml - self.push(ml+x) - self.close() - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class tcpdnsserver(asyncore.dispatcher): - def __init__(self, port, dnsserver): - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_STREAM) - self.set_reuse_addr() - self.bind(('',port)) - self.listen(5) - self.dnsserver = dnsserver - - def handle_accept(self): - conn, addr = self.accept() - tcpdnschannel(self, conn, addr) - - def handle_close(self): - self.close() - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class nameserver: - def __init__(self, resolver, localconfig): - self.resolver = resolver - self.config = localconfig - self.zdb = self.config.zonedatabase - self.last_reap_time = time.time() - self.maint_int = 10 - self.slavesupdating = [] - self.notifys = [] - self.sentnotify = [] - self.notify_retry_time = 30 - self.notify_retries = 4 - self.askedsoa = {} - self.soatimeout = 10 - - def error(self, id, qname, querytype, queryclass, rcode): - error = message() - error.header.id = id - error.header.rcode = rcode - error.header.qr = 1 - error.question.qname = qname - error.question.qtype = querytype - error.question.qclass = queryclass - return error - - def need_zonetransfer(self, zkey, origin, masterip, trynum=0): - self.askedsoa[zkey] = {'masterip':masterip, - 'senttime':time.time(), - 'origin':origin, - 'trynum':trynum+1} - query = message() - query.header.id = random.randrange(1,32768) - query.header.rd = 0 - query.question.qname = origin - query.question.qtype = 'SOA' - query.question.qclass = 'IN' - log(3,'slave checking for new data in ' + origin) - simpleudprequest(query, self.handle_soaquery, - masterip, zkey) - - def handle_soaquery(self, msg, zkey): - origin = msg.question.qname - masterip = self.askedsoa[zkey]['masterip'] - del self.askedsoa[zkey] - if zkey not in self.slavesupdating: - self.slavesupdating.append(zkey) - query = message() - query.header.id = random.randrange(1,32768) - query.header.rd = 0 - query.question.qname = origin - query.question.qtype = 'AXFR' - query.question.qclass = 'IN' - log(3,'Updating slave zone: ' + zkey) - simpletcprequest(query, self.handle_zonetrans, - [zkey],masterip,self.handle_zterror) - - def handle_zonetrans(self, rrlist, params): - log(1,'handling zone transfer') - zonekey = params[0] - self.zdb.update_zone(rrlist, params) - self.slavesupdating.remove(zonekey) - - def handle_zterror(self, zonekey): - self.slavesupdating.remove(zonekey) - self.zdb.remove_zone(zonekey) - - def rrmatch(self, rrset1, rrset2): - for rrtype in rrset1.keys(): - if rrtype not in rrset2.keys(): - return - else: - if len(rrset1[rrtype]) != len(rrset2[rrtype]): - return - return 1 - - def process_notify(self, msg, ipaddr, port): - (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port) - goodzkey = '' - for zkey in zkeys: - origin = self.zdb.getorigin(zkey) - if origin == msg.question.qname: - masterip = self.zdb.getmasterip(zkey) - if masterip: - goodzkey = zkey - if goodzkey: - log(3,'got NOTIFY from ' + masterip) - self.need_zonetransfer(goodzkey, origin, masterip, 0) - return - - def notify(self): - curtime = time.time() - for origin, ipaddr, trynum, senttime in self.sentnotify: - if senttime + self.notify_retry_time > curtime: - self.notifys.append((origin, ipaddr, trynum)) - self.sentnotify.remove((origin, ipaddr, trynum, senttime)) - for origin, ipaddr, trynum in self.notifys: - msg = message() - msg.question.qname = origin - msg.question.qtype = 'SOA' - msg.question.qclass = 'IN' - msg.header.opcode = 4 - # there probably is a better way to do this - if self.resolver: - self.resolver.send_to([msg],(ipaddr,53)) - if trynum+1 <= self.notify_retries: - self.sentnotify.append((origin,ipaddr,trynum+1,curtime)) - self.notifys = [] - - def handle_packet(self, msgdata, addr, server): - # self.reap() - try: - msg = message(msgdata) - except: - return - # find a matching view - (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1]) - if not msg.header.qr and msg.header.opcode == 5: - log(2,'GOT UPDATE PACKET') - # check the zone section - if (msg.header.zocount != 1 or - msg.zone.ztype != 'SOA' or - msg.zone.zclass != 'IN'): - log(2,'SENDING FORMERR UPDATE ERROR') - errormsg = self.error(msg.header.id, msg.zone.zname, - msg.zone.ztype, msg.zone.zclass, 1) - server.sendpackets([errormsg],addr) - else: - (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self) - if answer.header.rcode == 0: - # schedule NOTIFYs to slaves - for ipaddr in slaves: - self.notifys.append((origin, ipaddr, 0)) - server.sendpackets([answer],addr) - elif msg.header.opcode == 4: - if msg.header.qr: - log(0,'got NOTIFY response') - for origin, ipaddr, trynum, senttime in self.sentnotify: - if ipaddr == addr[0] and msg.question.qname == origin: - self.sentnotify.remove((origin, ipaddr, trynum, senttime)) - else: - log(0,'got NOTIFY') - self.process_notify(msg, addr[0], addr[1]) - elif not msg.header.qr and msg.header.opcode == 0: - # it's a question - qname = msg.question.qname.lower() - log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype + - ') from ' + addr[0]) - # handle special version packet - if (msg.question.qtype == 'TXT' and - msg.question.qclass == 'CH'): - if qname == 'version.bind': - server.sendpackets([getversion(qname, - msg.header.id, - msg.header.rd, - dorecursion, '1.0')],addr) - elif qname == 'version.oak': - server.sendpackets([getversion(qname, - msg.header.id, - msg.header.rd, - dorecursion, '1.0')],addr) - return - self.zdb.lookup(zkeys, msg, addr, server, dorecursion, - flist, self.lookup_callback) - - def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist): - if answerlist: - server.sendpackets(self.config.outpackets(answerlist), addr) - elif dorecursion: - if msg.question.qtype in ['AXFR','IXFR']: - if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR': - if self.resolver: - server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr) - else: - # won't forward zone transfers and - # don't handle recursive zone transfers - server.sendpackets([self.error(msg.header.id, msg.question.qname, - msg.question.qtype, - msg.question.qclass,2)],addr) - else: - self.resolver.handle_query(msg, addr, flist, server.sendpackets) - - def reap(self): - log(4,'in nameserver reap') - # do all maintenence (interval) stuff here - if self.resolver: - self.resolver.reap() - self.notify() - curtime = time.time() - if curtime > (self.last_reap_time + self.maint_int): - self.last_reap_time = curtime - # do zone transfers here if slave server and haven't asked for soa - for (zkey, origin, masterip) in self.zdb.getslaves(curtime): - if not self.askedsoa.has_key(zkey): - self.need_zonetransfer(zkey, origin, masterip) - for zkey in self.askedsoa.keys(): - if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout: - if self.askedsoa[zkey]['trynum'] > 3: - self.zdb.remove_zone(zkey) - del self.askedsoa[zkey] - else: - masterip = self.askedsoa[zkey]['masterip'] - origin = self.askedsoa[zkey]['origin'] - trynum = self.askedsoa[zkey]['trynum'] - del self.askedsoa[zkey] - self.need_zonetransfer(zkey, origin, masterip, trynum) - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - -class resolver(asyncore.dispatcher): - def __init__(self, cache, port=0): - asyncore.dispatcher.__init__(self) - self.create_socket(socket.AF_INET, socket.SOCK_DGRAM) - self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024) - self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024) - self.bind(('',port)) - self.cache = cache - self.outqnum = 0 - self.outq = {} - self.holdq = {} - self.holdtime = 10 - self.holdqlength = 100 - self.last_reap_time = time.time() - self.maint_int = 10 - self.timeout = 3 - - def getoutqkey(self): - self.outqnum = self.outqnum + 1 - if self.outqnum == 99999: - self.outqnum = 1 - return str(self.outqnum) - - def error(self, id, qname, querytype, queryclass, rcode): - error = message() - error.header.id = id - error.header.rcode = rcode - error.header.qr = 1 - error.question.qname = qname - error.question.qtype = querytype - error.question.qclass = queryclass - return error - - def qpacket(self, id, qname, querytype, queryclass): - # create a question - query = message() - query.header.id = id - query.header.rd = 0 - query.question.qname = qname - query.question.qtype = querytype - query.question.qclass = queryclass - return query - - def send_to(self, msglist, addr): - for msg in msglist: - data = msg.buildpkt() - if len(data) > 512: - # packet to big - msg.header.tc = 1 - msg.header.ancount = 0 - msg.answerlist = [] - msg.header.nscount = 0 - msg.authlist = [] - msg.header.arcount = 0 - msg.addlist = [] - self.socket.sendto(msg.buildpkt(), addr) - else: - self.socket.sendto(data, addr) - - def handle_read(self): - try: - while 1: - msgdata, addr = self.socket.recvfrom(1500) - # should put 'try' here in production server - self.handle_packet(msgdata, addr) - except socket.error, why: - if why[0] != asyncore.EWOULDBLOCK: - raise socket.error, why - - def handle_packet(self, msgdata, addr): - try: - msg = message(msgdata) - except: - return - if not msg.header.qr: - self.handle_query(msg, addr, [], self.send_to) - else: - log(2,'received unsolicited reply') - - - def handle_query(self, msg, addr, flist, cbfunc): - qname = msg.question.qname - querytype = msg.question.qtype - queryclass = msg.question.qclass - # check the cache first - answer = self.cache.haskey(qname,querytype,msg) - if answer: - cbfunc([answer], addr) - log(2,'sent answer for ' + qname + '(' + querytype + - ') from cache') - else: - # check if query is already in progess - for oqkey in self.outq.keys(): - if (self.outq[oqkey]['qname'] == qname and - self.outq[oqkey]['querytype'] == querytype): - log(2,'query already in progress for '+qname+'('+querytype+')') - # put entry in hold queue to try later - hqrec = {'processtime':time.time()+self.holdtime, - 'query':msg,'addr':addr, - 'qname':qname,'querytype':querytype, - 'queryclass':queryclass, - 'cbfunc':cbfunc} - self.putonhold(hqrec) - return - - outqkey = self.getoutqkey()+str(msg.header.id) - self.outq[outqkey] = {'query':msg, - 'addr':addr, - 'qname':qname, - 'querytype':querytype, - 'queryclass':queryclass, - 'cbfunc':cbfunc, - 'answerlist':[], - 'addlist':[], - 'qsent':0} - if flist: - self.outq[outqkey]['flist'] = flist - self.askfns(outqkey) - else: - self.askns(outqkey) - - def putonhold(self,hqrec): - hqid = hqrec['qname']+hqrec['querytype'] - if self.holdq.has_key(hqid): - if len(self.holdq[hqid]) < self.holdqlength: - hqrec['processtime']=time.time()+self.holdtime - self.holdq[hqid].append(hqrec) - - - def askns(self, outqkey): - qname = self.outq[outqkey]['qname'] - querytype = self.outq[outqkey]['querytype'] - queryclass = self.outq[outqkey]['queryclass'] - # don't try more than 10 times to avoid loops - if self.outq[outqkey]['qsent'] == 10: - del self.outq[outqkey] - log(2,'Dropping query for ' + qname + '(' + querytype + ')' + - ' POSSIBLE LOOP') - return - # find the best nameservers to ask from the cache - (qzone, nsdict) = self.cache.getnslist(qname) - if not nsdict: - # there are no good servers - if self.outq[outqkey]['addr'] != 'IQ': - qid = self.outq[outqkey]['query'].header.id - self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2), - self.outq[outqkey]['addr']) - del self.outq[outqkey] - log(2,'Dropping query for ' + qname + '(' + querytype + ')' + - 'no good name servers to ask') - return - # pick the best nameserver - rtts = nsdict.keys() - rtts.sort() - bestnsip = nsdict[rtts[0]]['ip'] - bestnsname = nsdict[rtts[0]]['name'] - # fill in the callback data structure - id=random.randrange(1,32768) - self.outq[outqkey]['nsqueriedlastip'] = bestnsip - self.outq[outqkey]['nsqueriedlastname'] = bestnsname - self.outq[outqkey]['nsdict'] = nsdict - self.outq[outqkey]['qzone'] = qzone - self.outq[outqkey]['qsenttime'] = time.time() - self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1 - # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53)) - self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass), - self.handle_response, bestnsip, outqkey) - # update rtt so that we ask a different server next time - self.cache.updatertt(bestnsname,qzone,1) - log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname + - ') for ' + qname + '(' + querytype + ')') - - def askfns(self, outqkey): - flist = self.outq[outqkey]['flist'] - qname = self.outq[outqkey]['qname'] - querytype = self.outq[outqkey]['querytype'] - queryclass = self.outq[outqkey]['queryclass'] - self.outq[outqkey]['qsenttime'] = time.time() - id=random.randrange(1,32768) - # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53)) - self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass), - self.handle_fresponse, flist[0], outqkey) - log(2,''+outqkey+'|sent query to forwarder') - - def handle_response(self, msg, outqkey): - # either reponse: - # 1. contains a name error - # 2. answers the question - # (cache data and return it) - # 3. is (contains) a CNAME and qtype isn't - # (cache cname and change qname to it) - # (check if qname and qtype are in any other rrs in the response) - # (must check cache again here) - # 4. contains a better delegation - # (cache the delegation and start again) - # 5. is aserver failure - # (delete server from list and try again) - - # make sure that original question is still outstanding - if not self.outq.has_key(outqkey): - # should never get here - # if we do we aren't doing housekeeping of callbacks very well - log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname) - return - - querytype = self.outq[outqkey]['querytype'] - if msg.header.rcode not in [1,2,4,5]: - # update rtt time - rtt = time.time() - self.outq[outqkey]['qsenttime'] - nsname = self.outq[outqkey]['nsqueriedlastname'] - zone = self.outq[outqkey]['qzone'] - self.cache.updatertt(nsname,zone,rtt) - - if msg.header.rcode == 3: - log(2,outqkey+'|GOT Name Error for ' + msg.question.qname + - '(' + msg.question.qtype + ')') - # name error - # cache negative answer - self.cache.addneg(self.outq[outqkey]['qname'], - self.outq[outqkey]['querytype'], - self.outq[outqkey]['queryclass']) - if self.outq[outqkey]['addr'] != 'IQ': - answer = message() - answer.question.qname = self.outq[outqkey]['query'].question.qname - answer.question.qtype = self.outq[outqkey]['query'].question.qtype - answer.question.qclass = self.outq[outqkey]['query'].question.qclass - answer.header.id = self.outq[outqkey]['query'].header.id - answer.header.qr = 1 - answer.header.opcode = self.outq[outqkey]['query'].header.opcode - answer.header.ra = 1 - self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) - del self.outq[outqkey] - - elif msg.header.ancount > 0: - # answer (may be CNAME) - haveanswer = 0 - cname = '' - log(2,'CACHING ANSWERLIST ENTRIES') - for rr in msg.answerlist: - rrname = rr.keys()[0] - rrtype = rr[rrname].keys()[0] - if ((rrname == msg.question.qname or rrname == cname ) and - rrtype == msg.question.qtype): - haveanswer = 1 - if rrname == msg.question.qname and rrtype == 'CNAME': - cname = rr[rrname][rrtype][0]['cname'] - self.cache.add(rr, self.outq[outqkey]['qzone'], - self.outq[outqkey]['nsqueriedlastname']) - if haveanswer: - if self.outq[outqkey]['addr'] != 'IQ': - log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname + - '(' + msg.question.qtype + ')' ) - answer = message() - answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist'] - answer.header.ancount = len(answer.answerlist) - answer.question.qname = self.outq[outqkey]['query'].question.qname - answer.question.qtype = self.outq[outqkey]['query'].question.qtype - answer.question.qclass = self.outq[outqkey]['query'].question.qclass - answer.header.id = self.outq[outqkey]['query'].header.id - answer.header.qr = 1 - answer.header.opcode = self.outq[outqkey]['query'].header.opcode - answer.header.ra = 1 - self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) - log(2,outqkey+'|sent answer retrieved from remote server for ' + - self.outq[outqkey]['query'].question.qname) - else: - log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' + - msg.question.qtype + ')') - del self.outq[outqkey] - elif cname: - log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')') - self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist - self.outq[outqkey]['qname'] = cname - self.askns(outqkey) - else: - log(2,outqkey+'|GOT BOGUS answer for ' + msg.question.qname + '(' + - msg.question.qtype + ')') - del self.outq[outqkey] - - elif msg.header.nscount > 0 and msg.header.ancount == 0: - log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')') - # delegation - # cache the nameserver rrs and start over - # if there are no glue records for nameservers must fetch them first - log(2,'CACHING AUTHLIST ENTRIES') - for rr in msg.authlist: - self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) - log(2,'CACHING ADDLIST ENTRIES') - for rr in msg.addlist: - self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) - rrlist = msg.authlist+msg.addlist - fetchglue = 0 - nscount = 0 - for rr in msg.authlist: - nodename = rr.keys()[0] - if rr[nodename].keys()[0] == 'NS': - nscount = nscount + 1 - nsdname = rr[nodename]['NS'][0]['nsdname'] - if not self.cache.haskey(nsdname,'A'): - log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)') - fetchglue = fetchglue + 1 - # need to fetch A rec - noutqkey = self.getoutqkey()+str(random.randrange(1,32768)) - self.outq[noutqkey] = {'query':'', - 'addr':'IQ', - 'qname':nsdname, - 'querytype':'A', - 'queryclass':'IN', - 'qsent':0} - log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)') - self.askns(noutqkey) - if not nscount: - log(2,outqkey+'|Dropping query (no ns recs) for ' + - msg.question.qname + '(' + msg.question.qtype + ')' ) - del self.outq[outqkey] - elif fetchglue == nscount: - log(2,outqkey+'|Stalling query (no glue recs) for ' + - msg.question.qname + '(' + msg.question.qtype + ')') - self.putonhold(self.outq[outqkey]) - del self.outq[outqkey] - else: - log(2,outqkey+'|got (some) glue with delegation') - self.askns(outqkey) - - elif msg.header.rcode in [1,2,4,5]: - log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode)) - log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' + - self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname) - # don't ask this server for a while - self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname']) - self.askns(outqkey) - else: - log(2,outqkey+'|GOT UNPARSEABLE REPLY') - msg.printpkt() - - def handle_fresponse(self, msg, outqkey): - if msg.header.rcode in [1,2,4,5]: - self.outq[outqkey]['flist'].pop(0) - if len(self.outq[outqkey]['flist']) == 0: - qid = self.outq[outqkey]['query'].header.id - qname = self.outq[outqkey]['qname'] - querytype = self.outq[outqkey]['querytype'] - queryclass = self.outq[outqkey]['queryclass'] - self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2), - self.outq[outqkey]['addr']) - del self.outq[outqkey] - else: - self.askfns(outqkey) - else: - answer = message() - answer.header.id = self.outq[outqkey]['query'].header.id - answer.header.qr = 1 - answer.header.opcode = self.outq[outqkey]['query'].header.opcode - answer.header.ra = 1 - answer.question.qname = self.outq[outqkey]['query'].question.qname - answer.question.qtype = self.outq[outqkey]['query'].question.qtype - answer.question.qclass = self.outq[outqkey]['query'].question.qclass - answer.header.ancount = msg.header.ancount - answer.header.nscount = msg.header.nscount - answer.header.arcount = msg.header.arcount - answer.answerlist = msg.answerlist - answer.authlist = msg.authlist - answer.addlist = msg.addlist - if msg.header.rcode == 3: - # name error - # cache negative answer - self.cache.addneg(self.outq[outqkey]['qname'], - self.outq[outqkey]['querytype'], - self.outq[outqkey]['queryclass']) - else: - # cache all rrs - for rr in msg.answerlist: - self.cache.add(rr,'','forwarder') - for rr in msg.authlist: - self.cache.add(rr,'','forwarder') - for rr in msg.addlist: - self.cache.add(rr,'','forwarder') - self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr']) - del self.outq[outqkey] - - def writable(self): - return 0 - - def handle_write(self): - pass - - def handle_connect(self): - pass - - def handle_close(self): - # print '1:In handle close' - return - - def process_holdq(self): - curtime = time.time() - for hqkey in self.holdq.keys(): - for hqrec in self.holdq[hqkey]: - if curtime >= hqrec['processtime']: - log(2,'processing held query') - answer = self.cache.haskey(hqrec['qname'], - hqrec['querytype'], - hqrec['query']) - if answer: - hqrec['cbfunc']([answer], hqrec['addr']) - log(2,'sent answer for ' + hqrec['qname'] + - '(' + hqrec['querytype'] + ') from cache') - self.holdq[hqkey].remove(hqrec) - if len(self.holdq[hqkey]) == 0: - del self.holdq[hqkey] - - def reap(self): - self.process_holdq() - curtime = time.time() - log(3,timestamp() + 'processed HOLDQ (sockets: ' + - str(len(asyncore.socket_map.keys()))+')') - if curtime > (self.last_reap_time + self.maint_int): - self.last_reap_time = curtime - for outqkey in self.outq.keys(): - if curtime > self.outq[outqkey]['qsenttime'] + self.timeout: - log(2,'query for '+self.outq[outqkey]['qname']+'('+ - self.outq[outqkey]['querytype']+') expired') - # don't set forwarders as bad - if not self.outq[outqkey].has_key('flist'): - self.cache.badns(self.outq[outqkey]['qzone'], - self.outq[outqkey]['nsqueriedlastname']) - if self.outq[outqkey].has_key('request'): - log(3,'closing socket for expired query') - self.outq[outqkey]['request'].close() - del self.outq[outqkey] - return - - def log_info (self, message, type='info'): - if __debug__ or type != 'info': - log(0,'%s: %s' % (type, message)) - - -def run(configobj): - global loglevel - r = resolver(dnscache(configobj.cached)) - ns = nameserver(r, configobj) - udpds = udpdnsserver(53, ns) - tcpds = tcpdnsserver(53, ns) - loglevel = configobj.loglevel - try: - loop(ns.reap) - except KeyboardInterrupt: - print 'server done' - -if __name__ == '__main__': - sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen') - zonedict = {'example.net':{'origin':'example.net', - 'filename':'db.example.net', - 'type':'master', - 'slaves':[]}} - - - zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu', - 'filename':'db.servers.csail.mit.edu', - 'type':'master', - 'slaves':[]}} - - zonedict2 = {'example.net':{'origin':'example.net', - 'filename':'db.example.net', - 'type':'slave', - 'masterip':'127.0.0.1'}} - readzonefiles(zonedict) - lconfig = dnsconfig() - lconfig.zonedatabase = zonedb(zonedict) - pr = zonefileparser() - pr.parse('','db.ca') - lconfig.cached = pr.getzdict() - lconfig.loglevel = 3 - - run(lconfig)