+++ /dev/null
-#!/usr/bin/python
-# Python Domain Name Server
-# Copyright (C) 2002 Digital Lumber, Inc.
-
-# This library is free software; you can redistribute it and/or
-# modify it under the terms of the GNU Lesser General Public
-# License as published by the Free Software Foundation; either
-# version 2.1 of the License, or (at your option) any later version.
-
-# This library is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-# Lesser General Public License for more details.
-
-# You should have received a copy of the GNU Lesser General Public
-# License along with this library; if not, write to the Free Software
-# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
-
-import socket
-import asyncore
-import asynchat
-import select
-import types
-import random
-import time
-import signal
-import string
-import sys
-import sipb_xen_database
-from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
- ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT
-
-# EXAMPLE ZONE FILE DATA STRUCTURE
-
-# NOTE:
-# There are no trailing dots in the internal data
-# structure. Although it's hard to tell by reading
-# the RFC's, the dots on the end of names are just
-# used internally by the resolvers and servers to
-# see if they need to append a domain name onto
-# the end of names. There are no trailing dots
-# on names in queries on the network.
-
-examplenet = {'example.net':{'SOA':[{'class':'IN',
- 'ttl':10,
- 'mname':'ns1.example.net',
- 'rname':'hostmaster.example.net',
- 'serial':1,
- 'refresh':10800,
- 'retry':3600,
- 'expire':604800,
- 'minimum':3600}],
- 'NS':[{'class':'IN',
- 'ttl':10,
- 'nsdname':'ns1.example.net'},
- {'ttl':10,
- 'nsdname':'ns2.example.net'}],
- 'MX':[{'class':'IN',
- 'ttl':10,
- 'preference':10,
- 'exchange':'mail.example.net'}]},
- 'server1.example.net':{'A':[{'class':'IN',
- 'ttl':10,
- 'address':'10.1.2.3'}]},
- 'www.example.net':{'CNAME':[{'class':'IN',
- 'ttl':10,
- 'cname':'server1.example.net'}]},
- 'router.example.net':{'A':[{'class':'IN',
- 'ttl':10,
- 'address':'10.1.2.1'},
- {'class':'IN',
- 'ttl':10,
- 'address':'10.2.1.1'}]}
-
- }
-
-# setup logging defaults
-loglevel = 0
-logfile = sys.stdout
-
-try:
- file
-except NameError:
- def file(name, mode='r', buffer=0):
- return open(name, mode, buffer)
-
-def log(level,msg):
- if level <= loglevel:
- logfile.write(msg+'\n')
-
-def timestamp():
- return time.strftime('%m/%d/%y %H:%M:%S')+ '-'
-
-def inttoasc(number):
- try:
- hs = hex(number)[2:]
- except:
- log(0,'inttoasc cannot convert ' + repr(number))
- if hs[-1:].upper() == 'L':
- hs = hs[:-1]
- result = ''
- while len(hs) > 2:
- result = chr(int(hs[-2:],16)) + result
- hs = hs[:-2]
- result = chr(int(hs,16)) + result
-
- return result
-
-def asctoint(ascnum):
- rascnum = ''
- for i in range(len(ascnum)-1,-1,-1):
- rascnum = rascnum + ascnum[i]
- result = 0
- count = 0
- for c in rascnum:
- x = ord(c) << (8*count)
- result = result + x
- count = count + 1
-
- return result
-
-def ipv6net_aton(ip_string):
- packed_ip = ''
- # first account for shorthand syntax
- pieces = ip_string.split(':')
- pcount = 0
- for part in pieces:
- if part != '':
- pcount = pcount + 1
- if pcount < 8:
- rs = '0:'*(8-pcount)
- ip_string = ip_string.replace('::',':'+rs)
- if ip_string[0] == ':':
- ip_string = ip_string[1:]
- pieces = ip_string.split(':')
- for part in pieces:
- # pad with the zeros
- i = 4-len(part)
- part = i*'0'+part
- packed_ip = packed_ip + chr(int(part[:2],16))+ chr(int(part[2:],16))
- return packed_ip
-
-def ipv6net_ntoa(packed_ip):
- ip_string = ''
- count = 0
- for c in packed_ip:
- ip_string = ip_string + hex(ord(c))[2:]
- count = count + 1
- if count == 2:
- ip_string = ip_string + ':'
- count = 0
- return ip_string[:-1]
-
-def getversion(qname, id, rd, ra, versionstr):
- msg = message()
- msg.header.id = id
- msg.header.qr = 1
- msg.header.aa = 1
- msg.header.rd = rd
- msg.header.ra = ra
- msg.header.rcode = 0
- msg.question.qname = qname
- msg.question.qtype = 'TXT'
- msg.question.qclass = 'CH'
- if qname == 'version.bind':
- msg.header.ancount = 2
- msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak',
- 'ttl':360000,
- 'class':'CH'}]}})
- msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr,
- 'ttl':360000,
- 'class':'CH'}]}})
- else:
- msg.header.ancount = 1
- msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr,
- 'ttl':360000,
- 'class':'CH'}]}})
- return msg
-
-def getrcode(rcode):
- if rcode == 0:
- rcodestr = 'NOERROR(No error condition)'
- elif rcode == 1:
- rcodestr = 'FORMERR(Format Error)'
- elif rcode == 2:
- rcodestr = 'SERVFAIL(Internal failure)'
- elif rcode == 3:
- rcodestr = 'NXDOMAIN(Name does not exist)'
- elif rcode == 4:
- rcodestr = 'NOTIMP(Not Implemented)'
- elif rcode == 5:
- rcodestr = 'REFUSED(Security violation)'
- elif rcode == 6:
- rcodestr = 'YXDOMAIN(Name exists)'
- elif rcode == 7:
- rcodestr = 'YXRRSET(RR exists)'
- elif rcode == 8:
- rcodestr = 'NXRRSET(RR does not exist)'
- elif rcode == 9:
- rcodestr = 'NOTAUTH(Server not Authoritative)'
- elif rcode == 10:
- rcodestr = 'NOTZONE(Name not in zone)'
- else:
- rcodestr = 'Unknown RCODE(' + str(rcode) + ')'
- return rcodestr
-
-def printrdata(dnstype, rdata):
- if dnstype == 'A':
- return rdata['address']
- elif dnstype == 'MX':
- return str(rdata['preference'])+'\t'+rdata['exchange']+'.'
- elif dnstype == 'NS':
- return rdata['nsdname']+'.'
- elif dnstype == 'PTR':
- return rdata['ptrdname']+'.'
- elif dnstype == 'CNAME':
- return rdata['cname']+'.'
- elif dnstype == 'SOA':
- return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+
- 35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+
- str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )')
-
-def makezonedatalist(zonedata, origin):
- # unravel structure into list
- zonedatalist = []
- # get soa first
- soanode = zonedata[origin]
- zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]])
- for item in soanode.keys():
- if item != 'SOA':
- for listitem in soanode[item]:
- zonedatalist.append([origin+'.', item, listitem])
- for nodename in zonedata.keys():
- if nodename != origin:
- for item in zonedata[nodename].keys():
- for listitem in zonedata[nodename][item]:
- zonedatalist.append([nodename+'.', item, listitem])
- return zonedatalist
-
-def writezonefile(zonedata, origin, file):
- zonedatalist = makezonedatalist(zonedata, origin)
- for rr in zonedatalist:
- owner = rr[0]
- dnstype = rr[1]
- line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' +
- dnstype + '\t' + printrdata(dnstype, rr[2]))
- file.write(line + '\n')
-
-def readzonefiles(zonedict):
- for k in zonedict.keys():
- filepath = zonedict[k]['filename']
- try:
- pr = zonefileparser()
- pr.parse(zonedict[k]['origin'],filepath)
- zonedict[k]['zonedata'] = pr.getzdict()
- except ZonefileError, lineno:
- log(0,'Error reading zone file ' + filepath + ' at line ' +
- str(lineno) + '\n')
- del zonedict[k]
-
-def slowloop(tofunc='',timeout=5.0):
- if not tofunc:
- def tofunc(): return
- map = asyncore.socket_map
- while map:
- r = []; w=[]; e=[]
- for fd, obj in map.items():
- if obj.readable():
- r.append(fd)
- if obj.writable():
- w.append(fd)
- try:
- starttime = time.time()
- r,w,e = select.select(r,w,e,timeout)
- endtime = time.time()
- if endtime-starttime >= timeout:
- tofunc()
- except select.error, err:
- if err[0] != EINTR:
- raise
- r=[]; w=[]; e=[]
- log(0,'ERROR in select')
-
- for fd in r:
- try:
- obj=map[fd]
- except KeyError:
- log(0,'KeyError in socket map')
- continue
- try:
- obj.handle_read_event()
- except:
- log(0,'calling HANDLE ERROR from loop')
- log(0,repr(obj))
- obj.handle_error()
- for fd in w:
- try:
- obj=map[fd]
- except KeyError:
- log(0,'KeyError in socket map')
- continue
- try:
- obj.handle_read_event()
- except:
- log(0,'calling HANDLE ERROR from loop')
- log(0,repr(obj))
- obj.handle_error()
-
-def fastloop(tofunc='',timeout=5.0):
- if not tofunc:
- def tofunc(): return
- polltimeout = timeout*1000
- map = asyncore.socket_map
- while map:
- regfds = 0
- pollobj = select.poll()
- for fd, obj in map.items():
- flags = 0
- if obj.readable():
- flags = select.POLLIN
- if obj.writable():
- flags = flags | select.POLLOUT
- if flags:
- pollobj.register(fd, flags)
- regfds = regfds + 1
- try:
- starttime = time.time()
- r = pollobj.poll(polltimeout)
- endtime = time.time()
- if endtime-starttime >= timeout:
- tofunc()
- except select.error, err:
- if err[0] != EINTR:
- raise
- r = []
- log(0,'ERROR in select')
- for fd, flags in r:
- try:
- obj = map[fd]
- badvals = (select.POLLPRI + select.POLLERR +
- select.POLLHUP + select.POLLNVAL)
- if (flags & badvals):
- if (flags & select.POLLPRI):
- log(0,'POLLPRI')
- if (flags & select.POLLERR):
- log(0,'POLLERR')
- if (flags & select.POLLHUP):
- log(0,'POLLHUP')
- if (flags & select.POLLNVAL):
- log(0,'POLLNVAL')
- obj.handle_error()
- else:
- if (flags & select.POLLIN):
- obj.handle_read_event()
- if (flags & select.POLLOUT):
- obj.handle_write_event()
- except KeyError:
- log(0,'KeyError in socket map')
- continue
- except:
- # print traceback
- sf = StringIO.StringIO()
- traceback.print_exc(file=sf)
- log(0,'ERROR IN LOOP:')
- log(0,sf.getvalue())
- sf.close()
- log(0,repr(obj))
- obj.handle_error()
-
-if hasattr(select,'poll'):
- loop = fastloop
-else:
- loop = slowloop
-
-class ZonefileError(Exception):
- def __init__(self, linenum, errordesc=''):
- self.linenum = linenum
- self.errordesc = errordesc
- def __str__(self):
- return str(self.linenum) + ' (' + self.errordesc + ')'
-
-class zonefileparser:
- def __init__(self):
- self.zonedata = {}
- self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX',
- 'NS','PTR','RP','SOA','SRV','TXT']
-
- def stripcomments(self, line):
- i = line.find(';')
- if i >= 0:
- line = line[:i]
- return line
-
- def strip(self, line):
- # strip trailing linefeeds
- if line[-1:] == '\n':
- line = line[:-1]
- return line
-
- def getzdict(self):
- return self.zonedata
-
- def addorigin(self, origin, name):
- if name[-1:] != '.':
- return name + '.' + origin
- else:
- return name[:-1]
-
- def getstrings(self, s):
- if s.find('"') == -1:
- return s.split()
- else:
- x = s.split('"')
- rlist = []
- for i in x:
- if i != '' and i != ' ':
- rlist.append(i)
- return rlist
-
- def getlocsize(self, s):
- if s[-1:] == 'm':
- size = float(s[:-1])*100
- else:
- size = float(s)*100
- i = 0
- while size > 9:
- size = size/10
- i = i + 1
- return (int(size),i)
-
- def getloclat(self, l,c):
- deg = float(l[0])
- min = 0
- secs = 0
- if len(l) == 3:
- min = float(l[1])
- secs = float(l[2])
- elif len(l) == 2:
- min = float(l[1])
- rval = ((((deg *60) + min) * 60) + secs) * 1000
- if c in ['N','E']:
- rval = rval + (2**31)
- elif c in ['S','W']:
- rval = (2**31) - rval
- else:
- log(0,'ERROR: unsupported latitude/longitude direction')
- return long(rval)
-
- def getgname(self, name, iter):
- if name == '0' or name == 'O':
- return ''
- start = 0
- offset = 0
- width = 0
- base = 'd'
- for x in range(name.count('$')):
- i = name.find('$',start)
- j = i
- start = i+1
- if i>0:
- if name[i-1] == '\\':
- continue
- if len(name)>i+1:
- if name[i+1] == '$':
- continue
- if name[i+1] == '{':
- j = name.find('}',i+1)
- owb = name[i+2:j].split(',')
- if len(owb) == 1:
- offset = int(owb[0])
- elif len(owb) == 2:
- offset = int(owb[0])
- width = int(owb[1])
- elif len(owb) == 3:
- offset = int(owb[0])
- width = int(owb[1])
- base = owb[2]
- val = iter - offset
- if base == 'd':
- rs = str(val)
- elif base == 'o':
- rs = oct(val)
- elif base == 'x':
- rs = hex(val)[2:].lower()
- elif base == 'X':
- rs = hex(val)[2:].upper()
- else:
- rs = ''
- if len(rs) > width:
- rs = (width-len(rs))*'0'+rs
- name = name[:i]+rs+name[j+1:]
- start = i+len(rs)+1
-
- return name
-
- def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens):
- rdata = {}
- rdata['class'] = dnsclass
- rdata['ttl'] = ttl
- if dnstype == 'A':
- rdata['address'] = tokens[0]
- elif dnstype == 'AAAA':
- rdata['address'] = tokens[0]
- elif dnstype == 'CNAME':
- rdata['cname'] = self.addorigin(origin,tokens[0].lower())
- elif dnstype == 'HINFO':
- sl = self.getstrings(' '.join(tokens))
- rdata['cpu'] = sl[0]
- rdata['os'] = sl[1]
- elif dnstype == 'LOC':
- if 'N' in tokens:
- i = tokens.index('N')
- else:
- i = tokens.index('S')
- lat = self.getloclat(tokens[0:i],tokens[i])
- if 'E' in tokens:
- j = tokens.index('E')
- else:
- j = tokens.index('W')
- lng = self.getloclat(tokens[i+1:j],tokens[j])
- size = self.getlocsize('1m')
- horiz_pre = self.getlocsize('1000m')
- vert_pre = self.getlocsize('10m')
- if len(tokens[j+1:]) == 2:
- size = self.getlocsize(tokens[-1:][0])
- elif len(tokens[j+1:]) == 3:
- size = self.getlocsize(tokens[-2:-1][0])
- horiz_pre = self.getlocsize(tokens[-1:][0])
- elif len(tokens[j+1:]) == 4:
- size = self.getlocsize(tokens[-3:-2][0])
- horiz_pre = self.getlocsize(tokens[-2:-1][0])
- vert_pre = self.getlocsize(tokens[-1:][0])
- if tokens[j+1][-1:] == 'm':
- alt = int((float(tokens[j+1][:-1])*100)+10000000)
- else:
- size = int((float(tokens[j+1])*100)+10000000)
- rdata['version'] = 0
- rdata['size'] = size
- rdata['horiz_pre'] = horiz_pre
- rdata['vert_pre'] = vert_pre
- rdata['latitude'] = lat
- rdata['longitude'] = lng
- rdata['altitude'] = 0
- elif dnstype == 'MX':
- rdata['preference'] = int(tokens[0])
- rdata['exchange'] = self.addorigin(origin,tokens[1].lower())
- elif dnstype == 'NS':
- rdata['nsdname'] = self.addorigin(origin,tokens[0].lower())
- elif dnstype == 'PTR':
- rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower())
- elif dnstype == 'RP':
- rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower())
- rdata['txtdname'] = self.addorigin(origin,tokens[1].lower())
- elif dnstype == 'SOA':
- rdata['mname'] = self.addorigin(origin,tokens[0].lower())
- rdata['rname'] = self.addorigin(origin,tokens[1].lower())
- rdata['serial'] = int(tokens[2])
- rdata['refresh'] = int(tokens[3])
- rdata['retry'] = int(tokens[4])
- rdata['expire'] = int(tokens[5])
- rdata['minimum'] = int(tokens[6])
- elif dnstype == 'SRV':
- rdata['priority'] = int(tokens[0])
- rdata['weight'] = int(tokens[1])
- rdata['port'] = int(tokens[2])
- rdata['target'] = self.addorigin(origin,tokens[3].lower())
- elif dnstype == 'TXT':
- rdata['txtdata'] = self.getstrings(' '.join(tokens))[0]
- else:
- raise ZonefileError(lineno,'bad DNS type')
- return rdata
-
- def addrec(self, owner, dnstype, rrdata):
- if self.zonedata.has_key(owner):
- if not self.zonedata[owner].has_key(dnstype):
- self.zonedata[owner][dnstype] = []
- else:
- self.zonedata[owner] = {}
- self.zonedata[owner][dnstype] = []
- self.zonedata[owner][dnstype].append(rrdata)
-
- def parse(self, origin, f):
- closefile = 0
- if type(f) != types.FileType:
- # must be a path
- try:
- f = file(f)
- closefile = 1
- except:
- log(0,'Invalid path to zonefile')
- return
- lastowner = ''
- lastdnsclass = ''
- lastttl = 3600
- lineno = 0
- while 1:
- line = f.readline()
- if not line:
- break
- lineno = lineno + 1
- line = self.stripcomments(line)
- line = self.strip(line)
- if not line:
- continue
- if line.find('(') >= 0:
- # grab lines until end paren
- if line.find(')') == -1:
- line2 = self.stripcomments(f.readline())
- lineno = lineno + 1
- line2 = self.strip(line2)
- line = line + line2
- while line2.find(')') == -1:
- line2 = self.strip(self.stripcomments(f.readline()))
- lineno = lineno + 1
- line = line + line2
- # now strip the parenthesis
- line = line.replace(')','')
- line = line.replace('(','')
- # now line equals the entire RR entry
- tokens = line.split()
- if tokens[0].upper() == '$ORIGIN':
- try:
- origin = tokens[1].lower()
- except:
- raise ZonefileError(lineno, 'bad origin')
- elif tokens[0].upper() == '$INCLUDE':
- try:
- f2 = file(tokens[1].lower())
- if len(tokens) > 2:
- self.parse(tokens[2].lower(), f2)
- else:
- self.parse(origin, f2)
- f2.close()
- except:
- raise ZonefileError(lineno, 'bad INCLUDE directive')
- elif tokens[0].upper() == '$TTL':
- try:
- lastttl = int(tokens[1])
- except:
- raise ZonefileError(lineno, 'bad TTL directive')
- elif tokens[0].upper() == '$GENERATE':
- try:
- lhs = tokens[2].lower()
- dnstype = tokens[3].upper()
- rhs = tokens[4].lower()
- rng = tokens[1].split('-')
- start = int(rng[0])
- i = rng[1].find('/')
- if i != -1:
- stop = int(rng[1][:i])+1
- step = int(rng[1][i+1:])
- else:
- stop = int(rng[1])+1
- step = 1
- for i in range(start,stop,step):
- grhs = self.getgname(rhs,i)
- if dnstype in ['NS','CNAME','PTR']:
- grhs = self.addorigin(origin,grhs)
- rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl,
- [grhs])
- glhs = self.addorigin(origin,self.getgname(lhs,i))
- self.addrec(glhs,dnstype, rrdata)
- except KeyError:
- raise ZonefileError(lineno, 'bad GENERATE directive')
- else:
- try:
- # if line begins with blank then owner is last owner
- if line[0] in string.whitespace:
- owner = lastowner
- else:
- owner = tokens[0].lower()
- tokens = tokens[1:]
- if owner == '@':
- owner = origin
- elif owner[-1:] != '.':
- owner = owner + '.' + origin
- else:
- owner = owner[:-1] # strip off trailing dot
- # line format is either: [class] [ttl] type RDATA
- # or [ttl] [class] type RDATA
- # - items in brackets are optional
- #
- # need to figure out which token is type
- # and backfill the missing data
- count = 0
- for token in tokens:
- if token.upper() in self.dnstypes:
- break
- count = count + 1
- # the following strips off the ttl and class if they exist
- if count == 0:
- ttl = lastttl
- dnsclass = lastdnsclass
- elif count == 1:
- if tokens[0].isdigit():
- ttl = int(tokens[0])
- dnsclass = lastdnsclass
- else:
- ttl = lastttl
- dnsclass = tokens[0].upper()
- tokens = tokens[1:]
- elif count == 2:
- if tokens[0].isdigit():
- ttl = int(tokens[0])
- dnsclass = tokens[1].upper()
- else:
- ttl = int(tokens[1])
- dnsclass = tokens[0].upper()
- tokens = tokens[2:]
- else:
- raise ZonefileError(lineno,'bad ttl or class')
- dnstype = tokens[0]
- # make sure all of the structure is there
- rrdata = self.getrrdata(origin, dnstype, dnsclass,
- ttl, tokens[1:])
- self.addrec(owner, dnstype, rrdata)
- lastowner = owner
- lastttl = ttl
- lastdnsclass = dnsclass
- except:
- raise ZonefileError(lineno,'unable to parse line')
- if closefile:
- f.close()
-
-class dnsconfig:
- def __init__(self):
- # self.zonedb = zonedb({})
- self.cached = {}
- self.loglevel = 0
-
- def getview(self, msg, address, port):
- # return:
- # 1. a list of zone keys
- # 2. whether or not to use the resolver
- # (i.e. answer recursive queries)
- # 3. a list of forwarder addresses
- return ['servers.csail.mit.edu'], 1, []
-
- def allowupdate(self, msg, address, port):
- # return 1 if updates are allowed
- # NOTE: can only update the zones
- # returned by the getview func
- return 1
-
- def outpackets(self, packetlist):
- # modify outgoing packets
- return packetlist
-
-class dnsheader:
- def __init__(self, id=1):
- self.id = id # 16bit identifier generated by queryer
- self.qr = 0 # one bit field specifying query(0) or response(1)
- self.opcode = 0 # 4bit field specifying type of query
- self.aa = 0 # authoritative answer
- self.tc = 0 # message is not truncated
- self.rd = 1 # recursion desired
- self.ra = 0 # recursion available?
- self.z = 0 # reserved for future use
- self.rcode = 0 # response code (set in response)
- self.qdcount = 1 # number of questions, only 1 is supported
- self.ancount = 0 # number of rrs in the answer section
- self.nscount = 0 # number of name server rrs in authority section
- self.arcount = 0 # number or rrs in the additional section
-
-class dnsquestion:
- def __init__(self):
- self.qname = 'localhost'
- self.qtype = 'A'
- self.qclass = 'IN'
-
-class dnsupdatezone:
- pass
-
-class message:
- def __init__(self, msgdata=''):
- if msgdata:
- self.header = dnsheader()
- else:
- self.header = dnsheader(id=random.randrange(1,32768))
- self.question = dnsquestion()
- self.answerlist = []
- self.authlist = []
- self.addlist = []
- self.u = ''
- self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA',
- 7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS',
- 12:'PTR',13:'HINFO',14:'MINFO',15:'MX',
- 16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV',
- 38:'A6',39:'DNAME',251:'IXFR',252:'AXFR',
- 253:'MAILB',254:'MAILA',255:'ANY'}
- self.rqtypes = {}
- for key in self.qtypes.keys():
- self.rqtypes[self.qtypes[key]] = key
- self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'}
- self.rqclasses = {}
- for key in self.qclasses.keys():
- self.rqclasses[self.qclasses[key]] = key
-
- if msgdata:
- self.processpkt(msgdata)
-
- def getdomainname(self, data, i):
- log(4,'IN GETDOMAINNAME')
- domainname = ''
- gotpointer = 0
- labellength= ord(data[i])
- log(4,'labellength:' + str(labellength))
- i = i + 1
- while labellength != 0:
- while labellength >= 192:
- # pointer
- if not gotpointer:
- rindex = i + 1
- gotpointer = 1
- log(4,'got pointer')
- i = asctoint(chr(ord(data[i-1]) & 63)+data[i])
- log(4,'new index:'+str(i))
- labellength = ord(data[i])
- log(4,'labellength:' + str(labellength))
- i = i + 1
- if domainname:
- domainname = domainname + '.' + data[i:i+labellength]
- else:
- domainname = data[i:i+labellength]
- log(4,'domainname:'+domainname)
- i = i + labellength
- labellength = ord(data[i])
- log(4,'labellength:' + str(labellength))
- i = i + 1
- if not gotpointer:
- rindex = i
-
- return domainname.lower(), rindex
-
- def getrrdata(self, type, msgdata, rdlength, i):
- log(4,'unpacking RR data')
- rdata = msgdata[i:i+rdlength]
- if type == 'A':
- return {'address':socket.inet_ntoa(rdata)}
- elif type == 'AAAA':
- return {'address':ipv6net_ntoa(rdata)}
- elif type == 'CNAME':
- cname, i = self.getdomainname(msgdata,i)
- return {'cname':cname}
- elif type == 'HINFO':
- cpulen = ord(rdata[0])
- cpu = rdata[1:cpulen+1]
- return {'cpu':cpu,
- 'os':rdata[cpulen+2:]}
- elif type == 'LOC':
- return {'version':ord(rdata[0]),
- 'size':self.locsize(rdata[1]),
- 'horiz_pre':self.locsize(rdata[2]),
- 'vert_pre':self.locsize(rdata[3]),
- 'latitude':asctoint(rdata[4:8]),
- 'longitude':asctoint(rdata[8:12]),
- 'altitude':asctoint(rdata[12:16])}
- elif type == 'MX':
- exchange, i = self.getdomainname(msgdata,i+2)
- return {'preference':asctoint(rdata[:2]),
- 'exchange':exchange}
- elif type == 'NS':
- nsdname, i = self.getdomainname(msgdata,i)
- return {'nsdname':nsdname}
- elif type == 'PTR':
- ptrdname, i = self.getdomainname(msgdata,i)
- return {'ptrdname':ptrdname}
- elif type == 'RP':
- mboxdname, i = self.getdomainname(msgdata,i)
- txtdname, i = self.getdomainname(msgdata,i)
- return {'mboxdname':mboxdname,
- 'txtdname':txtdname}
- elif type == 'SOA':
- mname, i = self.getdomainname(msgdata,i)
- rname, i = self.getdomainname(msgdata,i)
- return {'mname':mname,
- 'rname':rname,
- 'serial':asctoint(msgdata[i:i+4]),
- 'refresh':asctoint(msgdata[i+4:i+8]),
- 'retry':asctoint(msgdata[i+8:i+12]),
- 'expire':asctoint(msgdata[i+12:i+16]),
- 'minimum':asctoint(msgdata[i+16:i+20])}
- elif type == 'SRV':
- target, i = self.getdomainname(msgdata,i+6)
- return {'priority':asctoint(rdata[0:2]),
- 'weight':asctoint(rdata[2:4]),
- 'port':asctoint(rdata[4:6]),
- 'target':target}
- elif type == 'TXT':
- return {'txtdata':rdata[1:]}
- else:
- return {'rdata':rdata}
-
- def getrr(self, data, i):
- log(4,'unpacking RR name')
- name, i = self.getdomainname(data, i)
- type = asctoint(data[i:i+2])
- type = self.qtypes.get(type,chr(type))
- klass = asctoint(data[i+2:i+4])
- klass = self.qclasses.get(klass,chr(klass))
- ttl = asctoint(data[i+4:i+8])
- rdlength = asctoint(data[i+8:i+10])
- rrdata = self.getrrdata(type,data,rdlength,i+10)
- rrdata['ttl'] = ttl
- rrdata['class'] = klass
- rr = {name:{type:[rrdata]}}
- return rr, i+10+rdlength
-
- def processpkt(self, msgdata):
- self.header.id = asctoint(msgdata[:2])
- self.header.qr = ord(msgdata[2]) >> 7
- self.header.opcode = (ord(msgdata[2]) & 127) >> 3
- if self.header.opcode == 5:
- # UPDATE packet
- log(4,'processing UPDATE packet')
- del self.header.aa
- del self.header.tc
- del self.header.rd
- del self.header.ra
- del self.header.qdcount
- del self.header.ancount
- del self.header.nscount
- del self.header.arcount
- del self.question
- self.zone = dnsupdatezone()
- del self.answerlist
- del self.authlist
- del self.addlist
- self.header.z = 0
- self.header.rcode = ord(msgdata[3]) & 15
- self.header.zocount = asctoint(msgdata[4:6])
- self.header.prcount = asctoint(msgdata[6:8])
- self.header.upcount = asctoint(msgdata[8:10])
- self.header.arcount = asctoint(msgdata[10:12])
- self.zolist = []
- self.prlist = []
- self.uplist = []
- self.addlist = []
- i = 12
- for x in range(self.header.zocount):
- (dn, i) = self.getdomainname(msgdata,i)
- self.zone.zname = dn
- type = asctoint(msgdata[i:i+2])
- self.zone.ztype = self.qtypes.get(type,chr(type))
- klass = asctoint(msgdata[i+2:i+4])
- self.zone.zclass = self.qclasses.get(klass,chr(klass))
- i = i + 4
- for x in range(self.header.prcount):
- rr, i = self.getrr(msgdata,i)
- self.prlist.append(rr)
- for x in range(self.header.upcount):
- rr, i = self.getrr(msgdata,i)
- self.uplist.append(rr)
- for x in range(self.header.arcount):
- rr, i = self.getrr(msgdata,i)
- self.adlist.append(rr)
- else:
- self.header.aa = (ord(msgdata[2]) & 4) >> 2
- self.header.tc = (ord(msgdata[2]) & 2) >> 1
- self.header.rd = ord(msgdata[2]) & 1
- self.header.ra = ord(msgdata[3]) >> 7
- self.header.z = (ord(msgdata[3]) & 112) >> 4
- self.header.rcode = ord(msgdata[3]) & 15
- self.header.qdcount = asctoint(msgdata[4:6])
- self.header.ancount = asctoint(msgdata[6:8])
- self.header.nscount = asctoint(msgdata[8:10])
- self.header.arcount = asctoint(msgdata[10:12])
- i = 12
- for x in range(self.header.qdcount):
- log(4,'unpacking question')
- (dn, i) = self.getdomainname(msgdata,i)
- self.question.qname = dn
- rrtype = asctoint(msgdata[i:i+2])
- self.question.qtype = self.qtypes.get(rrtype,chr(rrtype))
- klass = asctoint(msgdata[i+2:i+4])
- self.question.qclass = self.qclasses.get(klass,chr(klass))
- i = i + 4
- for x in range(self.header.ancount):
- log(4,'unpacking answer RR')
- rr, i = self.getrr(msgdata,i)
- self.answerlist.append(rr)
- for x in range(self.header.nscount):
- log(4,'unpacking auth RR')
- rr, i = self.getrr(msgdata,i)
- self.authlist.append(rr)
- for x in range(self.header.arcount):
- log(4,'unpacking additional RR')
- rr, i = self.getrr(msgdata,i)
- self.addlist.append(rr)
- return
-
- def pds(self, s, l):
- # pad string with chr(0)'s so that
- # return string length is l
- x = l - len(s)
- return x*chr(0) + s
-
- def locsize(self, s):
- x1 = ord(s) >> 4
- x2 = ord(s) & 15
- return (x1, x2)
-
- def packlocsize(self, x):
- return chr((x[0] << 4) + x[1])
-
- def packdomainname(self, name, i, msgcomp):
- log(4,'packing domainname: ' + name)
- if name == '':
- return chr(0)
- if name in msgcomp.keys():
- log(4,'using pointer for: ' + name)
- return msgcomp[name]
- packedname = ''
- tokens = name.split('.')
- for j in range(len(tokens)):
- packedname = packedname + chr(len(tokens[j])) + tokens[j]
- nameleft = '.'.join(tokens[j+1:])
- if nameleft in msgcomp.keys():
- log(4,'using pointer for: ' + nameleft)
- return packedname+msgcomp[nameleft]
- # haven't used a pointer so put this in the dictionary
- pointer = inttoasc(i)
- if len(pointer) == 1:
- msgcomp[name] = chr(192)+pointer
- else:
- msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1]
- log(4,'added pointer for ' + name + '(' + str(i) + ')')
- return packedname + chr(0)
-
- def packrr(self, rr, i, msgcomp):
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- if self.rqtypes.has_key(rrtype):
- typeval = self.rqtypes[rrtype]
- else:
- typeval = ord(rrtype)
- dbrec = rr[rrname][rrtype][0]
- ttl = dbrec['ttl']
- rclass = self.rqclasses[dbrec['class']]
- packedrr = (self.packdomainname(rrname, i, msgcomp) +
- self.pds(inttoasc(typeval),2) +
- self.pds(inttoasc(rclass),2) +
- self.pds(inttoasc(ttl),4))
- i = i + len(packedrr) + 2
- if rrtype == 'A':
- rdata = socket.inet_aton(dbrec['address'])
- elif rrtype == 'AAAA':
- rdata = ipv6net_aton(dbrec['address'])
- elif rrtype == 'CNAME':
- rdata = self.packdomainname(dbrec['cname'], i, msgcomp)
- elif rrtype == 'HINFO':
- rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] +
- chr(len(dbrec['os'])) + dbrec['os'])
- elif rrtype == 'LOC':
- rdata = (chr(dbrec['version']) +
- self.packlocsize(dbrec['size']) +
- self.packlocsize(dbrec['horiz_pre']) +
- self.packlocsize(dbrec['vert_pre']) +
- self.pds(inttoasc(dbrec['latitude']),4) +
- self.pds(inttoasc(dbrec['longitude']),4) +
- self.pds(inttoasc(dbrec['altitude']),4))
- elif rrtype == 'MX':
- rdata = (self.pds(inttoasc(dbrec['preference']),2) +
- self.packdomainname(dbrec['exchange'], i+2, msgcomp))
- elif rrtype == 'NS':
- rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp)
- elif rrtype == 'PTR':
- rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp)
- elif rrtype == 'RP':
- rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
- i = i + len(rdata1)
- rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
- rdata = rdata1 + rdata2
- elif rrtype == 'SOA':
- rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp)
- i = i + len(rdata1)
- rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp)
- rdata = (rdata1 +
- rdata2 +
- self.pds(inttoasc(dbrec['serial']),4) +
- self.pds(inttoasc(dbrec['refresh']),4) +
- self.pds(inttoasc(dbrec['retry']),4) +
- self.pds(inttoasc(dbrec['expire']),4) +
- self.pds(inttoasc(dbrec['minimum']),4))
- elif rrtype == 'SRV':
- rdata = (self.pds(inttoasc(dbrec['priority']),2) +
- self.pds(inttoasc(dbrec['weight']),2) +
- self.pds(inttoasc(dbrec['port']),2) +
- self.packdomainname(dbrec['target'], i+6, msgcomp))
- elif rrtype == 'TXT':
- rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata']
- else:
- rdata = dbrec['rdata']
-
- return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata
-
- def buildpkt(self):
- # keep dictionary of names packed (so we can use pointers)
- msgcomp = {}
- # header
- if self.header.id > 65535:
- log(0,'building packet with bad ID field')
- self.header.id = 1
- msgdata = inttoasc(self.header.id)
- if len(msgdata) == 1:
- msgdata = chr(0) + msgdata
- h1 = ((self.header.qr << 7) +
- (self.header.opcode << 3) +
- (self.header.aa << 2) +
- (self.header.tc << 1) +
- (self.header.rd))
- h2 = ((self.header.ra << 7) +
- (self.header.z << 4) +
- (self.header.rcode))
- msgdata = msgdata + chr(h1) + chr(h2)
- msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2)
- msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2)
- msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2)
- msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2)
- # question
- msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp)
- if self.rqtypes.has_key(self.question.qtype):
- typeval = self.rqtypes[self.question.qtype]
- else:
- typeval = ord(self.question.qtype)
- msgdata = msgdata + self.pds(inttoasc(typeval),2)
- if self.rqclasses.has_key(self.question.qclass):
- classval = self.rqclasses[self.question.qclass]
- else:
- classval = ord(self.question.qclass)
- msgdata = msgdata + self.pds(inttoasc(classval),2)
- # rr's
- # RR record format:
- # {'name' : {'type' : [rdata, rdata, ...]}
- # example: {'test.blah.net': {'A': [{'address': '10.1.1.2',
- # 'ttl': 3600L}]}}
- for rr in self.answerlist:
- log(4,'packing answer RR')
- msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
- for rr in self.authlist:
- log(4,'packing auth RR')
- msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
- for rr in self.addlist:
- log(4,'packing additional RR')
- msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
-
- return msgdata
-
- def printpkt(self):
- print 'ID: ' +str(self.header.id)
- if self.header.qr:
- print 'QR: RESPONSE'
- else:
- print 'QR: QUERY'
- if self.header.opcode == 0:
- print 'OPCODE: STANDARD QUERY'
- elif self.header.opcode == 1:
- print 'OPCODE: INVERSE QUERY'
- elif self.header.opcode == 2:
- print 'OPCODE: SERVER STATUS REQUEST'
- elif self.header.opcode == 5:
- print 'UPDATE REQUEST'
- else:
- print 'OPCODE: UNKNOWN QUERY TYPE'
- if self.header.opcode != 5:
- if self.header.aa:
- print 'AA: AUTHORITATIVE ANSWER'
- else:
- print 'AA: NON-AUTHORITATIVE ANSWER'
- if self.header.tc:
- print 'TC: MESSAGE IS TRUNCATED'
- else:
- print 'TC: MESSAGE IS NOT TRUNCATED'
- if self.header.rd:
- print 'RD: RECURSION DESIRED'
- else:
- print 'RD: RECURSION NOT DESIRED'
- if self.header.ra:
- print 'RA: RECURSION AVAILABLE'
- else:
- print 'RA: RECURSION IS NOT AVAILABLE'
- if self.header.rcode == 1:
- printrcode = 'FORMERR'
- elif self.header.rcode == 2:
- printrcode = 'SERVFAIL'
- elif self.header.rcode == 3:
- printrcode = 'NXDOMAIN'
- elif self.header.rcode == 4:
- printrcode = 'NOTIMP'
- elif self.header.rcode == 5:
- printrcode = 'REFUSED'
- elif self.header.rcode == 6:
- printrcode = 'YXDOMAIN'
- elif self.header.rcode == 7:
- printrcode = 'YXRRSET'
- elif self.header.rcode == 8:
- printrcode = 'NXRRSET'
- elif self.header.rcode == 9:
- printrcode = 'NOTAUTH'
- elif self.header.rcode == 10:
- printrcode = 'NOTZONE'
- else:
- printrcode = 'NOERROR'
- print 'RCODE: ' + printrcode
- if self.header.opcode == 5:
- print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount)
- print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount)
- print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount)
- print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount)
- print 'ZONE SECTION:'
- print 'zname: ' + self.zone.zname
- print 'zonetype: ' + self.zone.ztype
- print 'zoneclass: ' + self.zone.zclass
- print 'PREREQUISITE RRs:'
- for rr in self.prlist:
- print rr
- print 'UPDATE RRs:'
- for rr in self.uplist:
- print rr
- print 'ADDITIONAL RRs:'
- for rr in self.addlist:
- print rr
-
-
- else:
- print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount)
- print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount)
- print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount)
- print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount)
- print 'QUESTION SECTION:'
- print 'qname: ' + self.question.qname
- print 'querytype: ' + self.question.qtype
- print 'queryclass: ' + self.question.qclass
- print 'ANSWER RRs:'
- for rr in self.answerlist:
- print rr
- print 'AUTHORITY RRs:'
- for rr in self.authlist:
- print rr
- print 'ADDITIONAL RRs:'
- for rr in self.addlist:
- print rr
-
-class zonedb:
- def __init__(self, zdict):
- self.zdict = zdict
- self.updates = {}
- for k in self.zdict.keys():
- if self.zdict[k]['type'] == 'slave':
- self.zdict[k]['lastupdatetime'] = 0
-
- def error(self, id, qname, querytype, queryclass, rcode):
- error = message()
- error.header.id = id
- error.header.rcode = rcode
- error.header.qr = 1
- error.question.qname = qname
- error.question.qtype = querytype
- error.question.qclass = queryclass
- return error
-
- def getorigin(self, zkey):
- origin = ''
- if self.zdict.has_key(zkey):
- origin = self.zdict[zkey]['origin']
- return origin
-
- def getmasterip(self, zkey):
- masterip = ''
- if self.zdict.has_key(zkey):
- if self.zdict[zkey].has_key('masterip'):
- masterip = self.zdict[zkey]['masterip']
- return masterip
-
- def zonetrans(self, query):
- # build a list of messages
- # each message contains one rr of the zone
- # the first and last message are the
- # SOA records
- origin = query.question.qname
- querytype = query.question.qtype
- zkey = ''
- for zonekey in self.zdict.keys():
- if self.zdict[zonekey]['origin'] == query.question.qname:
- zkey = zonekey
- if not zkey:
- return []
- zonedata = self.zdict[zkey]['zonedata']
- queryid = query.header.id
- soarec = zonedata[origin]['SOA'][0]
- soa = {origin:{'SOA':[soarec]}}
- curserial = soarec['serial']
- rrlist = []
- if querytype == 'IXFR':
- clientserial = query.authlist[0][origin]['SOA'][0]['serial']
- if clientserial < curserial:
- for i in range(clientserial,curserial+1):
- if self.updates[zkey].has_key(i):
- for rr in self.updates[zkey][i]['added']:
- rrlist.append(rr)
- for rr in self.updates[zkey][i]['removed']:
- rrlist.append(rr)
- if len(rrlist) > 0:
- rrlist.insert(0,soa)
- rrlist.append(soa)
- else:
- rrlist.append(soa)
- else:
- for nodename in zonedata.keys():
- for rrtype in zonedata[nodename].keys():
- if not (rrtype == 'SOA' and nodename == origin):
- for rr in zonedata[nodename][rrtype]:
- rrlist.append({nodename:{rrtype:[rr]}})
- rrlist.insert(0,soa)
- rrlist.append(soa)
- msglist = []
- for rr in rrlist:
- msg = message()
- msg.header.id = queryid
- msg.header.qr = 1
- msg.header.aa = 1
- msg.header.rd = 0
- msg.header.qdcount = 1
- msg.question.qname = origin
- msg.question.qtype = querytype
- msg.question.qclass = 'IN'
- msg.header.ancount = 1
- msg.answerlist.append(rr)
- msglist.append(msg)
- return msglist
-
- def update_zone(self, rrlist, params):
- zonekey = params[0]
- zonedata = {}
- soa = rrlist.pop()
- origin = soa.keys()[0]
- for rr in rrlist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if zonedata.has_key(rrname):
- if not zonedata[rrname].has_key(rrtype):
- zonedata[rrname][rrtype] = []
- else:
- zonedata[rrname] = {}
- zonedata[rrname][rrtype] = []
- zonedata[rrname][rrtype].append(dbrec)
- self.zdict[zonekey]['zonedata'] = zonedata
- curtime = time.time()
- self.zdict[zonekey]['lastupdatetime'] = curtime
- try:
- f = file(self.zdict[zonekey]['filename'],'w')
- writezonefile(zonedata, self.zdict[zonekey]['origin'], f)
- f.close()
- except:
- log(0,'unable to write zone ' + zonekey + 'to disk')
- log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')')
-
- def remove_zone(self, zonekey):
- if self.zdict.has_key(zonekey):
- del self.zdict[zonekey]
-
- def getslaves(self, curtime):
- rlist = []
- for k in self.zdict.keys():
- if self.zdict[k]['type'] == 'slave':
- origin = self.zdict[k]['origin']
- refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh']
- if self.zdict[k]['lastupdatetime'] + refresh < curtime:
- rlist.append((k, origin, self.zdict[k]['masterip']))
- return rlist
-
- def zmatch(self, qname, zkeys):
- for zkey in zkeys:
- if self.zdict.has_key(zkey):
- origin = self.zdict[zkey]['origin']
- if qname.rfind(origin) != -1:
- return zkey
- return ''
-
- def getzlist(self, name, zone):
- if name == zone:
- return
- zlist = []
- i = name.rfind(zone)
- if i == -1:
- return
- firstpart = name[:i-1]
- partlist = firstpart.split('.')
- partlist.reverse()
- lastpart = zone
- for x in range(len(partlist)):
- lastpart = partlist[x] + '.' + lastpart
- zlist.append(lastpart)
- return zlist
-
- def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc):
- # handle zone transfers seperately
- qname = query.question.qname
- querytype = query.question.qtype
- queryclass = query.question.qclass
- if querytype in ['AXFR','IXFR']:
- for zkey in self.zdict.keys():
- if zkey in zkeys:
- if qname == self.zdict[zkey]['origin']:
- answerlist = self.zonetrans(query)
- break
- else:
- answerlist = []
- cbfunc(query, addr, server, dorecursion, flist, answerlist)
- else:
- zonekey = self.zmatch(qname, zkeys)
- if zonekey:
- origin = self.zdict[zonekey]['origin']
- zonedict = self.zdict[zonekey]['zonedata']
- referral = 0
- rranswerlist = []
- rrnslist = []
- rraddlist = []
- answer = message()
- answer.header.aa = 1
- answer.header.id = query.header.id
- answer.header.qr = 1
- answer.header.opcode = query.header.opcode
- answer.header.rcode = 4
- answer.header.ra = dorecursion
- answer.question.qname = query.question.qname
- answer.question.qtype = query.question.qtype
- answer.question.qclass = query.question.qclass
- answer.header.ra = dorecursion
- s = '.servers.csail.mit.edu'
- if qname.endswith(s):
- host = qname[:-len(s)]
- value = sipb_xen_database.NIC.get_by(hostname=host)
- if value is None:
- pass
- else:
- ip = value.ip
- rranswerlist.append({qname: {'A': [{'address': ip,
- 'class': 'IN',
- 'ttl': 10}]}})
- if zonedict.has_key(qname):
- # found the node, now take care of CNAMEs
- if zonedict[qname].has_key('CNAME'):
- if querytype != 'CNAME':
- nodetype = 'CNAME'
- while nodetype == 'CNAME':
- rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}})
- qname = zonedict[qname]['CNAME'][0]['cname']
- if zonedict.has_key(qname):
- nodetype = zonedict[qname].keys()[0]
- else:
- # error, shouldn't have a CNAME that points to nothing
- return
- # if we get this far, then the record has matched and we should return
- # a reply that has no error (even if there is no info macthing the qtype)
- answer.header.rcode = 0
- answernode = zonedict[qname]
- if querytype == 'ANY':
- for type in answernode.keys():
- for rec in answernode[type]:
- rranswerlist.append({qname:{type:[rec]}})
- elif answernode.has_key(querytype):
- for rec in answernode[querytype]:
- rranswerlist.append({qname:{querytype:[rec]}})
- # do rrset ordering (cyclic)
- if len(answernode[querytype]) > 1:
- rec = answernode[querytype].pop(0)
- answernode[querytype].append(rec)
- else:
- # remove all cname rrs from answerlist
- rranswerlist = []
- else:
- # would check for wildcards here (but aren't because they seem bad)
- # see if we need to give a referral
- zlist = self.getzlist(qname,origin)
- for zonename in zlist:
- if zonedict.has_key(zonename):
- if zonedict[zonename].has_key('NS'):
- answer.header.rcode = 0
- referral = 1
- for rec in zonedict[zonename]['NS']:
- rrnslist.append({zonename:{'NS':[rec]}})
- nsdname = rec['nsdname']
- # add glue records if they exist
- if zonedict.has_key(nsdname):
- if zonedict[nsdname].has_key('A'):
- for gluerec in zonedict[nsdname]['A']:
- rraddlist.append({nsdname:{'A':[gluerec]}})
- # negative caching stuff
- if not referral:
- if not rranswerlist:
- # NOTE: RFC1034 section 4.3.4 says we should add the SOA record
- # to the additional section of the response. BIND adds
- # it to the ns section though
- answer.header.rcode = 3
- rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}})
- else:
- for rec in zonedict[origin]['NS']:
- rrnslist.append({origin:{'NS':[rec]}})
- answer.header.ancount = len(rranswerlist)
- answer.header.nscount = len(rrnslist)
- answer.header.arcount = len(rraddlist)
- answer.answerlist = rranswerlist
- answer.authlist = rrnslist
- answer.addlist = rraddlist
- cbfunc(query, addr, server, dorecursion, flist, [answer])
- else:
- cbfunc(query, addr, server, dorecursion, flist, [])
-
- def handle_update(self, msg, addr, ns):
- zkey = ''
- slaves = []
- for zonekey in self.zdict.keys():
- if (self.zdict[zonekey]['type'] == 'master' and
- self.zdict[zonekey]['origin'] == msg.zone.zname):
- zkey = zonekey
- if not zkey:
- log(2,'SENDING NOTAUTH UPDATE ERROR')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 9)
- return errormsg, '', slaves
- # find the slaves for the zone
- if self.zdict[zkey].has_key('slaves'):
- slaves = self.zdict[zkey]['slaves']
- origin = self.zdict[zkey]['origin']
- zd = self.zdict[zkey]['zonedata']
- # check the permissions
- if not ns.config.allowupdate(msg, addr[0], addr[1]):
- log(2,'SENDING REFUSED UPDATE ERROR')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 5)
- return errormsg, origin, slaves
- # now check the prereqs
- temprrset = {}
- for rr in msg.prlist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if dbrec['ttl'] != 0:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- if rrname.rfind(msg.zone.zname) == -1:
- log(2,'NOTZONE(10)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 10)
- return errormsg, origin, slaves
- if dbrec['class'] == 'ANY':
- if dbrec['rdata']:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- if rrtype == 'ANY':
- if not zd.has_key(rrname):
- log(2,'NXDOMAIN(3)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 3)
- return errormsg, origin, slaves
- else:
- rrsettest = 0
- if zd.has_key(rrname):
- if zd[rrname].has_key(rrtype):
- rrsettest = 1
- if not rrsettest:
- log(2,'NXRRSET(8)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 8)
- return errormsg, origin, slaves
- if dbrec['class'] == 'NONE':
- if dbrec['rdata']:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- if rrtype == 'ANY':
- if zd.has_key(rrname):
- log(2,'YXDOMAIN(6)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 6)
- return errormsg, origin, slaves
- else:
- if zd.has_key(rrname):
- if zd[rrname].has_key(rrtype):
- log(2,'YXRRSET(7)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 7)
- return errormsg, origin, slaves
- if dbrec['class'] == msg.zone.zclass:
- if temprrset.has_key(rrname):
- if not temprrset[rrname].has_key(rrtype):
- temprrset[rrname][rrtype] = []
- else:
- temprrset[rrname] = {}
- temprrset[rrname][rrtype] = []
- temprrset[rrname][rrtype].append(dbrec)
- else:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- for nodename in temprrset.keys():
- if not self.rrmatch(temprrset[nodename],zd[nodename]):
- log(2,'NXRRSET(8)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 8)
- return errormsg, origin, slaves
-
- # update section prescan
- for rr in msg.uplist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if rrname.rfind(msg.zone.zname) == -1:
- log(2,'NOTZONE(10)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 10)
- return errormsg, origin, slaves
- if dbrec['class'] == msg.zone.zclass:
- if rrtype in ['ANY','MAILA','MAILB','AXFR']:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- elif dbrec['class'] == 'ANY':
- if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- elif dbrec['class'] == 'NONE':
- if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
- else:
- log(2,'FORMERROR(1)')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- return errormsg, origin, slaves
-
- # now handle actual update
- curserial = zd[msg.zone.zname]['SOA'][0]['serial']
- # update the soa serial here
- clearupdatehist = 0
- if len(msg.uplist) > 0:
- # initialize history structure
- if not self.updates.has_key(zkey):
- self.updates[zkey] = {}
- self.updates[zkey][curserial] = {'removed':[],
- 'added':[]}
- if curserial == 2**32:
- newserial = 2
- clearupdatehist = 1
- else:
- newserial = curserial + 1
- self.updates[zkey][newserial] = {'removed':[],
- 'added':[]}
- zd[msg.zone.zname]['SOA'][0]['serial'] = newserial
- for rr in msg.uplist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if dbrec['class'] == msg.zone.zclass:
- if rrtype == 'SOA':
- if zd.has_key(rrname):
- if zd[rrname].has_key('SOA'):
- if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']:
- del zd[rrname]['SOA'][0]
- zd[rrname]['SOA'].append(dbrec)
- clearupdatehist = 1
- elif rrtype == 'WKS':
- if zd.has_key(rrname):
- if zd[rrname].has_key('WKS'):
- rdata = zd[rrname]['WKS'][0]
- oldrr = {rrname:{'WKS':[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname]['WKS'][0]
- zd[rrname]['WKS'].append(dbrec)
- newrr = {rrname:{'WKS':[dbrec]}}
- self.updates[zkey][newserial]['added'].append(newrr)
- else:
- if zd.has_key(rrname):
- if not zd[rrname].has_key(rrtype):
- zd[rrname][rrtype] = []
- else:
- zd[rrname] = {}
- zd[rrname][rrtype] = []
- zd[rrname][rrtype].append(dbrec)
- newrr = {rrname:{rrtype:[dbrec]}}
- self.updates[zkey][newserial]['added'].append(newrr)
- elif dbrec['class'] == 'ANY':
- if rrtype == 'ANY':
- if rrname == msg.zone.zname:
- if zd.has_key(rrname):
- for dnstype in zd[rrname].keys():
- if dnstype not in ['SOA','NS']:
- for rdata in zd[rrname][dnstype]:
- oldrr = {rrname:{dnstype:[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname][dnstype]
- else:
- if zd.has_key(rrname):
- for dnstype in zd[rrname].keys():
- for rdata in zd[rrname][dnstype]:
- oldrr = {rrname:{dnstype:[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname]
- else:
- if zd.has_key(rrname):
- if zd[rrname].has_key(rrtype):
- if rrname == msg.zone.zname:
- if rrtype not in ['SOA','NS']:
- for rdata in zd[rrname][dnstype]:
- oldrr = {rrname:{dnstype:[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname][rrtype]
- else:
- for rdata in zd[rrname][dnstype]:
- oldrr = {rrname:{dnstype:[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname][rrtype]
- elif dbrec['class'] == 'NONE':
- if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']):
- if zd.had_key(rrname):
- if zd[rrname].has_key(rrtype):
- for i in range(len(zd[rrname][rrtype])):
- if dbrec == zd[rrname][rrtype][i]:
- rdata = zd[rrname][dnstype][i]
- oldrr = {rrname:{dnstype:[rdata]}}
- self.updates[zkey][curserial]['removed'].append(oldrr)
- del zd[rrname][rrtype][i]
- if len(zd[rrname][rrtype]) == 0:
- del zd[rrname][rrtype]
- if clearupdatehist:
- self.updates[zkey] = {}
- log(2,'SENDING UPDATE NOERROR MSG')
- noerrormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 0)
- return noerrormsg, origin, slaves
-
-class dnscache:
- def __init__(self,cachezone):
- self.cachedb = cachezone
- # go through and set all of the root ttls to zero
- for node in self.cachedb.keys():
- for rtype in self.cachedb[node].keys():
- for rr in self.cachedb[node][rtype]:
- rr['ttl'] = 0
- if rtype == 'NS':
- rr['rtt'] = 0
- # add special entries for localhost
- self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]}
- self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]}
- self.cachedb['']['SOA'] = []
- self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb',
- 'rname':'cachedb@localhost','serial':1,'refresh':10800,
- 'retry':3600,'expire':604800,'minimum':3600})
-
- def hasrdata(self, irrdata, rrdatalist):
- # compare everything but ttls
- test = 0
- testrrdata = irrdata.copy()
- del testrrdata['ttl']
- for rrdata in rrdatalist:
- temprrdata = rrdata.copy()
- del temprrdata['ttl']
- if temprrdata == testrrdata:
- test = 1
- return test
-
- def add(self, rr, qzone, nsdname):
- # NOTE: can't cache records from sites
- # that don't own those records (i.e. example.com
- # can't give us A records for www.example.net)
- name = rr.keys()[0]
- if (qzone != '') and (name[-len(qzone):] != qzone):
- log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone)
- return
- rtype = rr[name].keys()[0]
- rdata = rr[name][rtype][0]
- if rdata['ttl'] < 3600:
- log(2,'low ttl: ' + str(rdata['ttl']))
- rdata['ttl'] = 3600
- rdata['ttl'] = int(time.time() + rdata['ttl'])
- if rtype == 'NS':
- rdata['rtt'] = 0
- name = name.lower()
- rtype = rtype.upper()
- if self.cachedb.has_key(name):
- if self.cachedb[name].has_key(rtype):
- if not self.hasrdata(rdata, self.cachedb[name][rtype]):
- self.cachedb[name][rtype].append(rdata)
- log(3,'appended rdata to ' +
- name + '(' + rtype + ') in cache')
- else:
- log(3,'same rdata for ' + name + '(' +
- rtype + ') is already in cache')
- else:
- self.cachedb[name][rtype] = [rdata]
- log(3,'appended ' + rtype + ' and rdata to node ' +
- name + ' in cache')
- else:
- self.cachedb[name] = {rtype:[rdata]}
- log(3,'added node ' + name + '(' + rtype + ') to cache')
- self.reap()
-
- def addneg(self, qname, querytype, queryclass):
- if not self.cachedb.has_key(qname):
- self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]}
- else:
- if not self.cachedb[qname].has_key(querytype):
- self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}]
-
- def haskey(self, qname, querytype, msg=''):
- log(3,'looking for ' + qname + '(' + querytype + ') in cache')
- if self.cachedb.has_key(qname):
- rranswerlist = []
- rrnslist = []
- rraddlist = []
- if self.cachedb[qname].has_key('CNAME'):
- if querytype != 'CNAME':
- nodetype = 'CNAME'
- while nodetype == 'CNAME':
- if len(self.cachedb[qname]['CNAME'][0].keys()) > 1:
- log(3,'Adding CNAME to cache answer')
- rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}})
- qname = self.cachedb[qname]['CNAME'][0]['cname']
- if self.cachedb.has_key(qname):
- nodetype = self.cachedb[qname].keys()[0]
- else:
- # shouldn't have a CNAME that points to nothing
- return
- if querytype == 'ANY':
- for type in self.cache[qname].keys():
- for rec in self.cachedb[qname][type]:
- # can't append negative entries
- if len(rec.keys()) > 1:
- rranswerlist.append({qname:{type:[rec]}})
- elif self.cachedb[qname].has_key(querytype):
- for rec in self.cachedb[qname][querytype]:
- if len(rec.keys()) > 1:
- rranswerlist.append({qname:{querytype:[rec]}})
- if rranswerlist:
- if msg:
- answer = message()
- answer.header.id = msg.header.id
- answer.header.qr = 1
- answer.header.opcode = msg.header.opcode
- answer.header.ra = 1
- answer.question.qname = msg.question.qname
- answer.question.qtype = msg.question.qtype
- answer.question.qclass = msg.question.qclass
- answer.header.rcode = 0
- answer.header.ancount = len(rranswerlist)
- answer.answerlist = rranswerlist
- return answer
- else:
- return 1
- else:
- log(3,'Cache has no node for ' + qname)
-
- def getnslist(self, qname):
- # find the best nameserver to ask from the cache
- tokens = qname.split('.')
- nsdict = {}
- curtime = time.time()
- for i in range(len(tokens)):
- domainname = '.'.join(tokens[i:])
- if self.cachedb.has_key(domainname):
- if self.cachedb[domainname].has_key('NS'):
- for nsrec in self.cachedb[domainname]['NS']:
- badserver = 0
- if nsrec.has_key('badtill'):
- if nsrec['badtill'] < curtime:
- del nsrec['badtill']
- else:
- badserver = 1
- if badserver:
- log(2,'BAD SERVER, not using ' + nsrec['nsdname'])
- if self.cachedb.has_key(nsrec['nsdname']) and not badserver:
- if self.cachedb[nsrec['nsdname']].has_key('A'):
- for arec in self.cachedb[nsrec['nsdname']]['A']:
- nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'],
- 'ip':arec['address']}
- if nsdict:
- break
- if not nsdict:
- domainname = ''
- # nothing in the cache matches so give back the root servers
- for nsrec in self.cachedb['']['NS']:
- badserver = 0
- if nsrec.has_key('badtill'):
- if curtime > nsrec['badtill']:
- del nsrec['badtill']
- else:
- badserver = 1
- if not badserver:
- for arec in self.cachedb[nsrec['nsdname']]['A']:
- nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']}
-
- return (domainname, nsdict)
-
- def badns(self, zonename, nsdname):
- if self.cachedb.has_key(zonename):
- if self.cachedb[zonename].has_key('NS'):
- for nsrec in self.cachedb[zonename]['NS']:
- if nsrec['nsdname'] == nsdname:
- log(2,'Setting ' + nsdname + ' as bad nameserver')
- nsrec['badtill'] = time.time() + 3600
-
-
- def updatertt(self, qname, zone, rtt):
- if self.cachedb.has_key(zone):
- if self.cachedb[zone].has_key('NS'):
- for rr in self.cachedb[zone]['NS']:
- if rr['nsdname'] == qname:
- log(2,'updating rtt for ' + qname + ' to ' + str(rtt))
- rr['rtt'] = rtt
-
- def reap(self):
- # expire all old records
- ntime = time.time()
- for nodename in self.cachedb.keys():
- for rrtype in self.cachedb[nodename].keys():
- for rdata in self.cachedb[nodename][rrtype]:
- ttl = rdata['ttl']
- if ttl != 0:
- if ttl < ntime:
- self.cachedb[nodename][rrtype].remove(rdata)
- if len(self.cachedb[nodename][rrtype]) == 0:
- del self.cachedb[nodename][rrtype]
- if len(self.cachedb[nodename]) == 0:
- del self.cachedb[nodename]
-
- return
-
- def zonetrans(self, queryid):
- # build a list of messages
- # each message contains one rr of the zone
- # the first and last message are the
- # SOA records
- zonedata = self.cachedb
- rrlist = []
- soa = {'':{'SOA':[zonedata['']['SOA'][0]]}}
- for nodename in zonedata.keys():
- for rrtype in zonedata[nodename].keys():
- if not (rrtype == 'SOA' and nodename == ''):
- for rr in zonedata[nodename][rrtype]:
- rrlist.append({nodename:{rrtype:[rr]}})
- rrlist.insert(0,soa)
- rrlist.append(soa)
- msglist = []
- for rr in rrlist:
- msg = message()
- msg.header.id = queryid
- msg.header.qr = 1
- msg.header.aa = 1
- msg.header.rd = 0
- msg.header.qdcount = 1
- msg.question.qname = 'cache'
- msg.question.qtype = 'AXFR'
- msg.question.qclass = 'IN'
- msg.header.ancount = 1
- msg.answerlist.append(rr)
- msglist.append(msg)
- return msglist
-
-class gethostaddr(asyncore.dispatcher):
- def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'):
- asyncore.dispatcher.__init__(self)
- self.msg = message()
- self.msg.question.qname = hostname
- self.msg.question.qtype = 'A'
- self.cbfunc = cbfunc
- self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
- self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
-
- def handle_read(self):
- replydata, addr = self.socket.recvfrom(1500)
- self.close()
- try:
- replymsg = message(replydata)
- except:
- log(0,'unable to process packet')
- return
- answername = replymsg.question.qname
- cname = ''
- # go through twice to catch cnames after A recs
- for rr in replymsg.answerlist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if rrname == answername and rrtype == 'CNAME':
- answername = dbrec['cname']
- cname = answername
- for rr in replymsg.answerlist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- dbrec = rr[rrname][rrtype][0]
- if rrname == answername and rrtype == 'A':
- self.cbfunc(dbrec['address'])
- return
- # if we got a cname and no A send query for cname
- if cname:
- self.msg = message()
- self.msg.question.qname = cname
- self.msg.question.qtype = 'A'
- self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
- else:
- self.cbfunc('')
-
- def writable(self):
- return 0
-
- def handle_write(self):
- pass
-
- def handle_connect(self):
- pass
-
- def handle_close(self):
- self.close()
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class simpleudprequest(asyncore.dispatcher):
- def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''):
- asyncore.dispatcher.__init__(self)
- self.gotanswer = 0
- self.msg = msg
- self.cbfunc = cbfunc
- self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
- self.outqkey = outqkey
- self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
-
- def handle_read(self):
- replydata, addr = self.socket.recvfrom(1500)
- self.close()
- try:
- replymsg = message(replydata)
- except:
- log(0,'unable to process packet')
- return
- self.cbfunc(replymsg, self.outqkey)
-
- def writable(self):
- return 0
-
- def handle_write(self):
- pass
-
- def handle_connect(self):
- pass
-
- def handle_close(self):
- self.close()
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class simpletcprequest(asyncore.dispatcher):
- def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''):
- asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.query = msg
- self.cbfunc = cbfunc
- self.cbparams = cbparams
- self.errorfunc = errorfunc
- msgdata = msg.buildpkt()
- ml = inttoasc(len(msgdata))
- if len(ml) == 1:
- ml = chr(0) + ml
- self.buffer = ml+msgdata
- self.rbuffer = ''
- self.rmsgleft = 0
- self.rrlist = []
- log(2,'sending tcp request to ' + serveraddr)
- self.connect((serveraddr,53))
-
- def recv (self, buffer_size):
- try:
- data = self.socket.recv (buffer_size)
- if not data:
- # a closed connection is indicated by signaling
- # a read condition, and having recv() return 0.
- self.handle_close()
- return ''
- else:
- return data
- except socket.error, why:
- # winsock sometimes throws ENOTCONN
- if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]:
- self.handle_close()
- return ''
- else:
- raise socket.error, why
-
- def handle_connect(self):
- pass
-
- def handle_msg(self, msg):
- if self.query.question.qtype == 'AXFR':
- if len(self.rrlist) == 0:
- if len(msg.answerlist) == 0:
- if self.errorfunc:
- self.errorfunc(self.cbparams[0])
- self.close()
- return
- rr = msg.answerlist[0]
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- self.rrlist.append(rr)
- if rrtype == 'SOA' and len(self.rrlist) > 1:
- self.close()
- if self.cbparams:
- self.cbfunc(self.rrlist, self.cbparams)
- else:
- self.cbfunc(self.rrlist)
- else:
- self.close()
- if self.cbparams:
- self.cbfunc(msg, self.cbparams)
- else:
- self.cbfunc(msg)
-
- def handle_read(self):
- data = self.recv(8192)
- if len(self.rbuffer) == 0:
- self.rmsglength = asctoint(data[:2])
- data = data[2:]
- self.rbuffer = self.rbuffer + data
- while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0:
- msgdata = self.rbuffer[:self.rmsglength]
- self.rbuffer = self.rbuffer[self.rmsglength:]
- if len(self.rbuffer) == 0:
- self.rmsglength = 0
- else:
- self.rmsglength = asctoint(self.rbuffer[:2])
- self.rbuffer = self.rbuffer[2:]
- try:
- self.handle_msg(message(msgdata))
- except:
- return
-
- def writable(self):
- return (len(self.buffer) > 0)
-
- def handle_write(self):
- sent = self.send(self.buffer)
- self.buffer = self.buffer[sent:]
-
- def handle_close(self):
- if self.errorfunc:
- self.errorfunc(self.query.question.qname)
- self.close()
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class udpdnsserver(asyncore.dispatcher):
- def __init__(self, port, dnsserver):
- asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
- self.bind(('',port))
- self.dnsserver = dnsserver
- self.maxmsgsize = 500
-
- def handle_read(self):
- try:
- while 1:
- msgdata, addr = self.socket.recvfrom(1500)
- self.dnsserver.handle_packet(msgdata, addr, self)
- except socket.error, why:
- if why[0] != asyncore.EWOULDBLOCK:
- raise socket.error, why
-
- def sendpackets(self, msglist, addr):
- for msg in msglist:
- msgdata = msg.buildpkt()
- if len(msgdata) > self.maxmsgsize:
- msg.header.tc = 1
- # take off all the answers to ensure
- # the packet size is small enough
- msg.header.ancount = 0
- msg.header.nscount = 0
- msg.header.arcount = 0
- msg.answerlist = []
- msg.authlist = []
- msg.addlist = []
- msgdata = msg.buildpkt()
- self.sendto(msgdata, addr)
-
- def writable(self):
- return 0
-
- def handle_write(self):
- pass
-
- def handle_connect(self):
- pass
-
- def handle_close(self):
- # print '1:In handle close'
- return
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class tcpdnschannel(asynchat.async_chat):
- def __init__(self, server, s, addr):
- asynchat.async_chat.__init__(self, s)
- self.server = server
- self.addr = addr
- self.set_terminator(None)
- self.databuffer = ''
- self.msglength = 0
- log(3,'Created new tcp channel')
-
- def collect_incoming_data(self, data):
- if self.msglength == 0:
- self.msglength = asctoint(data[:2])
- data = data[2:]
- self.databuffer = self.databuffer + data
- if len(self.databuffer) == self.msglength:
- # got entire message
- self.server.dnsserver.handle_packet(self.databuffer, self.addr, self)
- self.databuffer = ''
-
- def sendpackets(self, msglist, addr):
- for msg in msglist:
- x = msg.buildpkt()
- ml = inttoasc(len(x))
- if len(ml) == 1:
- ml = chr(0) + ml
- self.push(ml+x)
- self.close()
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class tcpdnsserver(asyncore.dispatcher):
- def __init__(self, port, dnsserver):
- asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
- self.set_reuse_addr()
- self.bind(('',port))
- self.listen(5)
- self.dnsserver = dnsserver
-
- def handle_accept(self):
- conn, addr = self.accept()
- tcpdnschannel(self, conn, addr)
-
- def handle_close(self):
- self.close()
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class nameserver:
- def __init__(self, resolver, localconfig):
- self.resolver = resolver
- self.config = localconfig
- self.zdb = self.config.zonedatabase
- self.last_reap_time = time.time()
- self.maint_int = 10
- self.slavesupdating = []
- self.notifys = []
- self.sentnotify = []
- self.notify_retry_time = 30
- self.notify_retries = 4
- self.askedsoa = {}
- self.soatimeout = 10
-
- def error(self, id, qname, querytype, queryclass, rcode):
- error = message()
- error.header.id = id
- error.header.rcode = rcode
- error.header.qr = 1
- error.question.qname = qname
- error.question.qtype = querytype
- error.question.qclass = queryclass
- return error
-
- def need_zonetransfer(self, zkey, origin, masterip, trynum=0):
- self.askedsoa[zkey] = {'masterip':masterip,
- 'senttime':time.time(),
- 'origin':origin,
- 'trynum':trynum+1}
- query = message()
- query.header.id = random.randrange(1,32768)
- query.header.rd = 0
- query.question.qname = origin
- query.question.qtype = 'SOA'
- query.question.qclass = 'IN'
- log(3,'slave checking for new data in ' + origin)
- simpleudprequest(query, self.handle_soaquery,
- masterip, zkey)
-
- def handle_soaquery(self, msg, zkey):
- origin = msg.question.qname
- masterip = self.askedsoa[zkey]['masterip']
- del self.askedsoa[zkey]
- if zkey not in self.slavesupdating:
- self.slavesupdating.append(zkey)
- query = message()
- query.header.id = random.randrange(1,32768)
- query.header.rd = 0
- query.question.qname = origin
- query.question.qtype = 'AXFR'
- query.question.qclass = 'IN'
- log(3,'Updating slave zone: ' + zkey)
- simpletcprequest(query, self.handle_zonetrans,
- [zkey],masterip,self.handle_zterror)
-
- def handle_zonetrans(self, rrlist, params):
- log(1,'handling zone transfer')
- zonekey = params[0]
- self.zdb.update_zone(rrlist, params)
- self.slavesupdating.remove(zonekey)
-
- def handle_zterror(self, zonekey):
- self.slavesupdating.remove(zonekey)
- self.zdb.remove_zone(zonekey)
-
- def rrmatch(self, rrset1, rrset2):
- for rrtype in rrset1.keys():
- if rrtype not in rrset2.keys():
- return
- else:
- if len(rrset1[rrtype]) != len(rrset2[rrtype]):
- return
- return 1
-
- def process_notify(self, msg, ipaddr, port):
- (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port)
- goodzkey = ''
- for zkey in zkeys:
- origin = self.zdb.getorigin(zkey)
- if origin == msg.question.qname:
- masterip = self.zdb.getmasterip(zkey)
- if masterip:
- goodzkey = zkey
- if goodzkey:
- log(3,'got NOTIFY from ' + masterip)
- self.need_zonetransfer(goodzkey, origin, masterip, 0)
- return
-
- def notify(self):
- curtime = time.time()
- for origin, ipaddr, trynum, senttime in self.sentnotify:
- if senttime + self.notify_retry_time > curtime:
- self.notifys.append((origin, ipaddr, trynum))
- self.sentnotify.remove((origin, ipaddr, trynum, senttime))
- for origin, ipaddr, trynum in self.notifys:
- msg = message()
- msg.question.qname = origin
- msg.question.qtype = 'SOA'
- msg.question.qclass = 'IN'
- msg.header.opcode = 4
- # there probably is a better way to do this
- if self.resolver:
- self.resolver.send_to([msg],(ipaddr,53))
- if trynum+1 <= self.notify_retries:
- self.sentnotify.append((origin,ipaddr,trynum+1,curtime))
- self.notifys = []
-
- def handle_packet(self, msgdata, addr, server):
- # self.reap()
- try:
- msg = message(msgdata)
- except:
- return
- # find a matching view
- (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1])
- if not msg.header.qr and msg.header.opcode == 5:
- log(2,'GOT UPDATE PACKET')
- # check the zone section
- if (msg.header.zocount != 1 or
- msg.zone.ztype != 'SOA' or
- msg.zone.zclass != 'IN'):
- log(2,'SENDING FORMERR UPDATE ERROR')
- errormsg = self.error(msg.header.id, msg.zone.zname,
- msg.zone.ztype, msg.zone.zclass, 1)
- server.sendpackets([errormsg],addr)
- else:
- (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self)
- if answer.header.rcode == 0:
- # schedule NOTIFYs to slaves
- for ipaddr in slaves:
- self.notifys.append((origin, ipaddr, 0))
- server.sendpackets([answer],addr)
- elif msg.header.opcode == 4:
- if msg.header.qr:
- log(0,'got NOTIFY response')
- for origin, ipaddr, trynum, senttime in self.sentnotify:
- if ipaddr == addr[0] and msg.question.qname == origin:
- self.sentnotify.remove((origin, ipaddr, trynum, senttime))
- else:
- log(0,'got NOTIFY')
- self.process_notify(msg, addr[0], addr[1])
- elif not msg.header.qr and msg.header.opcode == 0:
- # it's a question
- qname = msg.question.qname.lower()
- log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype +
- ') from ' + addr[0])
- # handle special version packet
- if (msg.question.qtype == 'TXT' and
- msg.question.qclass == 'CH'):
- if qname == 'version.bind':
- server.sendpackets([getversion(qname,
- msg.header.id,
- msg.header.rd,
- dorecursion, '1.0')],addr)
- elif qname == 'version.oak':
- server.sendpackets([getversion(qname,
- msg.header.id,
- msg.header.rd,
- dorecursion, '1.0')],addr)
- return
- self.zdb.lookup(zkeys, msg, addr, server, dorecursion,
- flist, self.lookup_callback)
-
- def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist):
- if answerlist:
- server.sendpackets(self.config.outpackets(answerlist), addr)
- elif dorecursion:
- if msg.question.qtype in ['AXFR','IXFR']:
- if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR':
- if self.resolver:
- server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr)
- else:
- # won't forward zone transfers and
- # don't handle recursive zone transfers
- server.sendpackets([self.error(msg.header.id, msg.question.qname,
- msg.question.qtype,
- msg.question.qclass,2)],addr)
- else:
- self.resolver.handle_query(msg, addr, flist, server.sendpackets)
-
- def reap(self):
- log(4,'in nameserver reap')
- # do all maintenence (interval) stuff here
- if self.resolver:
- self.resolver.reap()
- self.notify()
- curtime = time.time()
- if curtime > (self.last_reap_time + self.maint_int):
- self.last_reap_time = curtime
- # do zone transfers here if slave server and haven't asked for soa
- for (zkey, origin, masterip) in self.zdb.getslaves(curtime):
- if not self.askedsoa.has_key(zkey):
- self.need_zonetransfer(zkey, origin, masterip)
- for zkey in self.askedsoa.keys():
- if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout:
- if self.askedsoa[zkey]['trynum'] > 3:
- self.zdb.remove_zone(zkey)
- del self.askedsoa[zkey]
- else:
- masterip = self.askedsoa[zkey]['masterip']
- origin = self.askedsoa[zkey]['origin']
- trynum = self.askedsoa[zkey]['trynum']
- del self.askedsoa[zkey]
- self.need_zonetransfer(zkey, origin, masterip, trynum)
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-class resolver(asyncore.dispatcher):
- def __init__(self, cache, port=0):
- asyncore.dispatcher.__init__(self)
- self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
- self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
- self.bind(('',port))
- self.cache = cache
- self.outqnum = 0
- self.outq = {}
- self.holdq = {}
- self.holdtime = 10
- self.holdqlength = 100
- self.last_reap_time = time.time()
- self.maint_int = 10
- self.timeout = 3
-
- def getoutqkey(self):
- self.outqnum = self.outqnum + 1
- if self.outqnum == 99999:
- self.outqnum = 1
- return str(self.outqnum)
-
- def error(self, id, qname, querytype, queryclass, rcode):
- error = message()
- error.header.id = id
- error.header.rcode = rcode
- error.header.qr = 1
- error.question.qname = qname
- error.question.qtype = querytype
- error.question.qclass = queryclass
- return error
-
- def qpacket(self, id, qname, querytype, queryclass):
- # create a question
- query = message()
- query.header.id = id
- query.header.rd = 0
- query.question.qname = qname
- query.question.qtype = querytype
- query.question.qclass = queryclass
- return query
-
- def send_to(self, msglist, addr):
- for msg in msglist:
- data = msg.buildpkt()
- if len(data) > 512:
- # packet to big
- msg.header.tc = 1
- msg.header.ancount = 0
- msg.answerlist = []
- msg.header.nscount = 0
- msg.authlist = []
- msg.header.arcount = 0
- msg.addlist = []
- self.socket.sendto(msg.buildpkt(), addr)
- else:
- self.socket.sendto(data, addr)
-
- def handle_read(self):
- try:
- while 1:
- msgdata, addr = self.socket.recvfrom(1500)
- # should put 'try' here in production server
- self.handle_packet(msgdata, addr)
- except socket.error, why:
- if why[0] != asyncore.EWOULDBLOCK:
- raise socket.error, why
-
- def handle_packet(self, msgdata, addr):
- try:
- msg = message(msgdata)
- except:
- return
- if not msg.header.qr:
- self.handle_query(msg, addr, [], self.send_to)
- else:
- log(2,'received unsolicited reply')
-
-
- def handle_query(self, msg, addr, flist, cbfunc):
- qname = msg.question.qname
- querytype = msg.question.qtype
- queryclass = msg.question.qclass
- # check the cache first
- answer = self.cache.haskey(qname,querytype,msg)
- if answer:
- cbfunc([answer], addr)
- log(2,'sent answer for ' + qname + '(' + querytype +
- ') from cache')
- else:
- # check if query is already in progess
- for oqkey in self.outq.keys():
- if (self.outq[oqkey]['qname'] == qname and
- self.outq[oqkey]['querytype'] == querytype):
- log(2,'query already in progress for '+qname+'('+querytype+')')
- # put entry in hold queue to try later
- hqrec = {'processtime':time.time()+self.holdtime,
- 'query':msg,'addr':addr,
- 'qname':qname,'querytype':querytype,
- 'queryclass':queryclass,
- 'cbfunc':cbfunc}
- self.putonhold(hqrec)
- return
-
- outqkey = self.getoutqkey()+str(msg.header.id)
- self.outq[outqkey] = {'query':msg,
- 'addr':addr,
- 'qname':qname,
- 'querytype':querytype,
- 'queryclass':queryclass,
- 'cbfunc':cbfunc,
- 'answerlist':[],
- 'addlist':[],
- 'qsent':0}
- if flist:
- self.outq[outqkey]['flist'] = flist
- self.askfns(outqkey)
- else:
- self.askns(outqkey)
-
- def putonhold(self,hqrec):
- hqid = hqrec['qname']+hqrec['querytype']
- if self.holdq.has_key(hqid):
- if len(self.holdq[hqid]) < self.holdqlength:
- hqrec['processtime']=time.time()+self.holdtime
- self.holdq[hqid].append(hqrec)
-
-
- def askns(self, outqkey):
- qname = self.outq[outqkey]['qname']
- querytype = self.outq[outqkey]['querytype']
- queryclass = self.outq[outqkey]['queryclass']
- # don't try more than 10 times to avoid loops
- if self.outq[outqkey]['qsent'] == 10:
- del self.outq[outqkey]
- log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
- ' POSSIBLE LOOP')
- return
- # find the best nameservers to ask from the cache
- (qzone, nsdict) = self.cache.getnslist(qname)
- if not nsdict:
- # there are no good servers
- if self.outq[outqkey]['addr'] != 'IQ':
- qid = self.outq[outqkey]['query'].header.id
- self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
- self.outq[outqkey]['addr'])
- del self.outq[outqkey]
- log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
- 'no good name servers to ask')
- return
- # pick the best nameserver
- rtts = nsdict.keys()
- rtts.sort()
- bestnsip = nsdict[rtts[0]]['ip']
- bestnsname = nsdict[rtts[0]]['name']
- # fill in the callback data structure
- id=random.randrange(1,32768)
- self.outq[outqkey]['nsqueriedlastip'] = bestnsip
- self.outq[outqkey]['nsqueriedlastname'] = bestnsname
- self.outq[outqkey]['nsdict'] = nsdict
- self.outq[outqkey]['qzone'] = qzone
- self.outq[outqkey]['qsenttime'] = time.time()
- self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1
- # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53))
- self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
- self.handle_response, bestnsip, outqkey)
- # update rtt so that we ask a different server next time
- self.cache.updatertt(bestnsname,qzone,1)
- log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname +
- ') for ' + qname + '(' + querytype + ')')
-
- def askfns(self, outqkey):
- flist = self.outq[outqkey]['flist']
- qname = self.outq[outqkey]['qname']
- querytype = self.outq[outqkey]['querytype']
- queryclass = self.outq[outqkey]['queryclass']
- self.outq[outqkey]['qsenttime'] = time.time()
- id=random.randrange(1,32768)
- # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53))
- self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
- self.handle_fresponse, flist[0], outqkey)
- log(2,''+outqkey+'|sent query to forwarder')
-
- def handle_response(self, msg, outqkey):
- # either reponse:
- # 1. contains a name error
- # 2. answers the question
- # (cache data and return it)
- # 3. is (contains) a CNAME and qtype isn't
- # (cache cname and change qname to it)
- # (check if qname and qtype are in any other rrs in the response)
- # (must check cache again here)
- # 4. contains a better delegation
- # (cache the delegation and start again)
- # 5. is aserver failure
- # (delete server from list and try again)
-
- # make sure that original question is still outstanding
- if not self.outq.has_key(outqkey):
- # should never get here
- # if we do we aren't doing housekeeping of callbacks very well
- log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname)
- return
-
- querytype = self.outq[outqkey]['querytype']
- if msg.header.rcode not in [1,2,4,5]:
- # update rtt time
- rtt = time.time() - self.outq[outqkey]['qsenttime']
- nsname = self.outq[outqkey]['nsqueriedlastname']
- zone = self.outq[outqkey]['qzone']
- self.cache.updatertt(nsname,zone,rtt)
-
- if msg.header.rcode == 3:
- log(2,outqkey+'|GOT Name Error for ' + msg.question.qname +
- '(' + msg.question.qtype + ')')
- # name error
- # cache negative answer
- self.cache.addneg(self.outq[outqkey]['qname'],
- self.outq[outqkey]['querytype'],
- self.outq[outqkey]['queryclass'])
- if self.outq[outqkey]['addr'] != 'IQ':
- answer = message()
- answer.question.qname = self.outq[outqkey]['query'].question.qname
- answer.question.qtype = self.outq[outqkey]['query'].question.qtype
- answer.question.qclass = self.outq[outqkey]['query'].question.qclass
- answer.header.id = self.outq[outqkey]['query'].header.id
- answer.header.qr = 1
- answer.header.opcode = self.outq[outqkey]['query'].header.opcode
- answer.header.ra = 1
- self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
- del self.outq[outqkey]
-
- elif msg.header.ancount > 0:
- # answer (may be CNAME)
- haveanswer = 0
- cname = ''
- log(2,'CACHING ANSWERLIST ENTRIES')
- for rr in msg.answerlist:
- rrname = rr.keys()[0]
- rrtype = rr[rrname].keys()[0]
- if ((rrname == msg.question.qname or rrname == cname ) and
- rrtype == msg.question.qtype):
- haveanswer = 1
- if rrname == msg.question.qname and rrtype == 'CNAME':
- cname = rr[rrname][rrtype][0]['cname']
- self.cache.add(rr, self.outq[outqkey]['qzone'],
- self.outq[outqkey]['nsqueriedlastname'])
- if haveanswer:
- if self.outq[outqkey]['addr'] != 'IQ':
- log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname +
- '(' + msg.question.qtype + ')' )
- answer = message()
- answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist']
- answer.header.ancount = len(answer.answerlist)
- answer.question.qname = self.outq[outqkey]['query'].question.qname
- answer.question.qtype = self.outq[outqkey]['query'].question.qtype
- answer.question.qclass = self.outq[outqkey]['query'].question.qclass
- answer.header.id = self.outq[outqkey]['query'].header.id
- answer.header.qr = 1
- answer.header.opcode = self.outq[outqkey]['query'].header.opcode
- answer.header.ra = 1
- self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
- log(2,outqkey+'|sent answer retrieved from remote server for ' +
- self.outq[outqkey]['query'].question.qname)
- else:
- log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' +
- msg.question.qtype + ')')
- del self.outq[outqkey]
- elif cname:
- log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')')
- self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist
- self.outq[outqkey]['qname'] = cname
- self.askns(outqkey)
- else:
- log(2,outqkey+'|GOT BOGUS answer for ' + msg.question.qname + '(' +
- msg.question.qtype + ')')
- del self.outq[outqkey]
-
- elif msg.header.nscount > 0 and msg.header.ancount == 0:
- log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')')
- # delegation
- # cache the nameserver rrs and start over
- # if there are no glue records for nameservers must fetch them first
- log(2,'CACHING AUTHLIST ENTRIES')
- for rr in msg.authlist:
- self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
- log(2,'CACHING ADDLIST ENTRIES')
- for rr in msg.addlist:
- self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
- rrlist = msg.authlist+msg.addlist
- fetchglue = 0
- nscount = 0
- for rr in msg.authlist:
- nodename = rr.keys()[0]
- if rr[nodename].keys()[0] == 'NS':
- nscount = nscount + 1
- nsdname = rr[nodename]['NS'][0]['nsdname']
- if not self.cache.haskey(nsdname,'A'):
- log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)')
- fetchglue = fetchglue + 1
- # need to fetch A rec
- noutqkey = self.getoutqkey()+str(random.randrange(1,32768))
- self.outq[noutqkey] = {'query':'',
- 'addr':'IQ',
- 'qname':nsdname,
- 'querytype':'A',
- 'queryclass':'IN',
- 'qsent':0}
- log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)')
- self.askns(noutqkey)
- if not nscount:
- log(2,outqkey+'|Dropping query (no ns recs) for ' +
- msg.question.qname + '(' + msg.question.qtype + ')' )
- del self.outq[outqkey]
- elif fetchglue == nscount:
- log(2,outqkey+'|Stalling query (no glue recs) for ' +
- msg.question.qname + '(' + msg.question.qtype + ')')
- self.putonhold(self.outq[outqkey])
- del self.outq[outqkey]
- else:
- log(2,outqkey+'|got (some) glue with delegation')
- self.askns(outqkey)
-
- elif msg.header.rcode in [1,2,4,5]:
- log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode))
- log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' +
- self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname)
- # don't ask this server for a while
- self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
- self.askns(outqkey)
- else:
- log(2,outqkey+'|GOT UNPARSEABLE REPLY')
- msg.printpkt()
-
- def handle_fresponse(self, msg, outqkey):
- if msg.header.rcode in [1,2,4,5]:
- self.outq[outqkey]['flist'].pop(0)
- if len(self.outq[outqkey]['flist']) == 0:
- qid = self.outq[outqkey]['query'].header.id
- qname = self.outq[outqkey]['qname']
- querytype = self.outq[outqkey]['querytype']
- queryclass = self.outq[outqkey]['queryclass']
- self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
- self.outq[outqkey]['addr'])
- del self.outq[outqkey]
- else:
- self.askfns(outqkey)
- else:
- answer = message()
- answer.header.id = self.outq[outqkey]['query'].header.id
- answer.header.qr = 1
- answer.header.opcode = self.outq[outqkey]['query'].header.opcode
- answer.header.ra = 1
- answer.question.qname = self.outq[outqkey]['query'].question.qname
- answer.question.qtype = self.outq[outqkey]['query'].question.qtype
- answer.question.qclass = self.outq[outqkey]['query'].question.qclass
- answer.header.ancount = msg.header.ancount
- answer.header.nscount = msg.header.nscount
- answer.header.arcount = msg.header.arcount
- answer.answerlist = msg.answerlist
- answer.authlist = msg.authlist
- answer.addlist = msg.addlist
- if msg.header.rcode == 3:
- # name error
- # cache negative answer
- self.cache.addneg(self.outq[outqkey]['qname'],
- self.outq[outqkey]['querytype'],
- self.outq[outqkey]['queryclass'])
- else:
- # cache all rrs
- for rr in msg.answerlist:
- self.cache.add(rr,'','forwarder')
- for rr in msg.authlist:
- self.cache.add(rr,'','forwarder')
- for rr in msg.addlist:
- self.cache.add(rr,'','forwarder')
- self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
- del self.outq[outqkey]
-
- def writable(self):
- return 0
-
- def handle_write(self):
- pass
-
- def handle_connect(self):
- pass
-
- def handle_close(self):
- # print '1:In handle close'
- return
-
- def process_holdq(self):
- curtime = time.time()
- for hqkey in self.holdq.keys():
- for hqrec in self.holdq[hqkey]:
- if curtime >= hqrec['processtime']:
- log(2,'processing held query')
- answer = self.cache.haskey(hqrec['qname'],
- hqrec['querytype'],
- hqrec['query'])
- if answer:
- hqrec['cbfunc']([answer], hqrec['addr'])
- log(2,'sent answer for ' + hqrec['qname'] +
- '(' + hqrec['querytype'] + ') from cache')
- self.holdq[hqkey].remove(hqrec)
- if len(self.holdq[hqkey]) == 0:
- del self.holdq[hqkey]
-
- def reap(self):
- self.process_holdq()
- curtime = time.time()
- log(3,timestamp() + 'processed HOLDQ (sockets: ' +
- str(len(asyncore.socket_map.keys()))+')')
- if curtime > (self.last_reap_time + self.maint_int):
- self.last_reap_time = curtime
- for outqkey in self.outq.keys():
- if curtime > self.outq[outqkey]['qsenttime'] + self.timeout:
- log(2,'query for '+self.outq[outqkey]['qname']+'('+
- self.outq[outqkey]['querytype']+') expired')
- # don't set forwarders as bad
- if not self.outq[outqkey].has_key('flist'):
- self.cache.badns(self.outq[outqkey]['qzone'],
- self.outq[outqkey]['nsqueriedlastname'])
- if self.outq[outqkey].has_key('request'):
- log(3,'closing socket for expired query')
- self.outq[outqkey]['request'].close()
- del self.outq[outqkey]
- return
-
- def log_info (self, message, type='info'):
- if __debug__ or type != 'info':
- log(0,'%s: %s' % (type, message))
-
-
-def run(configobj):
- global loglevel
- r = resolver(dnscache(configobj.cached))
- ns = nameserver(r, configobj)
- udpds = udpdnsserver(53, ns)
- tcpds = tcpdnsserver(53, ns)
- loglevel = configobj.loglevel
- try:
- loop(ns.reap)
- except KeyboardInterrupt:
- print 'server done'
-
-if __name__ == '__main__':
- sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
- zonedict = {'example.net':{'origin':'example.net',
- 'filename':'db.example.net',
- 'type':'master',
- 'slaves':[]}}
-
-
- zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu',
- 'filename':'db.servers.csail.mit.edu',
- 'type':'master',
- 'slaves':[]}}
-
- zonedict2 = {'example.net':{'origin':'example.net',
- 'filename':'db.example.net',
- 'type':'slave',
- 'masterip':'127.0.0.1'}}
- readzonefiles(zonedict)
- lconfig = dnsconfig()
- lconfig.zonedatabase = zonedb(zonedict)
- pr = zonefileparser()
- pr.parse('','db.ca')
- lconfig.cached = pr.getzdict()
- lconfig.loglevel = 3
-
- run(lconfig)