--- /dev/null
+#!/usr/bin/python
+# Python Domain Name Server
+# Copyright (C) 2002 Digital Lumber, Inc.
+
+# This library is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 2.1 of the License, or (at your option) any later version.
+
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+
+# You should have received a copy of the GNU Lesser General Public
+# License along with this library; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+
+import socket
+import asyncore
+import asynchat
+import select
+import types
+import random
+import time
+import signal
+import string
+import sys
+import sipb_xen_database
+from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
+ ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT
+
+# EXAMPLE ZONE FILE DATA STRUCTURE
+
+# NOTE:
+# There are no trailing dots in the internal data
+# structure. Although it's hard to tell by reading
+# the RFC's, the dots on the end of names are just
+# used internally by the resolvers and servers to
+# see if they need to append a domain name onto
+# the end of names. There are no trailing dots
+# on names in queries on the network.
+
+examplenet = {'example.net':{'SOA':[{'class':'IN',
+ 'ttl':10,
+ 'mname':'ns1.example.net',
+ 'rname':'hostmaster.example.net',
+ 'serial':1,
+ 'refresh':10800,
+ 'retry':3600,
+ 'expire':604800,
+ 'minimum':3600}],
+ 'NS':[{'class':'IN',
+ 'ttl':10,
+ 'nsdname':'ns1.example.net'},
+ {'ttl':10,
+ 'nsdname':'ns2.example.net'}],
+ 'MX':[{'class':'IN',
+ 'ttl':10,
+ 'preference':10,
+ 'exchange':'mail.example.net'}]},
+ 'server1.example.net':{'A':[{'class':'IN',
+ 'ttl':10,
+ 'address':'10.1.2.3'}]},
+ 'www.example.net':{'CNAME':[{'class':'IN',
+ 'ttl':10,
+ 'cname':'server1.example.net'}]},
+ 'router.example.net':{'A':[{'class':'IN',
+ 'ttl':10,
+ 'address':'10.1.2.1'},
+ {'class':'IN',
+ 'ttl':10,
+ 'address':'10.2.1.1'}]}
+
+ }
+
+# setup logging defaults
+loglevel = 0
+logfile = sys.stdout
+
+try:
+ file
+except NameError:
+ def file(name, mode='r', buffer=0):
+ return open(name, mode, buffer)
+
+def log(level,msg):
+ if level <= loglevel:
+ logfile.write(msg+'\n')
+
+def timestamp():
+ return time.strftime('%m/%d/%y %H:%M:%S')+ '-'
+
+def inttoasc(number):
+ try:
+ hs = hex(number)[2:]
+ except:
+ log(0,'inttoasc cannot convert ' + repr(number))
+ if hs[-1:].upper() == 'L':
+ hs = hs[:-1]
+ result = ''
+ while len(hs) > 2:
+ result = chr(int(hs[-2:],16)) + result
+ hs = hs[:-2]
+ result = chr(int(hs,16)) + result
+
+ return result
+
+def asctoint(ascnum):
+ rascnum = ''
+ for i in range(len(ascnum)-1,-1,-1):
+ rascnum = rascnum + ascnum[i]
+ result = 0
+ count = 0
+ for c in rascnum:
+ x = ord(c) << (8*count)
+ result = result + x
+ count = count + 1
+
+ return result
+
+def ipv6net_aton(ip_string):
+ packed_ip = ''
+ # first account for shorthand syntax
+ pieces = ip_string.split(':')
+ pcount = 0
+ for part in pieces:
+ if part != '':
+ pcount = pcount + 1
+ if pcount < 8:
+ rs = '0:'*(8-pcount)
+ ip_string = ip_string.replace('::',':'+rs)
+ if ip_string[0] == ':':
+ ip_string = ip_string[1:]
+ pieces = ip_string.split(':')
+ for part in pieces:
+ # pad with the zeros
+ i = 4-len(part)
+ part = i*'0'+part
+ packed_ip = packed_ip + chr(int(part[:2],16))+ chr(int(part[2:],16))
+ return packed_ip
+
+def ipv6net_ntoa(packed_ip):
+ ip_string = ''
+ count = 0
+ for c in packed_ip:
+ ip_string = ip_string + hex(ord(c))[2:]
+ count = count + 1
+ if count == 2:
+ ip_string = ip_string + ':'
+ count = 0
+ return ip_string[:-1]
+
+def getversion(qname, id, rd, ra, versionstr):
+ msg = message()
+ msg.header.id = id
+ msg.header.qr = 1
+ msg.header.aa = 1
+ msg.header.rd = rd
+ msg.header.ra = ra
+ msg.header.rcode = 0
+ msg.question.qname = qname
+ msg.question.qtype = 'TXT'
+ msg.question.qclass = 'CH'
+ if qname == 'version.bind':
+ msg.header.ancount = 2
+ msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak',
+ 'ttl':360000,
+ 'class':'CH'}]}})
+ msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr,
+ 'ttl':360000,
+ 'class':'CH'}]}})
+ else:
+ msg.header.ancount = 1
+ msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr,
+ 'ttl':360000,
+ 'class':'CH'}]}})
+ return msg
+
+def getrcode(rcode):
+ if rcode == 0:
+ rcodestr = 'NOERROR(No error condition)'
+ elif rcode == 1:
+ rcodestr = 'FORMERR(Format Error)'
+ elif rcode == 2:
+ rcodestr = 'SERVFAIL(Internal failure)'
+ elif rcode == 3:
+ rcodestr = 'NXDOMAIN(Name does not exist)'
+ elif rcode == 4:
+ rcodestr = 'NOTIMP(Not Implemented)'
+ elif rcode == 5:
+ rcodestr = 'REFUSED(Security violation)'
+ elif rcode == 6:
+ rcodestr = 'YXDOMAIN(Name exists)'
+ elif rcode == 7:
+ rcodestr = 'YXRRSET(RR exists)'
+ elif rcode == 8:
+ rcodestr = 'NXRRSET(RR does not exist)'
+ elif rcode == 9:
+ rcodestr = 'NOTAUTH(Server not Authoritative)'
+ elif rcode == 10:
+ rcodestr = 'NOTZONE(Name not in zone)'
+ else:
+ rcodestr = 'Unknown RCODE(' + str(rcode) + ')'
+ return rcodestr
+
+def printrdata(dnstype, rdata):
+ if dnstype == 'A':
+ return rdata['address']
+ elif dnstype == 'MX':
+ return str(rdata['preference'])+'\t'+rdata['exchange']+'.'
+ elif dnstype == 'NS':
+ return rdata['nsdname']+'.'
+ elif dnstype == 'PTR':
+ return rdata['ptrdname']+'.'
+ elif dnstype == 'CNAME':
+ return rdata['cname']+'.'
+ elif dnstype == 'SOA':
+ return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+
+ 35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+
+ str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )')
+
+def makezonedatalist(zonedata, origin):
+ # unravel structure into list
+ zonedatalist = []
+ # get soa first
+ soanode = zonedata[origin]
+ zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]])
+ for item in soanode.keys():
+ if item != 'SOA':
+ for listitem in soanode[item]:
+ zonedatalist.append([origin+'.', item, listitem])
+ for nodename in zonedata.keys():
+ if nodename != origin:
+ for item in zonedata[nodename].keys():
+ for listitem in zonedata[nodename][item]:
+ zonedatalist.append([nodename+'.', item, listitem])
+ return zonedatalist
+
+def writezonefile(zonedata, origin, file):
+ zonedatalist = makezonedatalist(zonedata, origin)
+ for rr in zonedatalist:
+ owner = rr[0]
+ dnstype = rr[1]
+ line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' +
+ dnstype + '\t' + printrdata(dnstype, rr[2]))
+ file.write(line + '\n')
+
+def readzonefiles(zonedict):
+ for k in zonedict.keys():
+ filepath = zonedict[k]['filename']
+ try:
+ pr = zonefileparser()
+ pr.parse(zonedict[k]['origin'],filepath)
+ zonedict[k]['zonedata'] = pr.getzdict()
+ except ZonefileError, lineno:
+ log(0,'Error reading zone file ' + filepath + ' at line ' +
+ str(lineno) + '\n')
+ del zonedict[k]
+
+def slowloop(tofunc='',timeout=5.0):
+ if not tofunc:
+ def tofunc(): return
+ map = asyncore.socket_map
+ while map:
+ r = []; w=[]; e=[]
+ for fd, obj in map.items():
+ if obj.readable():
+ r.append(fd)
+ if obj.writable():
+ w.append(fd)
+ try:
+ starttime = time.time()
+ r,w,e = select.select(r,w,e,timeout)
+ endtime = time.time()
+ if endtime-starttime >= timeout:
+ tofunc()
+ except select.error, err:
+ if err[0] != EINTR:
+ raise
+ r=[]; w=[]; e=[]
+ log(0,'ERROR in select')
+
+ for fd in r:
+ try:
+ obj=map[fd]
+ except KeyError:
+ log(0,'KeyError in socket map')
+ continue
+ try:
+ obj.handle_read_event()
+ except:
+ log(0,'calling HANDLE ERROR from loop')
+ log(0,repr(obj))
+ obj.handle_error()
+ for fd in w:
+ try:
+ obj=map[fd]
+ except KeyError:
+ log(0,'KeyError in socket map')
+ continue
+ try:
+ obj.handle_read_event()
+ except:
+ log(0,'calling HANDLE ERROR from loop')
+ log(0,repr(obj))
+ obj.handle_error()
+
+def fastloop(tofunc='',timeout=5.0):
+ if not tofunc:
+ def tofunc(): return
+ polltimeout = timeout*1000
+ map = asyncore.socket_map
+ while map:
+ regfds = 0
+ pollobj = select.poll()
+ for fd, obj in map.items():
+ flags = 0
+ if obj.readable():
+ flags = select.POLLIN
+ if obj.writable():
+ flags = flags | select.POLLOUT
+ if flags:
+ pollobj.register(fd, flags)
+ regfds = regfds + 1
+ try:
+ starttime = time.time()
+ r = pollobj.poll(polltimeout)
+ endtime = time.time()
+ if endtime-starttime >= timeout:
+ tofunc()
+ except select.error, err:
+ if err[0] != EINTR:
+ raise
+ r = []
+ log(0,'ERROR in select')
+ for fd, flags in r:
+ try:
+ obj = map[fd]
+ badvals = (select.POLLPRI + select.POLLERR +
+ select.POLLHUP + select.POLLNVAL)
+ if (flags & badvals):
+ if (flags & select.POLLPRI):
+ log(0,'POLLPRI')
+ if (flags & select.POLLERR):
+ log(0,'POLLERR')
+ if (flags & select.POLLHUP):
+ log(0,'POLLHUP')
+ if (flags & select.POLLNVAL):
+ log(0,'POLLNVAL')
+ obj.handle_error()
+ else:
+ if (flags & select.POLLIN):
+ obj.handle_read_event()
+ if (flags & select.POLLOUT):
+ obj.handle_write_event()
+ except KeyError:
+ log(0,'KeyError in socket map')
+ continue
+ except:
+ # print traceback
+ sf = StringIO.StringIO()
+ traceback.print_exc(file=sf)
+ log(0,'ERROR IN LOOP:')
+ log(0,sf.getvalue())
+ sf.close()
+ log(0,repr(obj))
+ obj.handle_error()
+
+if hasattr(select,'poll'):
+ loop = fastloop
+else:
+ loop = slowloop
+
+class ZonefileError(Exception):
+ def __init__(self, linenum, errordesc=''):
+ self.linenum = linenum
+ self.errordesc = errordesc
+ def __str__(self):
+ return str(self.linenum) + ' (' + self.errordesc + ')'
+
+class zonefileparser:
+ def __init__(self):
+ self.zonedata = {}
+ self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX',
+ 'NS','PTR','RP','SOA','SRV','TXT']
+
+ def stripcomments(self, line):
+ i = line.find(';')
+ if i >= 0:
+ line = line[:i]
+ return line
+
+ def strip(self, line):
+ # strip trailing linefeeds
+ if line[-1:] == '\n':
+ line = line[:-1]
+ return line
+
+ def getzdict(self):
+ return self.zonedata
+
+ def addorigin(self, origin, name):
+ if name[-1:] != '.':
+ return name + '.' + origin
+ else:
+ return name[:-1]
+
+ def getstrings(self, s):
+ if s.find('"') == -1:
+ return s.split()
+ else:
+ x = s.split('"')
+ rlist = []
+ for i in x:
+ if i != '' and i != ' ':
+ rlist.append(i)
+ return rlist
+
+ def getlocsize(self, s):
+ if s[-1:] == 'm':
+ size = float(s[:-1])*100
+ else:
+ size = float(s)*100
+ i = 0
+ while size > 9:
+ size = size/10
+ i = i + 1
+ return (int(size),i)
+
+ def getloclat(self, l,c):
+ deg = float(l[0])
+ min = 0
+ secs = 0
+ if len(l) == 3:
+ min = float(l[1])
+ secs = float(l[2])
+ elif len(l) == 2:
+ min = float(l[1])
+ rval = ((((deg *60) + min) * 60) + secs) * 1000
+ if c in ['N','E']:
+ rval = rval + (2**31)
+ elif c in ['S','W']:
+ rval = (2**31) - rval
+ else:
+ log(0,'ERROR: unsupported latitude/longitude direction')
+ return long(rval)
+
+ def getgname(self, name, iter):
+ if name == '0' or name == 'O':
+ return ''
+ start = 0
+ offset = 0
+ width = 0
+ base = 'd'
+ for x in range(name.count('$')):
+ i = name.find('$',start)
+ j = i
+ start = i+1
+ if i>0:
+ if name[i-1] == '\\':
+ continue
+ if len(name)>i+1:
+ if name[i+1] == '$':
+ continue
+ if name[i+1] == '{':
+ j = name.find('}',i+1)
+ owb = name[i+2:j].split(',')
+ if len(owb) == 1:
+ offset = int(owb[0])
+ elif len(owb) == 2:
+ offset = int(owb[0])
+ width = int(owb[1])
+ elif len(owb) == 3:
+ offset = int(owb[0])
+ width = int(owb[1])
+ base = owb[2]
+ val = iter - offset
+ if base == 'd':
+ rs = str(val)
+ elif base == 'o':
+ rs = oct(val)
+ elif base == 'x':
+ rs = hex(val)[2:].lower()
+ elif base == 'X':
+ rs = hex(val)[2:].upper()
+ else:
+ rs = ''
+ if len(rs) > width:
+ rs = (width-len(rs))*'0'+rs
+ name = name[:i]+rs+name[j+1:]
+ start = i+len(rs)+1
+
+ return name
+
+ def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens):
+ rdata = {}
+ rdata['class'] = dnsclass
+ rdata['ttl'] = ttl
+ if dnstype == 'A':
+ rdata['address'] = tokens[0]
+ elif dnstype == 'AAAA':
+ rdata['address'] = tokens[0]
+ elif dnstype == 'CNAME':
+ rdata['cname'] = self.addorigin(origin,tokens[0].lower())
+ elif dnstype == 'HINFO':
+ sl = self.getstrings(' '.join(tokens))
+ rdata['cpu'] = sl[0]
+ rdata['os'] = sl[1]
+ elif dnstype == 'LOC':
+ if 'N' in tokens:
+ i = tokens.index('N')
+ else:
+ i = tokens.index('S')
+ lat = self.getloclat(tokens[0:i],tokens[i])
+ if 'E' in tokens:
+ j = tokens.index('E')
+ else:
+ j = tokens.index('W')
+ lng = self.getloclat(tokens[i+1:j],tokens[j])
+ size = self.getlocsize('1m')
+ horiz_pre = self.getlocsize('1000m')
+ vert_pre = self.getlocsize('10m')
+ if len(tokens[j+1:]) == 2:
+ size = self.getlocsize(tokens[-1:][0])
+ elif len(tokens[j+1:]) == 3:
+ size = self.getlocsize(tokens[-2:-1][0])
+ horiz_pre = self.getlocsize(tokens[-1:][0])
+ elif len(tokens[j+1:]) == 4:
+ size = self.getlocsize(tokens[-3:-2][0])
+ horiz_pre = self.getlocsize(tokens[-2:-1][0])
+ vert_pre = self.getlocsize(tokens[-1:][0])
+ if tokens[j+1][-1:] == 'm':
+ alt = int((float(tokens[j+1][:-1])*100)+10000000)
+ else:
+ size = int((float(tokens[j+1])*100)+10000000)
+ rdata['version'] = 0
+ rdata['size'] = size
+ rdata['horiz_pre'] = horiz_pre
+ rdata['vert_pre'] = vert_pre
+ rdata['latitude'] = lat
+ rdata['longitude'] = lng
+ rdata['altitude'] = 0
+ elif dnstype == 'MX':
+ rdata['preference'] = int(tokens[0])
+ rdata['exchange'] = self.addorigin(origin,tokens[1].lower())
+ elif dnstype == 'NS':
+ rdata['nsdname'] = self.addorigin(origin,tokens[0].lower())
+ elif dnstype == 'PTR':
+ rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower())
+ elif dnstype == 'RP':
+ rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower())
+ rdata['txtdname'] = self.addorigin(origin,tokens[1].lower())
+ elif dnstype == 'SOA':
+ rdata['mname'] = self.addorigin(origin,tokens[0].lower())
+ rdata['rname'] = self.addorigin(origin,tokens[1].lower())
+ rdata['serial'] = int(tokens[2])
+ rdata['refresh'] = int(tokens[3])
+ rdata['retry'] = int(tokens[4])
+ rdata['expire'] = int(tokens[5])
+ rdata['minimum'] = int(tokens[6])
+ elif dnstype == 'SRV':
+ rdata['priority'] = int(tokens[0])
+ rdata['weight'] = int(tokens[1])
+ rdata['port'] = int(tokens[2])
+ rdata['target'] = self.addorigin(origin,tokens[3].lower())
+ elif dnstype == 'TXT':
+ rdata['txtdata'] = self.getstrings(' '.join(tokens))[0]
+ else:
+ raise ZonefileError(lineno,'bad DNS type')
+ return rdata
+
+ def addrec(self, owner, dnstype, rrdata):
+ if self.zonedata.has_key(owner):
+ if not self.zonedata[owner].has_key(dnstype):
+ self.zonedata[owner][dnstype] = []
+ else:
+ self.zonedata[owner] = {}
+ self.zonedata[owner][dnstype] = []
+ self.zonedata[owner][dnstype].append(rrdata)
+
+ def parse(self, origin, f):
+ closefile = 0
+ if type(f) != types.FileType:
+ # must be a path
+ try:
+ f = file(f)
+ closefile = 1
+ except:
+ log(0,'Invalid path to zonefile')
+ return
+ lastowner = ''
+ lastdnsclass = ''
+ lastttl = 3600
+ lineno = 0
+ while 1:
+ line = f.readline()
+ if not line:
+ break
+ lineno = lineno + 1
+ line = self.stripcomments(line)
+ line = self.strip(line)
+ if not line:
+ continue
+ if line.find('(') >= 0:
+ # grab lines until end paren
+ if line.find(')') == -1:
+ line2 = self.stripcomments(f.readline())
+ lineno = lineno + 1
+ line2 = self.strip(line2)
+ line = line + line2
+ while line2.find(')') == -1:
+ line2 = self.strip(self.stripcomments(f.readline()))
+ lineno = lineno + 1
+ line = line + line2
+ # now strip the parenthesis
+ line = line.replace(')','')
+ line = line.replace('(','')
+ # now line equals the entire RR entry
+ tokens = line.split()
+ if tokens[0].upper() == '$ORIGIN':
+ try:
+ origin = tokens[1].lower()
+ except:
+ raise ZonefileError(lineno, 'bad origin')
+ elif tokens[0].upper() == '$INCLUDE':
+ try:
+ f2 = file(tokens[1].lower())
+ if len(tokens) > 2:
+ self.parse(tokens[2].lower(), f2)
+ else:
+ self.parse(origin, f2)
+ f2.close()
+ except:
+ raise ZonefileError(lineno, 'bad INCLUDE directive')
+ elif tokens[0].upper() == '$TTL':
+ try:
+ lastttl = int(tokens[1])
+ except:
+ raise ZonefileError(lineno, 'bad TTL directive')
+ elif tokens[0].upper() == '$GENERATE':
+ try:
+ lhs = tokens[2].lower()
+ dnstype = tokens[3].upper()
+ rhs = tokens[4].lower()
+ rng = tokens[1].split('-')
+ start = int(rng[0])
+ i = rng[1].find('/')
+ if i != -1:
+ stop = int(rng[1][:i])+1
+ step = int(rng[1][i+1:])
+ else:
+ stop = int(rng[1])+1
+ step = 1
+ for i in range(start,stop,step):
+ grhs = self.getgname(rhs,i)
+ if dnstype in ['NS','CNAME','PTR']:
+ grhs = self.addorigin(origin,grhs)
+ rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl,
+ [grhs])
+ glhs = self.addorigin(origin,self.getgname(lhs,i))
+ self.addrec(glhs,dnstype, rrdata)
+ except KeyError:
+ raise ZonefileError(lineno, 'bad GENERATE directive')
+ else:
+ try:
+ # if line begins with blank then owner is last owner
+ if line[0] in string.whitespace:
+ owner = lastowner
+ else:
+ owner = tokens[0].lower()
+ tokens = tokens[1:]
+ if owner == '@':
+ owner = origin
+ elif owner[-1:] != '.':
+ owner = owner + '.' + origin
+ else:
+ owner = owner[:-1] # strip off trailing dot
+ # line format is either: [class] [ttl] type RDATA
+ # or [ttl] [class] type RDATA
+ # - items in brackets are optional
+ #
+ # need to figure out which token is type
+ # and backfill the missing data
+ count = 0
+ for token in tokens:
+ if token.upper() in self.dnstypes:
+ break
+ count = count + 1
+ # the following strips off the ttl and class if they exist
+ if count == 0:
+ ttl = lastttl
+ dnsclass = lastdnsclass
+ elif count == 1:
+ if tokens[0].isdigit():
+ ttl = int(tokens[0])
+ dnsclass = lastdnsclass
+ else:
+ ttl = lastttl
+ dnsclass = tokens[0].upper()
+ tokens = tokens[1:]
+ elif count == 2:
+ if tokens[0].isdigit():
+ ttl = int(tokens[0])
+ dnsclass = tokens[1].upper()
+ else:
+ ttl = int(tokens[1])
+ dnsclass = tokens[0].upper()
+ tokens = tokens[2:]
+ else:
+ raise ZonefileError(lineno,'bad ttl or class')
+ dnstype = tokens[0]
+ # make sure all of the structure is there
+ rrdata = self.getrrdata(origin, dnstype, dnsclass,
+ ttl, tokens[1:])
+ self.addrec(owner, dnstype, rrdata)
+ lastowner = owner
+ lastttl = ttl
+ lastdnsclass = dnsclass
+ except:
+ raise ZonefileError(lineno,'unable to parse line')
+ if closefile:
+ f.close()
+
+class dnsconfig:
+ def __init__(self):
+ # self.zonedb = zonedb({})
+ self.cached = {}
+ self.loglevel = 0
+
+ def getview(self, msg, address, port):
+ # return:
+ # 1. a list of zone keys
+ # 2. whether or not to use the resolver
+ # (i.e. answer recursive queries)
+ # 3. a list of forwarder addresses
+ return ['servers.csail.mit.edu'], 1, []
+
+ def allowupdate(self, msg, address, port):
+ # return 1 if updates are allowed
+ # NOTE: can only update the zones
+ # returned by the getview func
+ return 1
+
+ def outpackets(self, packetlist):
+ # modify outgoing packets
+ return packetlist
+
+class dnsheader:
+ def __init__(self, id=1):
+ self.id = id # 16bit identifier generated by queryer
+ self.qr = 0 # one bit field specifying query(0) or response(1)
+ self.opcode = 0 # 4bit field specifying type of query
+ self.aa = 0 # authoritative answer
+ self.tc = 0 # message is not truncated
+ self.rd = 1 # recursion desired
+ self.ra = 0 # recursion available?
+ self.z = 0 # reserved for future use
+ self.rcode = 0 # response code (set in response)
+ self.qdcount = 1 # number of questions, only 1 is supported
+ self.ancount = 0 # number of rrs in the answer section
+ self.nscount = 0 # number of name server rrs in authority section
+ self.arcount = 0 # number or rrs in the additional section
+
+class dnsquestion:
+ def __init__(self):
+ self.qname = 'localhost'
+ self.qtype = 'A'
+ self.qclass = 'IN'
+
+class dnsupdatezone:
+ pass
+
+class message:
+ def __init__(self, msgdata=''):
+ if msgdata:
+ self.header = dnsheader()
+ else:
+ self.header = dnsheader(id=random.randrange(1,32768))
+ self.question = dnsquestion()
+ self.answerlist = []
+ self.authlist = []
+ self.addlist = []
+ self.u = ''
+ self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA',
+ 7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS',
+ 12:'PTR',13:'HINFO',14:'MINFO',15:'MX',
+ 16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV',
+ 38:'A6',39:'DNAME',251:'IXFR',252:'AXFR',
+ 253:'MAILB',254:'MAILA',255:'ANY'}
+ self.rqtypes = {}
+ for key in self.qtypes.keys():
+ self.rqtypes[self.qtypes[key]] = key
+ self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'}
+ self.rqclasses = {}
+ for key in self.qclasses.keys():
+ self.rqclasses[self.qclasses[key]] = key
+
+ if msgdata:
+ self.processpkt(msgdata)
+
+ def getdomainname(self, data, i):
+ log(4,'IN GETDOMAINNAME')
+ domainname = ''
+ gotpointer = 0
+ labellength= ord(data[i])
+ log(4,'labellength:' + str(labellength))
+ i = i + 1
+ while labellength != 0:
+ while labellength >= 192:
+ # pointer
+ if not gotpointer:
+ rindex = i + 1
+ gotpointer = 1
+ log(4,'got pointer')
+ i = asctoint(chr(ord(data[i-1]) & 63)+data[i])
+ log(4,'new index:'+str(i))
+ labellength = ord(data[i])
+ log(4,'labellength:' + str(labellength))
+ i = i + 1
+ if domainname:
+ domainname = domainname + '.' + data[i:i+labellength]
+ else:
+ domainname = data[i:i+labellength]
+ log(4,'domainname:'+domainname)
+ i = i + labellength
+ labellength = ord(data[i])
+ log(4,'labellength:' + str(labellength))
+ i = i + 1
+ if not gotpointer:
+ rindex = i
+
+ return domainname.lower(), rindex
+
+ def getrrdata(self, type, msgdata, rdlength, i):
+ log(4,'unpacking RR data')
+ rdata = msgdata[i:i+rdlength]
+ if type == 'A':
+ return {'address':socket.inet_ntoa(rdata)}
+ elif type == 'AAAA':
+ return {'address':ipv6net_ntoa(rdata)}
+ elif type == 'CNAME':
+ cname, i = self.getdomainname(msgdata,i)
+ return {'cname':cname}
+ elif type == 'HINFO':
+ cpulen = ord(rdata[0])
+ cpu = rdata[1:cpulen+1]
+ return {'cpu':cpu,
+ 'os':rdata[cpulen+2:]}
+ elif type == 'LOC':
+ return {'version':ord(rdata[0]),
+ 'size':self.locsize(rdata[1]),
+ 'horiz_pre':self.locsize(rdata[2]),
+ 'vert_pre':self.locsize(rdata[3]),
+ 'latitude':asctoint(rdata[4:8]),
+ 'longitude':asctoint(rdata[8:12]),
+ 'altitude':asctoint(rdata[12:16])}
+ elif type == 'MX':
+ exchange, i = self.getdomainname(msgdata,i+2)
+ return {'preference':asctoint(rdata[:2]),
+ 'exchange':exchange}
+ elif type == 'NS':
+ nsdname, i = self.getdomainname(msgdata,i)
+ return {'nsdname':nsdname}
+ elif type == 'PTR':
+ ptrdname, i = self.getdomainname(msgdata,i)
+ return {'ptrdname':ptrdname}
+ elif type == 'RP':
+ mboxdname, i = self.getdomainname(msgdata,i)
+ txtdname, i = self.getdomainname(msgdata,i)
+ return {'mboxdname':mboxdname,
+ 'txtdname':txtdname}
+ elif type == 'SOA':
+ mname, i = self.getdomainname(msgdata,i)
+ rname, i = self.getdomainname(msgdata,i)
+ return {'mname':mname,
+ 'rname':rname,
+ 'serial':asctoint(msgdata[i:i+4]),
+ 'refresh':asctoint(msgdata[i+4:i+8]),
+ 'retry':asctoint(msgdata[i+8:i+12]),
+ 'expire':asctoint(msgdata[i+12:i+16]),
+ 'minimum':asctoint(msgdata[i+16:i+20])}
+ elif type == 'SRV':
+ target, i = self.getdomainname(msgdata,i+6)
+ return {'priority':asctoint(rdata[0:2]),
+ 'weight':asctoint(rdata[2:4]),
+ 'port':asctoint(rdata[4:6]),
+ 'target':target}
+ elif type == 'TXT':
+ return {'txtdata':rdata[1:]}
+ else:
+ return {'rdata':rdata}
+
+ def getrr(self, data, i):
+ log(4,'unpacking RR name')
+ name, i = self.getdomainname(data, i)
+ type = asctoint(data[i:i+2])
+ type = self.qtypes.get(type,chr(type))
+ klass = asctoint(data[i+2:i+4])
+ klass = self.qclasses.get(klass,chr(klass))
+ ttl = asctoint(data[i+4:i+8])
+ rdlength = asctoint(data[i+8:i+10])
+ rrdata = self.getrrdata(type,data,rdlength,i+10)
+ rrdata['ttl'] = ttl
+ rrdata['class'] = klass
+ rr = {name:{type:[rrdata]}}
+ return rr, i+10+rdlength
+
+ def processpkt(self, msgdata):
+ self.header.id = asctoint(msgdata[:2])
+ self.header.qr = ord(msgdata[2]) >> 7
+ self.header.opcode = (ord(msgdata[2]) & 127) >> 3
+ if self.header.opcode == 5:
+ # UPDATE packet
+ log(4,'processing UPDATE packet')
+ del self.header.aa
+ del self.header.tc
+ del self.header.rd
+ del self.header.ra
+ del self.header.qdcount
+ del self.header.ancount
+ del self.header.nscount
+ del self.header.arcount
+ del self.question
+ self.zone = dnsupdatezone()
+ del self.answerlist
+ del self.authlist
+ del self.addlist
+ self.header.z = 0
+ self.header.rcode = ord(msgdata[3]) & 15
+ self.header.zocount = asctoint(msgdata[4:6])
+ self.header.prcount = asctoint(msgdata[6:8])
+ self.header.upcount = asctoint(msgdata[8:10])
+ self.header.arcount = asctoint(msgdata[10:12])
+ self.zolist = []
+ self.prlist = []
+ self.uplist = []
+ self.addlist = []
+ i = 12
+ for x in range(self.header.zocount):
+ (dn, i) = self.getdomainname(msgdata,i)
+ self.zone.zname = dn
+ type = asctoint(msgdata[i:i+2])
+ self.zone.ztype = self.qtypes.get(type,chr(type))
+ klass = asctoint(msgdata[i+2:i+4])
+ self.zone.zclass = self.qclasses.get(klass,chr(klass))
+ i = i + 4
+ for x in range(self.header.prcount):
+ rr, i = self.getrr(msgdata,i)
+ self.prlist.append(rr)
+ for x in range(self.header.upcount):
+ rr, i = self.getrr(msgdata,i)
+ self.uplist.append(rr)
+ for x in range(self.header.arcount):
+ rr, i = self.getrr(msgdata,i)
+ self.adlist.append(rr)
+ else:
+ self.header.aa = (ord(msgdata[2]) & 4) >> 2
+ self.header.tc = (ord(msgdata[2]) & 2) >> 1
+ self.header.rd = ord(msgdata[2]) & 1
+ self.header.ra = ord(msgdata[3]) >> 7
+ self.header.z = (ord(msgdata[3]) & 112) >> 4
+ self.header.rcode = ord(msgdata[3]) & 15
+ self.header.qdcount = asctoint(msgdata[4:6])
+ self.header.ancount = asctoint(msgdata[6:8])
+ self.header.nscount = asctoint(msgdata[8:10])
+ self.header.arcount = asctoint(msgdata[10:12])
+ i = 12
+ for x in range(self.header.qdcount):
+ log(4,'unpacking question')
+ (dn, i) = self.getdomainname(msgdata,i)
+ self.question.qname = dn
+ rrtype = asctoint(msgdata[i:i+2])
+ self.question.qtype = self.qtypes.get(rrtype,chr(rrtype))
+ klass = asctoint(msgdata[i+2:i+4])
+ self.question.qclass = self.qclasses.get(klass,chr(klass))
+ i = i + 4
+ for x in range(self.header.ancount):
+ log(4,'unpacking answer RR')
+ rr, i = self.getrr(msgdata,i)
+ self.answerlist.append(rr)
+ for x in range(self.header.nscount):
+ log(4,'unpacking auth RR')
+ rr, i = self.getrr(msgdata,i)
+ self.authlist.append(rr)
+ for x in range(self.header.arcount):
+ log(4,'unpacking additional RR')
+ rr, i = self.getrr(msgdata,i)
+ self.addlist.append(rr)
+ return
+
+ def pds(self, s, l):
+ # pad string with chr(0)'s so that
+ # return string length is l
+ x = l - len(s)
+ return x*chr(0) + s
+
+ def locsize(self, s):
+ x1 = ord(s) >> 4
+ x2 = ord(s) & 15
+ return (x1, x2)
+
+ def packlocsize(self, x):
+ return chr((x[0] << 4) + x[1])
+
+ def packdomainname(self, name, i, msgcomp):
+ log(4,'packing domainname: ' + name)
+ if name == '':
+ return chr(0)
+ if name in msgcomp.keys():
+ log(4,'using pointer for: ' + name)
+ return msgcomp[name]
+ packedname = ''
+ tokens = name.split('.')
+ for j in range(len(tokens)):
+ packedname = packedname + chr(len(tokens[j])) + tokens[j]
+ nameleft = '.'.join(tokens[j+1:])
+ if nameleft in msgcomp.keys():
+ log(4,'using pointer for: ' + nameleft)
+ return packedname+msgcomp[nameleft]
+ # haven't used a pointer so put this in the dictionary
+ pointer = inttoasc(i)
+ if len(pointer) == 1:
+ msgcomp[name] = chr(192)+pointer
+ else:
+ msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1]
+ log(4,'added pointer for ' + name + '(' + str(i) + ')')
+ return packedname + chr(0)
+
+ def packrr(self, rr, i, msgcomp):
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ if self.rqtypes.has_key(rrtype):
+ typeval = self.rqtypes[rrtype]
+ else:
+ typeval = ord(rrtype)
+ dbrec = rr[rrname][rrtype][0]
+ ttl = dbrec['ttl']
+ rclass = self.rqclasses[dbrec['class']]
+ packedrr = (self.packdomainname(rrname, i, msgcomp) +
+ self.pds(inttoasc(typeval),2) +
+ self.pds(inttoasc(rclass),2) +
+ self.pds(inttoasc(ttl),4))
+ i = i + len(packedrr) + 2
+ if rrtype == 'A':
+ rdata = socket.inet_aton(dbrec['address'])
+ elif rrtype == 'AAAA':
+ rdata = ipv6net_aton(dbrec['address'])
+ elif rrtype == 'CNAME':
+ rdata = self.packdomainname(dbrec['cname'], i, msgcomp)
+ elif rrtype == 'HINFO':
+ rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] +
+ chr(len(dbrec['os'])) + dbrec['os'])
+ elif rrtype == 'LOC':
+ rdata = (chr(dbrec['version']) +
+ self.packlocsize(dbrec['size']) +
+ self.packlocsize(dbrec['horiz_pre']) +
+ self.packlocsize(dbrec['vert_pre']) +
+ self.pds(inttoasc(dbrec['latitude']),4) +
+ self.pds(inttoasc(dbrec['longitude']),4) +
+ self.pds(inttoasc(dbrec['altitude']),4))
+ elif rrtype == 'MX':
+ rdata = (self.pds(inttoasc(dbrec['preference']),2) +
+ self.packdomainname(dbrec['exchange'], i+2, msgcomp))
+ elif rrtype == 'NS':
+ rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp)
+ elif rrtype == 'PTR':
+ rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp)
+ elif rrtype == 'RP':
+ rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
+ i = i + len(rdata1)
+ rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
+ rdata = rdata1 + rdata2
+ elif rrtype == 'SOA':
+ rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp)
+ i = i + len(rdata1)
+ rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp)
+ rdata = (rdata1 +
+ rdata2 +
+ self.pds(inttoasc(dbrec['serial']),4) +
+ self.pds(inttoasc(dbrec['refresh']),4) +
+ self.pds(inttoasc(dbrec['retry']),4) +
+ self.pds(inttoasc(dbrec['expire']),4) +
+ self.pds(inttoasc(dbrec['minimum']),4))
+ elif rrtype == 'SRV':
+ rdata = (self.pds(inttoasc(dbrec['priority']),2) +
+ self.pds(inttoasc(dbrec['weight']),2) +
+ self.pds(inttoasc(dbrec['port']),2) +
+ self.packdomainname(dbrec['target'], i+6, msgcomp))
+ elif rrtype == 'TXT':
+ rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata']
+ else:
+ rdata = dbrec['rdata']
+
+ return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata
+
+ def buildpkt(self):
+ # keep dictionary of names packed (so we can use pointers)
+ msgcomp = {}
+ # header
+ if self.header.id > 65535:
+ log(0,'building packet with bad ID field')
+ self.header.id = 1
+ msgdata = inttoasc(self.header.id)
+ if len(msgdata) == 1:
+ msgdata = chr(0) + msgdata
+ h1 = ((self.header.qr << 7) +
+ (self.header.opcode << 3) +
+ (self.header.aa << 2) +
+ (self.header.tc << 1) +
+ (self.header.rd))
+ h2 = ((self.header.ra << 7) +
+ (self.header.z << 4) +
+ (self.header.rcode))
+ msgdata = msgdata + chr(h1) + chr(h2)
+ msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2)
+ msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2)
+ msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2)
+ msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2)
+ # question
+ msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp)
+ if self.rqtypes.has_key(self.question.qtype):
+ typeval = self.rqtypes[self.question.qtype]
+ else:
+ typeval = ord(self.question.qtype)
+ msgdata = msgdata + self.pds(inttoasc(typeval),2)
+ if self.rqclasses.has_key(self.question.qclass):
+ classval = self.rqclasses[self.question.qclass]
+ else:
+ classval = ord(self.question.qclass)
+ msgdata = msgdata + self.pds(inttoasc(classval),2)
+ # rr's
+ # RR record format:
+ # {'name' : {'type' : [rdata, rdata, ...]}
+ # example: {'test.blah.net': {'A': [{'address': '10.1.1.2',
+ # 'ttl': 3600L}]}}
+ for rr in self.answerlist:
+ log(4,'packing answer RR')
+ msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
+ for rr in self.authlist:
+ log(4,'packing auth RR')
+ msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
+ for rr in self.addlist:
+ log(4,'packing additional RR')
+ msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
+
+ return msgdata
+
+ def printpkt(self):
+ print 'ID: ' +str(self.header.id)
+ if self.header.qr:
+ print 'QR: RESPONSE'
+ else:
+ print 'QR: QUERY'
+ if self.header.opcode == 0:
+ print 'OPCODE: STANDARD QUERY'
+ elif self.header.opcode == 1:
+ print 'OPCODE: INVERSE QUERY'
+ elif self.header.opcode == 2:
+ print 'OPCODE: SERVER STATUS REQUEST'
+ elif self.header.opcode == 5:
+ print 'UPDATE REQUEST'
+ else:
+ print 'OPCODE: UNKNOWN QUERY TYPE'
+ if self.header.opcode != 5:
+ if self.header.aa:
+ print 'AA: AUTHORITATIVE ANSWER'
+ else:
+ print 'AA: NON-AUTHORITATIVE ANSWER'
+ if self.header.tc:
+ print 'TC: MESSAGE IS TRUNCATED'
+ else:
+ print 'TC: MESSAGE IS NOT TRUNCATED'
+ if self.header.rd:
+ print 'RD: RECURSION DESIRED'
+ else:
+ print 'RD: RECURSION NOT DESIRED'
+ if self.header.ra:
+ print 'RA: RECURSION AVAILABLE'
+ else:
+ print 'RA: RECURSION IS NOT AVAILABLE'
+ if self.header.rcode == 1:
+ printrcode = 'FORMERR'
+ elif self.header.rcode == 2:
+ printrcode = 'SERVFAIL'
+ elif self.header.rcode == 3:
+ printrcode = 'NXDOMAIN'
+ elif self.header.rcode == 4:
+ printrcode = 'NOTIMP'
+ elif self.header.rcode == 5:
+ printrcode = 'REFUSED'
+ elif self.header.rcode == 6:
+ printrcode = 'YXDOMAIN'
+ elif self.header.rcode == 7:
+ printrcode = 'YXRRSET'
+ elif self.header.rcode == 8:
+ printrcode = 'NXRRSET'
+ elif self.header.rcode == 9:
+ printrcode = 'NOTAUTH'
+ elif self.header.rcode == 10:
+ printrcode = 'NOTZONE'
+ else:
+ printrcode = 'NOERROR'
+ print 'RCODE: ' + printrcode
+ if self.header.opcode == 5:
+ print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount)
+ print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount)
+ print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount)
+ print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount)
+ print 'ZONE SECTION:'
+ print 'zname: ' + self.zone.zname
+ print 'zonetype: ' + self.zone.ztype
+ print 'zoneclass: ' + self.zone.zclass
+ print 'PREREQUISITE RRs:'
+ for rr in self.prlist:
+ print rr
+ print 'UPDATE RRs:'
+ for rr in self.uplist:
+ print rr
+ print 'ADDITIONAL RRs:'
+ for rr in self.addlist:
+ print rr
+
+
+ else:
+ print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount)
+ print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount)
+ print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount)
+ print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount)
+ print 'QUESTION SECTION:'
+ print 'qname: ' + self.question.qname
+ print 'querytype: ' + self.question.qtype
+ print 'queryclass: ' + self.question.qclass
+ print 'ANSWER RRs:'
+ for rr in self.answerlist:
+ print rr
+ print 'AUTHORITY RRs:'
+ for rr in self.authlist:
+ print rr
+ print 'ADDITIONAL RRs:'
+ for rr in self.addlist:
+ print rr
+
+class zonedb:
+ def __init__(self, zdict):
+ self.zdict = zdict
+ self.updates = {}
+ for k in self.zdict.keys():
+ if self.zdict[k]['type'] == 'slave':
+ self.zdict[k]['lastupdatetime'] = 0
+
+ def error(self, id, qname, querytype, queryclass, rcode):
+ error = message()
+ error.header.id = id
+ error.header.rcode = rcode
+ error.header.qr = 1
+ error.question.qname = qname
+ error.question.qtype = querytype
+ error.question.qclass = queryclass
+ return error
+
+ def getorigin(self, zkey):
+ origin = ''
+ if self.zdict.has_key(zkey):
+ origin = self.zdict[zkey]['origin']
+ return origin
+
+ def getmasterip(self, zkey):
+ masterip = ''
+ if self.zdict.has_key(zkey):
+ if self.zdict[zkey].has_key('masterip'):
+ masterip = self.zdict[zkey]['masterip']
+ return masterip
+
+ def zonetrans(self, query):
+ # build a list of messages
+ # each message contains one rr of the zone
+ # the first and last message are the
+ # SOA records
+ origin = query.question.qname
+ querytype = query.question.qtype
+ zkey = ''
+ for zonekey in self.zdict.keys():
+ if self.zdict[zonekey]['origin'] == query.question.qname:
+ zkey = zonekey
+ if not zkey:
+ return []
+ zonedata = self.zdict[zkey]['zonedata']
+ queryid = query.header.id
+ soarec = zonedata[origin]['SOA'][0]
+ soa = {origin:{'SOA':[soarec]}}
+ curserial = soarec['serial']
+ rrlist = []
+ if querytype == 'IXFR':
+ clientserial = query.authlist[0][origin]['SOA'][0]['serial']
+ if clientserial < curserial:
+ for i in range(clientserial,curserial+1):
+ if self.updates[zkey].has_key(i):
+ for rr in self.updates[zkey][i]['added']:
+ rrlist.append(rr)
+ for rr in self.updates[zkey][i]['removed']:
+ rrlist.append(rr)
+ if len(rrlist) > 0:
+ rrlist.insert(0,soa)
+ rrlist.append(soa)
+ else:
+ rrlist.append(soa)
+ else:
+ for nodename in zonedata.keys():
+ for rrtype in zonedata[nodename].keys():
+ if not (rrtype == 'SOA' and nodename == origin):
+ for rr in zonedata[nodename][rrtype]:
+ rrlist.append({nodename:{rrtype:[rr]}})
+ rrlist.insert(0,soa)
+ rrlist.append(soa)
+ msglist = []
+ for rr in rrlist:
+ msg = message()
+ msg.header.id = queryid
+ msg.header.qr = 1
+ msg.header.aa = 1
+ msg.header.rd = 0
+ msg.header.qdcount = 1
+ msg.question.qname = origin
+ msg.question.qtype = querytype
+ msg.question.qclass = 'IN'
+ msg.header.ancount = 1
+ msg.answerlist.append(rr)
+ msglist.append(msg)
+ return msglist
+
+ def update_zone(self, rrlist, params):
+ zonekey = params[0]
+ zonedata = {}
+ soa = rrlist.pop()
+ origin = soa.keys()[0]
+ for rr in rrlist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if zonedata.has_key(rrname):
+ if not zonedata[rrname].has_key(rrtype):
+ zonedata[rrname][rrtype] = []
+ else:
+ zonedata[rrname] = {}
+ zonedata[rrname][rrtype] = []
+ zonedata[rrname][rrtype].append(dbrec)
+ self.zdict[zonekey]['zonedata'] = zonedata
+ curtime = time.time()
+ self.zdict[zonekey]['lastupdatetime'] = curtime
+ try:
+ f = file(self.zdict[zonekey]['filename'],'w')
+ writezonefile(zonedata, self.zdict[zonekey]['origin'], f)
+ f.close()
+ except:
+ log(0,'unable to write zone ' + zonekey + 'to disk')
+ log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')')
+
+ def remove_zone(self, zonekey):
+ if self.zdict.has_key(zonekey):
+ del self.zdict[zonekey]
+
+ def getslaves(self, curtime):
+ rlist = []
+ for k in self.zdict.keys():
+ if self.zdict[k]['type'] == 'slave':
+ origin = self.zdict[k]['origin']
+ refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh']
+ if self.zdict[k]['lastupdatetime'] + refresh < curtime:
+ rlist.append((k, origin, self.zdict[k]['masterip']))
+ return rlist
+
+ def zmatch(self, qname, zkeys):
+ for zkey in zkeys:
+ if self.zdict.has_key(zkey):
+ origin = self.zdict[zkey]['origin']
+ if qname.rfind(origin) != -1:
+ return zkey
+ return ''
+
+ def getzlist(self, name, zone):
+ if name == zone:
+ return
+ zlist = []
+ i = name.rfind(zone)
+ if i == -1:
+ return
+ firstpart = name[:i-1]
+ partlist = firstpart.split('.')
+ partlist.reverse()
+ lastpart = zone
+ for x in range(len(partlist)):
+ lastpart = partlist[x] + '.' + lastpart
+ zlist.append(lastpart)
+ return zlist
+
+ def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc):
+ # handle zone transfers seperately
+ qname = query.question.qname
+ querytype = query.question.qtype
+ queryclass = query.question.qclass
+ if querytype in ['AXFR','IXFR']:
+ for zkey in self.zdict.keys():
+ if zkey in zkeys:
+ if qname == self.zdict[zkey]['origin']:
+ answerlist = self.zonetrans(query)
+ break
+ else:
+ answerlist = []
+ cbfunc(query, addr, server, dorecursion, flist, answerlist)
+ else:
+ zonekey = self.zmatch(qname, zkeys)
+ if zonekey:
+ origin = self.zdict[zonekey]['origin']
+ zonedict = self.zdict[zonekey]['zonedata']
+ referral = 0
+ rranswerlist = []
+ rrnslist = []
+ rraddlist = []
+ answer = message()
+ answer.header.aa = 1
+ answer.header.id = query.header.id
+ answer.header.qr = 1
+ answer.header.opcode = query.header.opcode
+ answer.header.rcode = 4
+ answer.header.ra = dorecursion
+ answer.question.qname = query.question.qname
+ answer.question.qtype = query.question.qtype
+ answer.question.qclass = query.question.qclass
+ answer.header.ra = dorecursion
+ s = '.servers.csail.mit.edu'
+ if qname.endswith(s):
+ host = qname[:-len(s)]
+ value = sipb_xen_database.NIC.get_by(hostname=host)
+ if value is None:
+ pass
+ else:
+ ip = value.ip
+ rranswerlist.append({qname: {'A': [{'address': ip,
+ 'class': 'IN',
+ 'ttl': 10}]}})
+ if zonedict.has_key(qname):
+ # found the node, now take care of CNAMEs
+ if zonedict[qname].has_key('CNAME'):
+ if querytype != 'CNAME':
+ nodetype = 'CNAME'
+ while nodetype == 'CNAME':
+ rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}})
+ qname = zonedict[qname]['CNAME'][0]['cname']
+ if zonedict.has_key(qname):
+ nodetype = zonedict[qname].keys()[0]
+ else:
+ # error, shouldn't have a CNAME that points to nothing
+ return
+ # if we get this far, then the record has matched and we should return
+ # a reply that has no error (even if there is no info macthing the qtype)
+ answer.header.rcode = 0
+ answernode = zonedict[qname]
+ if querytype == 'ANY':
+ for type in answernode.keys():
+ for rec in answernode[type]:
+ rranswerlist.append({qname:{type:[rec]}})
+ elif answernode.has_key(querytype):
+ for rec in answernode[querytype]:
+ rranswerlist.append({qname:{querytype:[rec]}})
+ # do rrset ordering (cyclic)
+ if len(answernode[querytype]) > 1:
+ rec = answernode[querytype].pop(0)
+ answernode[querytype].append(rec)
+ else:
+ # remove all cname rrs from answerlist
+ rranswerlist = []
+ else:
+ # would check for wildcards here (but aren't because they seem bad)
+ # see if we need to give a referral
+ zlist = self.getzlist(qname,origin)
+ for zonename in zlist:
+ if zonedict.has_key(zonename):
+ if zonedict[zonename].has_key('NS'):
+ answer.header.rcode = 0
+ referral = 1
+ for rec in zonedict[zonename]['NS']:
+ rrnslist.append({zonename:{'NS':[rec]}})
+ nsdname = rec['nsdname']
+ # add glue records if they exist
+ if zonedict.has_key(nsdname):
+ if zonedict[nsdname].has_key('A'):
+ for gluerec in zonedict[nsdname]['A']:
+ rraddlist.append({nsdname:{'A':[gluerec]}})
+ # negative caching stuff
+ if not referral:
+ if not rranswerlist:
+ # NOTE: RFC1034 section 4.3.4 says we should add the SOA record
+ # to the additional section of the response. BIND adds
+ # it to the ns section though
+ answer.header.rcode = 3
+ rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}})
+ else:
+ for rec in zonedict[origin]['NS']:
+ rrnslist.append({origin:{'NS':[rec]}})
+ answer.header.ancount = len(rranswerlist)
+ answer.header.nscount = len(rrnslist)
+ answer.header.arcount = len(rraddlist)
+ answer.answerlist = rranswerlist
+ answer.authlist = rrnslist
+ answer.addlist = rraddlist
+ cbfunc(query, addr, server, dorecursion, flist, [answer])
+ else:
+ cbfunc(query, addr, server, dorecursion, flist, [])
+
+ def handle_update(self, msg, addr, ns):
+ zkey = ''
+ slaves = []
+ for zonekey in self.zdict.keys():
+ if (self.zdict[zonekey]['type'] == 'master' and
+ self.zdict[zonekey]['origin'] == msg.zone.zname):
+ zkey = zonekey
+ if not zkey:
+ log(2,'SENDING NOTAUTH UPDATE ERROR')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 9)
+ return errormsg, '', slaves
+ # find the slaves for the zone
+ if self.zdict[zkey].has_key('slaves'):
+ slaves = self.zdict[zkey]['slaves']
+ origin = self.zdict[zkey]['origin']
+ zd = self.zdict[zkey]['zonedata']
+ # check the permissions
+ if not ns.config.allowupdate(msg, addr[0], addr[1]):
+ log(2,'SENDING REFUSED UPDATE ERROR')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 5)
+ return errormsg, origin, slaves
+ # now check the prereqs
+ temprrset = {}
+ for rr in msg.prlist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if dbrec['ttl'] != 0:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ if rrname.rfind(msg.zone.zname) == -1:
+ log(2,'NOTZONE(10)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 10)
+ return errormsg, origin, slaves
+ if dbrec['class'] == 'ANY':
+ if dbrec['rdata']:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ if rrtype == 'ANY':
+ if not zd.has_key(rrname):
+ log(2,'NXDOMAIN(3)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 3)
+ return errormsg, origin, slaves
+ else:
+ rrsettest = 0
+ if zd.has_key(rrname):
+ if zd[rrname].has_key(rrtype):
+ rrsettest = 1
+ if not rrsettest:
+ log(2,'NXRRSET(8)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 8)
+ return errormsg, origin, slaves
+ if dbrec['class'] == 'NONE':
+ if dbrec['rdata']:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ if rrtype == 'ANY':
+ if zd.has_key(rrname):
+ log(2,'YXDOMAIN(6)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 6)
+ return errormsg, origin, slaves
+ else:
+ if zd.has_key(rrname):
+ if zd[rrname].has_key(rrtype):
+ log(2,'YXRRSET(7)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 7)
+ return errormsg, origin, slaves
+ if dbrec['class'] == msg.zone.zclass:
+ if temprrset.has_key(rrname):
+ if not temprrset[rrname].has_key(rrtype):
+ temprrset[rrname][rrtype] = []
+ else:
+ temprrset[rrname] = {}
+ temprrset[rrname][rrtype] = []
+ temprrset[rrname][rrtype].append(dbrec)
+ else:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ for nodename in temprrset.keys():
+ if not self.rrmatch(temprrset[nodename],zd[nodename]):
+ log(2,'NXRRSET(8)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 8)
+ return errormsg, origin, slaves
+
+ # update section prescan
+ for rr in msg.uplist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if rrname.rfind(msg.zone.zname) == -1:
+ log(2,'NOTZONE(10)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 10)
+ return errormsg, origin, slaves
+ if dbrec['class'] == msg.zone.zclass:
+ if rrtype in ['ANY','MAILA','MAILB','AXFR']:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ elif dbrec['class'] == 'ANY':
+ if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ elif dbrec['class'] == 'NONE':
+ if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+ else:
+ log(2,'FORMERROR(1)')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ return errormsg, origin, slaves
+
+ # now handle actual update
+ curserial = zd[msg.zone.zname]['SOA'][0]['serial']
+ # update the soa serial here
+ clearupdatehist = 0
+ if len(msg.uplist) > 0:
+ # initialize history structure
+ if not self.updates.has_key(zkey):
+ self.updates[zkey] = {}
+ self.updates[zkey][curserial] = {'removed':[],
+ 'added':[]}
+ if curserial == 2**32:
+ newserial = 2
+ clearupdatehist = 1
+ else:
+ newserial = curserial + 1
+ self.updates[zkey][newserial] = {'removed':[],
+ 'added':[]}
+ zd[msg.zone.zname]['SOA'][0]['serial'] = newserial
+ for rr in msg.uplist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if dbrec['class'] == msg.zone.zclass:
+ if rrtype == 'SOA':
+ if zd.has_key(rrname):
+ if zd[rrname].has_key('SOA'):
+ if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']:
+ del zd[rrname]['SOA'][0]
+ zd[rrname]['SOA'].append(dbrec)
+ clearupdatehist = 1
+ elif rrtype == 'WKS':
+ if zd.has_key(rrname):
+ if zd[rrname].has_key('WKS'):
+ rdata = zd[rrname]['WKS'][0]
+ oldrr = {rrname:{'WKS':[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname]['WKS'][0]
+ zd[rrname]['WKS'].append(dbrec)
+ newrr = {rrname:{'WKS':[dbrec]}}
+ self.updates[zkey][newserial]['added'].append(newrr)
+ else:
+ if zd.has_key(rrname):
+ if not zd[rrname].has_key(rrtype):
+ zd[rrname][rrtype] = []
+ else:
+ zd[rrname] = {}
+ zd[rrname][rrtype] = []
+ zd[rrname][rrtype].append(dbrec)
+ newrr = {rrname:{rrtype:[dbrec]}}
+ self.updates[zkey][newserial]['added'].append(newrr)
+ elif dbrec['class'] == 'ANY':
+ if rrtype == 'ANY':
+ if rrname == msg.zone.zname:
+ if zd.has_key(rrname):
+ for dnstype in zd[rrname].keys():
+ if dnstype not in ['SOA','NS']:
+ for rdata in zd[rrname][dnstype]:
+ oldrr = {rrname:{dnstype:[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname][dnstype]
+ else:
+ if zd.has_key(rrname):
+ for dnstype in zd[rrname].keys():
+ for rdata in zd[rrname][dnstype]:
+ oldrr = {rrname:{dnstype:[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname]
+ else:
+ if zd.has_key(rrname):
+ if zd[rrname].has_key(rrtype):
+ if rrname == msg.zone.zname:
+ if rrtype not in ['SOA','NS']:
+ for rdata in zd[rrname][dnstype]:
+ oldrr = {rrname:{dnstype:[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname][rrtype]
+ else:
+ for rdata in zd[rrname][dnstype]:
+ oldrr = {rrname:{dnstype:[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname][rrtype]
+ elif dbrec['class'] == 'NONE':
+ if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']):
+ if zd.had_key(rrname):
+ if zd[rrname].has_key(rrtype):
+ for i in range(len(zd[rrname][rrtype])):
+ if dbrec == zd[rrname][rrtype][i]:
+ rdata = zd[rrname][dnstype][i]
+ oldrr = {rrname:{dnstype:[rdata]}}
+ self.updates[zkey][curserial]['removed'].append(oldrr)
+ del zd[rrname][rrtype][i]
+ if len(zd[rrname][rrtype]) == 0:
+ del zd[rrname][rrtype]
+ if clearupdatehist:
+ self.updates[zkey] = {}
+ log(2,'SENDING UPDATE NOERROR MSG')
+ noerrormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 0)
+ return noerrormsg, origin, slaves
+
+class dnscache:
+ def __init__(self,cachezone):
+ self.cachedb = cachezone
+ # go through and set all of the root ttls to zero
+ for node in self.cachedb.keys():
+ for rtype in self.cachedb[node].keys():
+ for rr in self.cachedb[node][rtype]:
+ rr['ttl'] = 0
+ if rtype == 'NS':
+ rr['rtt'] = 0
+ # add special entries for localhost
+ self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]}
+ self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]}
+ self.cachedb['']['SOA'] = []
+ self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb',
+ 'rname':'cachedb@localhost','serial':1,'refresh':10800,
+ 'retry':3600,'expire':604800,'minimum':3600})
+
+ def hasrdata(self, irrdata, rrdatalist):
+ # compare everything but ttls
+ test = 0
+ testrrdata = irrdata.copy()
+ del testrrdata['ttl']
+ for rrdata in rrdatalist:
+ temprrdata = rrdata.copy()
+ del temprrdata['ttl']
+ if temprrdata == testrrdata:
+ test = 1
+ return test
+
+ def add(self, rr, qzone, nsdname):
+ # NOTE: can't cache records from sites
+ # that don't own those records (i.e. example.com
+ # can't give us A records for www.example.net)
+ name = rr.keys()[0]
+ if (qzone != '') and (name[-len(qzone):] != qzone):
+ log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone)
+ return
+ rtype = rr[name].keys()[0]
+ rdata = rr[name][rtype][0]
+ if rdata['ttl'] < 3600:
+ log(2,'low ttl: ' + str(rdata['ttl']))
+ rdata['ttl'] = 3600
+ rdata['ttl'] = int(time.time() + rdata['ttl'])
+ if rtype == 'NS':
+ rdata['rtt'] = 0
+ name = name.lower()
+ rtype = rtype.upper()
+ if self.cachedb.has_key(name):
+ if self.cachedb[name].has_key(rtype):
+ if not self.hasrdata(rdata, self.cachedb[name][rtype]):
+ self.cachedb[name][rtype].append(rdata)
+ log(3,'appended rdata to ' +
+ name + '(' + rtype + ') in cache')
+ else:
+ log(3,'same rdata for ' + name + '(' +
+ rtype + ') is already in cache')
+ else:
+ self.cachedb[name][rtype] = [rdata]
+ log(3,'appended ' + rtype + ' and rdata to node ' +
+ name + ' in cache')
+ else:
+ self.cachedb[name] = {rtype:[rdata]}
+ log(3,'added node ' + name + '(' + rtype + ') to cache')
+ self.reap()
+
+ def addneg(self, qname, querytype, queryclass):
+ if not self.cachedb.has_key(qname):
+ self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]}
+ else:
+ if not self.cachedb[qname].has_key(querytype):
+ self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}]
+
+ def haskey(self, qname, querytype, msg=''):
+ log(3,'looking for ' + qname + '(' + querytype + ') in cache')
+ if self.cachedb.has_key(qname):
+ rranswerlist = []
+ rrnslist = []
+ rraddlist = []
+ if self.cachedb[qname].has_key('CNAME'):
+ if querytype != 'CNAME':
+ nodetype = 'CNAME'
+ while nodetype == 'CNAME':
+ if len(self.cachedb[qname]['CNAME'][0].keys()) > 1:
+ log(3,'Adding CNAME to cache answer')
+ rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}})
+ qname = self.cachedb[qname]['CNAME'][0]['cname']
+ if self.cachedb.has_key(qname):
+ nodetype = self.cachedb[qname].keys()[0]
+ else:
+ # shouldn't have a CNAME that points to nothing
+ return
+ if querytype == 'ANY':
+ for type in self.cache[qname].keys():
+ for rec in self.cachedb[qname][type]:
+ # can't append negative entries
+ if len(rec.keys()) > 1:
+ rranswerlist.append({qname:{type:[rec]}})
+ elif self.cachedb[qname].has_key(querytype):
+ for rec in self.cachedb[qname][querytype]:
+ if len(rec.keys()) > 1:
+ rranswerlist.append({qname:{querytype:[rec]}})
+ if rranswerlist:
+ if msg:
+ answer = message()
+ answer.header.id = msg.header.id
+ answer.header.qr = 1
+ answer.header.opcode = msg.header.opcode
+ answer.header.ra = 1
+ answer.question.qname = msg.question.qname
+ answer.question.qtype = msg.question.qtype
+ answer.question.qclass = msg.question.qclass
+ answer.header.rcode = 0
+ answer.header.ancount = len(rranswerlist)
+ answer.answerlist = rranswerlist
+ return answer
+ else:
+ return 1
+ else:
+ log(3,'Cache has no node for ' + qname)
+
+ def getnslist(self, qname):
+ # find the best nameserver to ask from the cache
+ tokens = qname.split('.')
+ nsdict = {}
+ curtime = time.time()
+ for i in range(len(tokens)):
+ domainname = '.'.join(tokens[i:])
+ if self.cachedb.has_key(domainname):
+ if self.cachedb[domainname].has_key('NS'):
+ for nsrec in self.cachedb[domainname]['NS']:
+ badserver = 0
+ if nsrec.has_key('badtill'):
+ if nsrec['badtill'] < curtime:
+ del nsrec['badtill']
+ else:
+ badserver = 1
+ if badserver:
+ log(2,'BAD SERVER, not using ' + nsrec['nsdname'])
+ if self.cachedb.has_key(nsrec['nsdname']) and not badserver:
+ if self.cachedb[nsrec['nsdname']].has_key('A'):
+ for arec in self.cachedb[nsrec['nsdname']]['A']:
+ nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'],
+ 'ip':arec['address']}
+ if nsdict:
+ break
+ if not nsdict:
+ domainname = ''
+ # nothing in the cache matches so give back the root servers
+ for nsrec in self.cachedb['']['NS']:
+ badserver = 0
+ if nsrec.has_key('badtill'):
+ if curtime > nsrec['badtill']:
+ del nsrec['badtill']
+ else:
+ badserver = 1
+ if not badserver:
+ for arec in self.cachedb[nsrec['nsdname']]['A']:
+ nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']}
+
+ return (domainname, nsdict)
+
+ def badns(self, zonename, nsdname):
+ if self.cachedb.has_key(zonename):
+ if self.cachedb[zonename].has_key('NS'):
+ for nsrec in self.cachedb[zonename]['NS']:
+ if nsrec['nsdname'] == nsdname:
+ log(2,'Setting ' + nsdname + ' as bad nameserver')
+ nsrec['badtill'] = time.time() + 3600
+
+
+ def updatertt(self, qname, zone, rtt):
+ if self.cachedb.has_key(zone):
+ if self.cachedb[zone].has_key('NS'):
+ for rr in self.cachedb[zone]['NS']:
+ if rr['nsdname'] == qname:
+ log(2,'updating rtt for ' + qname + ' to ' + str(rtt))
+ rr['rtt'] = rtt
+
+ def reap(self):
+ # expire all old records
+ ntime = time.time()
+ for nodename in self.cachedb.keys():
+ for rrtype in self.cachedb[nodename].keys():
+ for rdata in self.cachedb[nodename][rrtype]:
+ ttl = rdata['ttl']
+ if ttl != 0:
+ if ttl < ntime:
+ self.cachedb[nodename][rrtype].remove(rdata)
+ if len(self.cachedb[nodename][rrtype]) == 0:
+ del self.cachedb[nodename][rrtype]
+ if len(self.cachedb[nodename]) == 0:
+ del self.cachedb[nodename]
+
+ return
+
+ def zonetrans(self, queryid):
+ # build a list of messages
+ # each message contains one rr of the zone
+ # the first and last message are the
+ # SOA records
+ zonedata = self.cachedb
+ rrlist = []
+ soa = {'':{'SOA':[zonedata['']['SOA'][0]]}}
+ for nodename in zonedata.keys():
+ for rrtype in zonedata[nodename].keys():
+ if not (rrtype == 'SOA' and nodename == ''):
+ for rr in zonedata[nodename][rrtype]:
+ rrlist.append({nodename:{rrtype:[rr]}})
+ rrlist.insert(0,soa)
+ rrlist.append(soa)
+ msglist = []
+ for rr in rrlist:
+ msg = message()
+ msg.header.id = queryid
+ msg.header.qr = 1
+ msg.header.aa = 1
+ msg.header.rd = 0
+ msg.header.qdcount = 1
+ msg.question.qname = 'cache'
+ msg.question.qtype = 'AXFR'
+ msg.question.qclass = 'IN'
+ msg.header.ancount = 1
+ msg.answerlist.append(rr)
+ msglist.append(msg)
+ return msglist
+
+class gethostaddr(asyncore.dispatcher):
+ def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'):
+ asyncore.dispatcher.__init__(self)
+ self.msg = message()
+ self.msg.question.qname = hostname
+ self.msg.question.qtype = 'A'
+ self.cbfunc = cbfunc
+ self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
+ self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
+
+ def handle_read(self):
+ replydata, addr = self.socket.recvfrom(1500)
+ self.close()
+ try:
+ replymsg = message(replydata)
+ except:
+ log(0,'unable to process packet')
+ return
+ answername = replymsg.question.qname
+ cname = ''
+ # go through twice to catch cnames after A recs
+ for rr in replymsg.answerlist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if rrname == answername and rrtype == 'CNAME':
+ answername = dbrec['cname']
+ cname = answername
+ for rr in replymsg.answerlist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ dbrec = rr[rrname][rrtype][0]
+ if rrname == answername and rrtype == 'A':
+ self.cbfunc(dbrec['address'])
+ return
+ # if we got a cname and no A send query for cname
+ if cname:
+ self.msg = message()
+ self.msg.question.qname = cname
+ self.msg.question.qtype = 'A'
+ self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
+ else:
+ self.cbfunc('')
+
+ def writable(self):
+ return 0
+
+ def handle_write(self):
+ pass
+
+ def handle_connect(self):
+ pass
+
+ def handle_close(self):
+ self.close()
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class simpleudprequest(asyncore.dispatcher):
+ def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''):
+ asyncore.dispatcher.__init__(self)
+ self.gotanswer = 0
+ self.msg = msg
+ self.cbfunc = cbfunc
+ self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
+ self.outqkey = outqkey
+ self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
+
+ def handle_read(self):
+ replydata, addr = self.socket.recvfrom(1500)
+ self.close()
+ try:
+ replymsg = message(replydata)
+ except:
+ log(0,'unable to process packet')
+ return
+ self.cbfunc(replymsg, self.outqkey)
+
+ def writable(self):
+ return 0
+
+ def handle_write(self):
+ pass
+
+ def handle_connect(self):
+ pass
+
+ def handle_close(self):
+ self.close()
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class simpletcprequest(asyncore.dispatcher):
+ def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''):
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.query = msg
+ self.cbfunc = cbfunc
+ self.cbparams = cbparams
+ self.errorfunc = errorfunc
+ msgdata = msg.buildpkt()
+ ml = inttoasc(len(msgdata))
+ if len(ml) == 1:
+ ml = chr(0) + ml
+ self.buffer = ml+msgdata
+ self.rbuffer = ''
+ self.rmsgleft = 0
+ self.rrlist = []
+ log(2,'sending tcp request to ' + serveraddr)
+ self.connect((serveraddr,53))
+
+ def recv (self, buffer_size):
+ try:
+ data = self.socket.recv (buffer_size)
+ if not data:
+ # a closed connection is indicated by signaling
+ # a read condition, and having recv() return 0.
+ self.handle_close()
+ return ''
+ else:
+ return data
+ except socket.error, why:
+ # winsock sometimes throws ENOTCONN
+ if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]:
+ self.handle_close()
+ return ''
+ else:
+ raise socket.error, why
+
+ def handle_connect(self):
+ pass
+
+ def handle_msg(self, msg):
+ if self.query.question.qtype == 'AXFR':
+ if len(self.rrlist) == 0:
+ if len(msg.answerlist) == 0:
+ if self.errorfunc:
+ self.errorfunc(self.cbparams[0])
+ self.close()
+ return
+ rr = msg.answerlist[0]
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ self.rrlist.append(rr)
+ if rrtype == 'SOA' and len(self.rrlist) > 1:
+ self.close()
+ if self.cbparams:
+ self.cbfunc(self.rrlist, self.cbparams)
+ else:
+ self.cbfunc(self.rrlist)
+ else:
+ self.close()
+ if self.cbparams:
+ self.cbfunc(msg, self.cbparams)
+ else:
+ self.cbfunc(msg)
+
+ def handle_read(self):
+ data = self.recv(8192)
+ if len(self.rbuffer) == 0:
+ self.rmsglength = asctoint(data[:2])
+ data = data[2:]
+ self.rbuffer = self.rbuffer + data
+ while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0:
+ msgdata = self.rbuffer[:self.rmsglength]
+ self.rbuffer = self.rbuffer[self.rmsglength:]
+ if len(self.rbuffer) == 0:
+ self.rmsglength = 0
+ else:
+ self.rmsglength = asctoint(self.rbuffer[:2])
+ self.rbuffer = self.rbuffer[2:]
+ try:
+ self.handle_msg(message(msgdata))
+ except:
+ return
+
+ def writable(self):
+ return (len(self.buffer) > 0)
+
+ def handle_write(self):
+ sent = self.send(self.buffer)
+ self.buffer = self.buffer[sent:]
+
+ def handle_close(self):
+ if self.errorfunc:
+ self.errorfunc(self.query.question.qname)
+ self.close()
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class udpdnsserver(asyncore.dispatcher):
+ def __init__(self, port, dnsserver):
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
+ self.bind(('',port))
+ self.dnsserver = dnsserver
+ self.maxmsgsize = 500
+
+ def handle_read(self):
+ try:
+ while 1:
+ msgdata, addr = self.socket.recvfrom(1500)
+ self.dnsserver.handle_packet(msgdata, addr, self)
+ except socket.error, why:
+ if why[0] != asyncore.EWOULDBLOCK:
+ raise socket.error, why
+
+ def sendpackets(self, msglist, addr):
+ for msg in msglist:
+ msgdata = msg.buildpkt()
+ if len(msgdata) > self.maxmsgsize:
+ msg.header.tc = 1
+ # take off all the answers to ensure
+ # the packet size is small enough
+ msg.header.ancount = 0
+ msg.header.nscount = 0
+ msg.header.arcount = 0
+ msg.answerlist = []
+ msg.authlist = []
+ msg.addlist = []
+ msgdata = msg.buildpkt()
+ self.sendto(msgdata, addr)
+
+ def writable(self):
+ return 0
+
+ def handle_write(self):
+ pass
+
+ def handle_connect(self):
+ pass
+
+ def handle_close(self):
+ # print '1:In handle close'
+ return
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class tcpdnschannel(asynchat.async_chat):
+ def __init__(self, server, s, addr):
+ asynchat.async_chat.__init__(self, s)
+ self.server = server
+ self.addr = addr
+ self.set_terminator(None)
+ self.databuffer = ''
+ self.msglength = 0
+ log(3,'Created new tcp channel')
+
+ def collect_incoming_data(self, data):
+ if self.msglength == 0:
+ self.msglength = asctoint(data[:2])
+ data = data[2:]
+ self.databuffer = self.databuffer + data
+ if len(self.databuffer) == self.msglength:
+ # got entire message
+ self.server.dnsserver.handle_packet(self.databuffer, self.addr, self)
+ self.databuffer = ''
+
+ def sendpackets(self, msglist, addr):
+ for msg in msglist:
+ x = msg.buildpkt()
+ ml = inttoasc(len(x))
+ if len(ml) == 1:
+ ml = chr(0) + ml
+ self.push(ml+x)
+ self.close()
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class tcpdnsserver(asyncore.dispatcher):
+ def __init__(self, port, dnsserver):
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.set_reuse_addr()
+ self.bind(('',port))
+ self.listen(5)
+ self.dnsserver = dnsserver
+
+ def handle_accept(self):
+ conn, addr = self.accept()
+ tcpdnschannel(self, conn, addr)
+
+ def handle_close(self):
+ self.close()
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class nameserver:
+ def __init__(self, resolver, localconfig):
+ self.resolver = resolver
+ self.config = localconfig
+ self.zdb = self.config.zonedatabase
+ self.last_reap_time = time.time()
+ self.maint_int = 10
+ self.slavesupdating = []
+ self.notifys = []
+ self.sentnotify = []
+ self.notify_retry_time = 30
+ self.notify_retries = 4
+ self.askedsoa = {}
+ self.soatimeout = 10
+
+ def error(self, id, qname, querytype, queryclass, rcode):
+ error = message()
+ error.header.id = id
+ error.header.rcode = rcode
+ error.header.qr = 1
+ error.question.qname = qname
+ error.question.qtype = querytype
+ error.question.qclass = queryclass
+ return error
+
+ def need_zonetransfer(self, zkey, origin, masterip, trynum=0):
+ self.askedsoa[zkey] = {'masterip':masterip,
+ 'senttime':time.time(),
+ 'origin':origin,
+ 'trynum':trynum+1}
+ query = message()
+ query.header.id = random.randrange(1,32768)
+ query.header.rd = 0
+ query.question.qname = origin
+ query.question.qtype = 'SOA'
+ query.question.qclass = 'IN'
+ log(3,'slave checking for new data in ' + origin)
+ simpleudprequest(query, self.handle_soaquery,
+ masterip, zkey)
+
+ def handle_soaquery(self, msg, zkey):
+ origin = msg.question.qname
+ masterip = self.askedsoa[zkey]['masterip']
+ del self.askedsoa[zkey]
+ if zkey not in self.slavesupdating:
+ self.slavesupdating.append(zkey)
+ query = message()
+ query.header.id = random.randrange(1,32768)
+ query.header.rd = 0
+ query.question.qname = origin
+ query.question.qtype = 'AXFR'
+ query.question.qclass = 'IN'
+ log(3,'Updating slave zone: ' + zkey)
+ simpletcprequest(query, self.handle_zonetrans,
+ [zkey],masterip,self.handle_zterror)
+
+ def handle_zonetrans(self, rrlist, params):
+ log(1,'handling zone transfer')
+ zonekey = params[0]
+ self.zdb.update_zone(rrlist, params)
+ self.slavesupdating.remove(zonekey)
+
+ def handle_zterror(self, zonekey):
+ self.slavesupdating.remove(zonekey)
+ self.zdb.remove_zone(zonekey)
+
+ def rrmatch(self, rrset1, rrset2):
+ for rrtype in rrset1.keys():
+ if rrtype not in rrset2.keys():
+ return
+ else:
+ if len(rrset1[rrtype]) != len(rrset2[rrtype]):
+ return
+ return 1
+
+ def process_notify(self, msg, ipaddr, port):
+ (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port)
+ goodzkey = ''
+ for zkey in zkeys:
+ origin = self.zdb.getorigin(zkey)
+ if origin == msg.question.qname:
+ masterip = self.zdb.getmasterip(zkey)
+ if masterip:
+ goodzkey = zkey
+ if goodzkey:
+ log(3,'got NOTIFY from ' + masterip)
+ self.need_zonetransfer(goodzkey, origin, masterip, 0)
+ return
+
+ def notify(self):
+ curtime = time.time()
+ for origin, ipaddr, trynum, senttime in self.sentnotify:
+ if senttime + self.notify_retry_time > curtime:
+ self.notifys.append((origin, ipaddr, trynum))
+ self.sentnotify.remove((origin, ipaddr, trynum, senttime))
+ for origin, ipaddr, trynum in self.notifys:
+ msg = message()
+ msg.question.qname = origin
+ msg.question.qtype = 'SOA'
+ msg.question.qclass = 'IN'
+ msg.header.opcode = 4
+ # there probably is a better way to do this
+ if self.resolver:
+ self.resolver.send_to([msg],(ipaddr,53))
+ if trynum+1 <= self.notify_retries:
+ self.sentnotify.append((origin,ipaddr,trynum+1,curtime))
+ self.notifys = []
+
+ def handle_packet(self, msgdata, addr, server):
+ # self.reap()
+ try:
+ msg = message(msgdata)
+ except:
+ return
+ # find a matching view
+ (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1])
+ if not msg.header.qr and msg.header.opcode == 5:
+ log(2,'GOT UPDATE PACKET')
+ # check the zone section
+ if (msg.header.zocount != 1 or
+ msg.zone.ztype != 'SOA' or
+ msg.zone.zclass != 'IN'):
+ log(2,'SENDING FORMERR UPDATE ERROR')
+ errormsg = self.error(msg.header.id, msg.zone.zname,
+ msg.zone.ztype, msg.zone.zclass, 1)
+ server.sendpackets([errormsg],addr)
+ else:
+ (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self)
+ if answer.header.rcode == 0:
+ # schedule NOTIFYs to slaves
+ for ipaddr in slaves:
+ self.notifys.append((origin, ipaddr, 0))
+ server.sendpackets([answer],addr)
+ elif msg.header.opcode == 4:
+ if msg.header.qr:
+ log(0,'got NOTIFY response')
+ for origin, ipaddr, trynum, senttime in self.sentnotify:
+ if ipaddr == addr[0] and msg.question.qname == origin:
+ self.sentnotify.remove((origin, ipaddr, trynum, senttime))
+ else:
+ log(0,'got NOTIFY')
+ self.process_notify(msg, addr[0], addr[1])
+ elif not msg.header.qr and msg.header.opcode == 0:
+ # it's a question
+ qname = msg.question.qname.lower()
+ log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype +
+ ') from ' + addr[0])
+ # handle special version packet
+ if (msg.question.qtype == 'TXT' and
+ msg.question.qclass == 'CH'):
+ if qname == 'version.bind':
+ server.sendpackets([getversion(qname,
+ msg.header.id,
+ msg.header.rd,
+ dorecursion, '1.0')],addr)
+ elif qname == 'version.oak':
+ server.sendpackets([getversion(qname,
+ msg.header.id,
+ msg.header.rd,
+ dorecursion, '1.0')],addr)
+ return
+ self.zdb.lookup(zkeys, msg, addr, server, dorecursion,
+ flist, self.lookup_callback)
+
+ def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist):
+ if answerlist:
+ server.sendpackets(self.config.outpackets(answerlist), addr)
+ elif dorecursion:
+ if msg.question.qtype in ['AXFR','IXFR']:
+ if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR':
+ if self.resolver:
+ server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr)
+ else:
+ # won't forward zone transfers and
+ # don't handle recursive zone transfers
+ server.sendpackets([self.error(msg.header.id, msg.question.qname,
+ msg.question.qtype,
+ msg.question.qclass,2)],addr)
+ else:
+ self.resolver.handle_query(msg, addr, flist, server.sendpackets)
+
+ def reap(self):
+ log(4,'in nameserver reap')
+ # do all maintenence (interval) stuff here
+ if self.resolver:
+ self.resolver.reap()
+ self.notify()
+ curtime = time.time()
+ if curtime > (self.last_reap_time + self.maint_int):
+ self.last_reap_time = curtime
+ # do zone transfers here if slave server and haven't asked for soa
+ for (zkey, origin, masterip) in self.zdb.getslaves(curtime):
+ if not self.askedsoa.has_key(zkey):
+ self.need_zonetransfer(zkey, origin, masterip)
+ for zkey in self.askedsoa.keys():
+ if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout:
+ if self.askedsoa[zkey]['trynum'] > 3:
+ self.zdb.remove_zone(zkey)
+ del self.askedsoa[zkey]
+ else:
+ masterip = self.askedsoa[zkey]['masterip']
+ origin = self.askedsoa[zkey]['origin']
+ trynum = self.askedsoa[zkey]['trynum']
+ del self.askedsoa[zkey]
+ self.need_zonetransfer(zkey, origin, masterip, trynum)
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+class resolver(asyncore.dispatcher):
+ def __init__(self, cache, port=0):
+ asyncore.dispatcher.__init__(self)
+ self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
+ self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
+ self.bind(('',port))
+ self.cache = cache
+ self.outqnum = 0
+ self.outq = {}
+ self.holdq = {}
+ self.holdtime = 10
+ self.holdqlength = 100
+ self.last_reap_time = time.time()
+ self.maint_int = 10
+ self.timeout = 3
+
+ def getoutqkey(self):
+ self.outqnum = self.outqnum + 1
+ if self.outqnum == 99999:
+ self.outqnum = 1
+ return str(self.outqnum)
+
+ def error(self, id, qname, querytype, queryclass, rcode):
+ error = message()
+ error.header.id = id
+ error.header.rcode = rcode
+ error.header.qr = 1
+ error.question.qname = qname
+ error.question.qtype = querytype
+ error.question.qclass = queryclass
+ return error
+
+ def qpacket(self, id, qname, querytype, queryclass):
+ # create a question
+ query = message()
+ query.header.id = id
+ query.header.rd = 0
+ query.question.qname = qname
+ query.question.qtype = querytype
+ query.question.qclass = queryclass
+ return query
+
+ def send_to(self, msglist, addr):
+ for msg in msglist:
+ data = msg.buildpkt()
+ if len(data) > 512:
+ # packet to big
+ msg.header.tc = 1
+ msg.header.ancount = 0
+ msg.answerlist = []
+ msg.header.nscount = 0
+ msg.authlist = []
+ msg.header.arcount = 0
+ msg.addlist = []
+ self.socket.sendto(msg.buildpkt(), addr)
+ else:
+ self.socket.sendto(data, addr)
+
+ def handle_read(self):
+ try:
+ while 1:
+ msgdata, addr = self.socket.recvfrom(1500)
+ # should put 'try' here in production server
+ self.handle_packet(msgdata, addr)
+ except socket.error, why:
+ if why[0] != asyncore.EWOULDBLOCK:
+ raise socket.error, why
+
+ def handle_packet(self, msgdata, addr):
+ try:
+ msg = message(msgdata)
+ except:
+ return
+ if not msg.header.qr:
+ self.handle_query(msg, addr, [], self.send_to)
+ else:
+ log(2,'received unsolicited reply')
+
+
+ def handle_query(self, msg, addr, flist, cbfunc):
+ qname = msg.question.qname
+ querytype = msg.question.qtype
+ queryclass = msg.question.qclass
+ # check the cache first
+ answer = self.cache.haskey(qname,querytype,msg)
+ if answer:
+ cbfunc([answer], addr)
+ log(2,'sent answer for ' + qname + '(' + querytype +
+ ') from cache')
+ else:
+ # check if query is already in progess
+ for oqkey in self.outq.keys():
+ if (self.outq[oqkey]['qname'] == qname and
+ self.outq[oqkey]['querytype'] == querytype):
+ log(2,'query already in progress for '+qname+'('+querytype+')')
+ # put entry in hold queue to try later
+ hqrec = {'processtime':time.time()+self.holdtime,
+ 'query':msg,'addr':addr,
+ 'qname':qname,'querytype':querytype,
+ 'queryclass':queryclass,
+ 'cbfunc':cbfunc}
+ self.putonhold(hqrec)
+ return
+
+ outqkey = self.getoutqkey()+str(msg.header.id)
+ self.outq[outqkey] = {'query':msg,
+ 'addr':addr,
+ 'qname':qname,
+ 'querytype':querytype,
+ 'queryclass':queryclass,
+ 'cbfunc':cbfunc,
+ 'answerlist':[],
+ 'addlist':[],
+ 'qsent':0}
+ if flist:
+ self.outq[outqkey]['flist'] = flist
+ self.askfns(outqkey)
+ else:
+ self.askns(outqkey)
+
+ def putonhold(self,hqrec):
+ hqid = hqrec['qname']+hqrec['querytype']
+ if self.holdq.has_key(hqid):
+ if len(self.holdq[hqid]) < self.holdqlength:
+ hqrec['processtime']=time.time()+self.holdtime
+ self.holdq[hqid].append(hqrec)
+
+
+ def askns(self, outqkey):
+ qname = self.outq[outqkey]['qname']
+ querytype = self.outq[outqkey]['querytype']
+ queryclass = self.outq[outqkey]['queryclass']
+ # don't try more than 10 times to avoid loops
+ if self.outq[outqkey]['qsent'] == 10:
+ del self.outq[outqkey]
+ log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
+ ' POSSIBLE LOOP')
+ return
+ # find the best nameservers to ask from the cache
+ (qzone, nsdict) = self.cache.getnslist(qname)
+ if not nsdict:
+ # there are no good servers
+ if self.outq[outqkey]['addr'] != 'IQ':
+ qid = self.outq[outqkey]['query'].header.id
+ self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
+ self.outq[outqkey]['addr'])
+ del self.outq[outqkey]
+ log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
+ 'no good name servers to ask')
+ return
+ # pick the best nameserver
+ rtts = nsdict.keys()
+ rtts.sort()
+ bestnsip = nsdict[rtts[0]]['ip']
+ bestnsname = nsdict[rtts[0]]['name']
+ # fill in the callback data structure
+ id=random.randrange(1,32768)
+ self.outq[outqkey]['nsqueriedlastip'] = bestnsip
+ self.outq[outqkey]['nsqueriedlastname'] = bestnsname
+ self.outq[outqkey]['nsdict'] = nsdict
+ self.outq[outqkey]['qzone'] = qzone
+ self.outq[outqkey]['qsenttime'] = time.time()
+ self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1
+ # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53))
+ self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
+ self.handle_response, bestnsip, outqkey)
+ # update rtt so that we ask a different server next time
+ self.cache.updatertt(bestnsname,qzone,1)
+ log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname +
+ ') for ' + qname + '(' + querytype + ')')
+
+ def askfns(self, outqkey):
+ flist = self.outq[outqkey]['flist']
+ qname = self.outq[outqkey]['qname']
+ querytype = self.outq[outqkey]['querytype']
+ queryclass = self.outq[outqkey]['queryclass']
+ self.outq[outqkey]['qsenttime'] = time.time()
+ id=random.randrange(1,32768)
+ # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53))
+ self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
+ self.handle_fresponse, flist[0], outqkey)
+ log(2,''+outqkey+'|sent query to forwarder')
+
+ def handle_response(self, msg, outqkey):
+ # either reponse:
+ # 1. contains a name error
+ # 2. answers the question
+ # (cache data and return it)
+ # 3. is (contains) a CNAME and qtype isn't
+ # (cache cname and change qname to it)
+ # (check if qname and qtype are in any other rrs in the response)
+ # (must check cache again here)
+ # 4. contains a better delegation
+ # (cache the delegation and start again)
+ # 5. is aserver failure
+ # (delete server from list and try again)
+
+ # make sure that original question is still outstanding
+ if not self.outq.has_key(outqkey):
+ # should never get here
+ # if we do we aren't doing housekeeping of callbacks very well
+ log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname)
+ return
+
+ querytype = self.outq[outqkey]['querytype']
+ if msg.header.rcode not in [1,2,4,5]:
+ # update rtt time
+ rtt = time.time() - self.outq[outqkey]['qsenttime']
+ nsname = self.outq[outqkey]['nsqueriedlastname']
+ zone = self.outq[outqkey]['qzone']
+ self.cache.updatertt(nsname,zone,rtt)
+
+ if msg.header.rcode == 3:
+ log(2,outqkey+'|GOT Name Error for ' + msg.question.qname +
+ '(' + msg.question.qtype + ')')
+ # name error
+ # cache negative answer
+ self.cache.addneg(self.outq[outqkey]['qname'],
+ self.outq[outqkey]['querytype'],
+ self.outq[outqkey]['queryclass'])
+ if self.outq[outqkey]['addr'] != 'IQ':
+ answer = message()
+ answer.question.qname = self.outq[outqkey]['query'].question.qname
+ answer.question.qtype = self.outq[outqkey]['query'].question.qtype
+ answer.question.qclass = self.outq[outqkey]['query'].question.qclass
+ answer.header.id = self.outq[outqkey]['query'].header.id
+ answer.header.qr = 1
+ answer.header.opcode = self.outq[outqkey]['query'].header.opcode
+ answer.header.ra = 1
+ self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
+ del self.outq[outqkey]
+
+ elif msg.header.ancount > 0:
+ # answer (may be CNAME)
+ haveanswer = 0
+ cname = ''
+ log(2,'CACHING ANSWERLIST ENTRIES')
+ for rr in msg.answerlist:
+ rrname = rr.keys()[0]
+ rrtype = rr[rrname].keys()[0]
+ if ((rrname == msg.question.qname or rrname == cname ) and
+ rrtype == msg.question.qtype):
+ haveanswer = 1
+ if rrname == msg.question.qname and rrtype == 'CNAME':
+ cname = rr[rrname][rrtype][0]['cname']
+ self.cache.add(rr, self.outq[outqkey]['qzone'],
+ self.outq[outqkey]['nsqueriedlastname'])
+ if haveanswer:
+ if self.outq[outqkey]['addr'] != 'IQ':
+ log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname +
+ '(' + msg.question.qtype + ')' )
+ answer = message()
+ answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist']
+ answer.header.ancount = len(answer.answerlist)
+ answer.question.qname = self.outq[outqkey]['query'].question.qname
+ answer.question.qtype = self.outq[outqkey]['query'].question.qtype
+ answer.question.qclass = self.outq[outqkey]['query'].question.qclass
+ answer.header.id = self.outq[outqkey]['query'].header.id
+ answer.header.qr = 1
+ answer.header.opcode = self.outq[outqkey]['query'].header.opcode
+ answer.header.ra = 1
+ self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
+ log(2,outqkey+'|sent answer retrieved from remote server for ' +
+ self.outq[outqkey]['query'].question.qname)
+ else:
+ log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' +
+ msg.question.qtype + ')')
+ del self.outq[outqkey]
+ elif cname:
+ log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')')
+ self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist
+ self.outq[outqkey]['qname'] = cname
+ self.askns(outqkey)
+ else:
+ log(2,outqkey+'|GOT BOGUS answer for ' + msg.question.qname + '(' +
+ msg.question.qtype + ')')
+ del self.outq[outqkey]
+
+ elif msg.header.nscount > 0 and msg.header.ancount == 0:
+ log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')')
+ # delegation
+ # cache the nameserver rrs and start over
+ # if there are no glue records for nameservers must fetch them first
+ log(2,'CACHING AUTHLIST ENTRIES')
+ for rr in msg.authlist:
+ self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
+ log(2,'CACHING ADDLIST ENTRIES')
+ for rr in msg.addlist:
+ self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
+ rrlist = msg.authlist+msg.addlist
+ fetchglue = 0
+ nscount = 0
+ for rr in msg.authlist:
+ nodename = rr.keys()[0]
+ if rr[nodename].keys()[0] == 'NS':
+ nscount = nscount + 1
+ nsdname = rr[nodename]['NS'][0]['nsdname']
+ if not self.cache.haskey(nsdname,'A'):
+ log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)')
+ fetchglue = fetchglue + 1
+ # need to fetch A rec
+ noutqkey = self.getoutqkey()+str(random.randrange(1,32768))
+ self.outq[noutqkey] = {'query':'',
+ 'addr':'IQ',
+ 'qname':nsdname,
+ 'querytype':'A',
+ 'queryclass':'IN',
+ 'qsent':0}
+ log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)')
+ self.askns(noutqkey)
+ if not nscount:
+ log(2,outqkey+'|Dropping query (no ns recs) for ' +
+ msg.question.qname + '(' + msg.question.qtype + ')' )
+ del self.outq[outqkey]
+ elif fetchglue == nscount:
+ log(2,outqkey+'|Stalling query (no glue recs) for ' +
+ msg.question.qname + '(' + msg.question.qtype + ')')
+ self.putonhold(self.outq[outqkey])
+ del self.outq[outqkey]
+ else:
+ log(2,outqkey+'|got (some) glue with delegation')
+ self.askns(outqkey)
+
+ elif msg.header.rcode in [1,2,4,5]:
+ log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode))
+ log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' +
+ self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname)
+ # don't ask this server for a while
+ self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
+ self.askns(outqkey)
+ else:
+ log(2,outqkey+'|GOT UNPARSEABLE REPLY')
+ msg.printpkt()
+
+ def handle_fresponse(self, msg, outqkey):
+ if msg.header.rcode in [1,2,4,5]:
+ self.outq[outqkey]['flist'].pop(0)
+ if len(self.outq[outqkey]['flist']) == 0:
+ qid = self.outq[outqkey]['query'].header.id
+ qname = self.outq[outqkey]['qname']
+ querytype = self.outq[outqkey]['querytype']
+ queryclass = self.outq[outqkey]['queryclass']
+ self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
+ self.outq[outqkey]['addr'])
+ del self.outq[outqkey]
+ else:
+ self.askfns(outqkey)
+ else:
+ answer = message()
+ answer.header.id = self.outq[outqkey]['query'].header.id
+ answer.header.qr = 1
+ answer.header.opcode = self.outq[outqkey]['query'].header.opcode
+ answer.header.ra = 1
+ answer.question.qname = self.outq[outqkey]['query'].question.qname
+ answer.question.qtype = self.outq[outqkey]['query'].question.qtype
+ answer.question.qclass = self.outq[outqkey]['query'].question.qclass
+ answer.header.ancount = msg.header.ancount
+ answer.header.nscount = msg.header.nscount
+ answer.header.arcount = msg.header.arcount
+ answer.answerlist = msg.answerlist
+ answer.authlist = msg.authlist
+ answer.addlist = msg.addlist
+ if msg.header.rcode == 3:
+ # name error
+ # cache negative answer
+ self.cache.addneg(self.outq[outqkey]['qname'],
+ self.outq[outqkey]['querytype'],
+ self.outq[outqkey]['queryclass'])
+ else:
+ # cache all rrs
+ for rr in msg.answerlist:
+ self.cache.add(rr,'','forwarder')
+ for rr in msg.authlist:
+ self.cache.add(rr,'','forwarder')
+ for rr in msg.addlist:
+ self.cache.add(rr,'','forwarder')
+ self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
+ del self.outq[outqkey]
+
+ def writable(self):
+ return 0
+
+ def handle_write(self):
+ pass
+
+ def handle_connect(self):
+ pass
+
+ def handle_close(self):
+ # print '1:In handle close'
+ return
+
+ def process_holdq(self):
+ curtime = time.time()
+ for hqkey in self.holdq.keys():
+ for hqrec in self.holdq[hqkey]:
+ if curtime >= hqrec['processtime']:
+ log(2,'processing held query')
+ answer = self.cache.haskey(hqrec['qname'],
+ hqrec['querytype'],
+ hqrec['query'])
+ if answer:
+ hqrec['cbfunc']([answer], hqrec['addr'])
+ log(2,'sent answer for ' + hqrec['qname'] +
+ '(' + hqrec['querytype'] + ') from cache')
+ self.holdq[hqkey].remove(hqrec)
+ if len(self.holdq[hqkey]) == 0:
+ del self.holdq[hqkey]
+
+ def reap(self):
+ self.process_holdq()
+ curtime = time.time()
+ log(3,timestamp() + 'processed HOLDQ (sockets: ' +
+ str(len(asyncore.socket_map.keys()))+')')
+ if curtime > (self.last_reap_time + self.maint_int):
+ self.last_reap_time = curtime
+ for outqkey in self.outq.keys():
+ if curtime > self.outq[outqkey]['qsenttime'] + self.timeout:
+ log(2,'query for '+self.outq[outqkey]['qname']+'('+
+ self.outq[outqkey]['querytype']+') expired')
+ # don't set forwarders as bad
+ if not self.outq[outqkey].has_key('flist'):
+ self.cache.badns(self.outq[outqkey]['qzone'],
+ self.outq[outqkey]['nsqueriedlastname'])
+ if self.outq[outqkey].has_key('request'):
+ log(3,'closing socket for expired query')
+ self.outq[outqkey]['request'].close()
+ del self.outq[outqkey]
+ return
+
+ def log_info (self, message, type='info'):
+ if __debug__ or type != 'info':
+ log(0,'%s: %s' % (type, message))
+
+
+def run(configobj):
+ global loglevel
+ r = resolver(dnscache(configobj.cached))
+ ns = nameserver(r, configobj)
+ udpds = udpdnsserver(53, ns)
+ tcpds = tcpdnsserver(53, ns)
+ loglevel = configobj.loglevel
+ try:
+ loop(ns.reap)
+ except KeyboardInterrupt:
+ print 'server done'
+
+if __name__ == '__main__':
+ sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
+ zonedict = {'example.net':{'origin':'example.net',
+ 'filename':'db.example.net',
+ 'type':'master',
+ 'slaves':[]}}
+
+
+ zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu',
+ 'filename':'db.servers.csail.mit.edu',
+ 'type':'master',
+ 'slaves':[]}}
+
+ zonedict2 = {'example.net':{'origin':'example.net',
+ 'filename':'db.example.net',
+ 'type':'slave',
+ 'masterip':'127.0.0.1'}}
+ readzonefiles(zonedict)
+ lconfig = dnsconfig()
+ lconfig.zonedatabase = zonedb(zonedict)
+ pr = zonefileparser()
+ pr.parse('','db.ca')
+ lconfig.cached = pr.getzdict()
+ lconfig.loglevel = 3
+
+ run(lconfig)