658a6623136997b6d3e6cb3acab6925bfad71862
[invirt/packages/invirt-dns.git] / nameserver.py
1 #!/usr/bin/python
2 #    Python Domain Name Server
3 #    Copyright (C) 2002  Digital Lumber, Inc.
4
5 #    This library is free software; you can redistribute it and/or
6 #    modify it under the terms of the GNU Lesser General Public
7 #    License as published by the Free Software Foundation; either
8 #    version 2.1 of the License, or (at your option) any later version.
9
10 #    This library is distributed in the hope that it will be useful,
11 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
12 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 #    Lesser General Public License for more details.
14
15 #    You should have received a copy of the GNU Lesser General Public
16 #    License along with this library; if not, write to the Free Software
17 #    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18
19 import socket
20 import asyncore
21 import asynchat
22 import select
23 import types
24 import random
25 import time
26 import signal
27 import string
28 import sys
29 import sipb_xen_database
30 from errno import EALREADY, EINPROGRESS, EWOULDBLOCK, ECONNRESET, \
31      ENOTCONN, ESHUTDOWN, EINTR, EISCONN, ETIMEDOUT
32
33 # EXAMPLE ZONE FILE DATA STRUCTURE
34
35 # NOTE:
36 # There are no trailing dots in the internal data
37 # structure.  Although it's hard to tell by reading
38 # the RFC's, the dots on the end of names are just
39 # used internally by the resolvers and servers to
40 # see if they need to append a domain name onto
41 # the end of names.  There are no trailing dots
42 # on names in queries on the network.
43
44 examplenet = {'example.net':{'SOA':[{'class':'IN',
45                                      'ttl':10,
46                                      'mname':'ns1.example.net',
47                                      'rname':'hostmaster.example.net',
48                                      'serial':1,
49                                      'refresh':10800,
50                                      'retry':3600,
51                                      'expire':604800,
52                                      'minimum':3600}],
53                              'NS':[{'class':'IN',
54                                     'ttl':10,
55                                     'nsdname':'ns1.example.net'},
56                                    {'ttl':10,
57                                     'nsdname':'ns2.example.net'}],
58                              'MX':[{'class':'IN',
59                                     'ttl':10,
60                                     'preference':10,
61                                     'exchange':'mail.example.net'}]},
62               'server1.example.net':{'A':[{'class':'IN',
63                                            'ttl':10,
64                                            'address':'10.1.2.3'}]},
65               'www.example.net':{'CNAME':[{'class':'IN',
66                                            'ttl':10,
67                                            'cname':'server1.example.net'}]},
68               'router.example.net':{'A':[{'class':'IN',
69                                           'ttl':10,
70                                           'address':'10.1.2.1'},
71                                          {'class':'IN',
72                                           'ttl':10,
73                                           'address':'10.2.1.1'}]}
74               
75               }
76
77 # setup logging defaults
78 loglevel = 0
79 logfile = sys.stdout
80
81 try:
82     file
83 except NameError:
84     def file(name, mode='r', buffer=0):
85         return open(name, mode, buffer)
86
87 def log(level,msg):
88     if level <= loglevel:
89         logfile.write(msg+'\n')
90
91 def timestamp():
92     return time.strftime('%m/%d/%y %H:%M:%S')+ '-'
93
94 def inttoasc(number):
95     try:
96         hs = hex(number)[2:]
97     except:
98         log(0,'inttoasc cannot convert ' + repr(number))
99     if hs[-1:].upper() == 'L':
100         hs = hs[:-1]
101     result = ''
102     while len(hs) > 2:
103         result = chr(int(hs[-2:],16)) + result
104         hs = hs[:-2]
105     result = chr(int(hs,16)) + result
106     
107     return result
108     
109 def asctoint(ascnum):
110     rascnum = ''
111     for i in range(len(ascnum)-1,-1,-1):
112         rascnum = rascnum + ascnum[i]
113     result = 0
114     count = 0
115     for c in rascnum:
116         x = ord(c) << (8*count)
117         result = result + x
118         count = count + 1
119
120     return result
121
122 def ipv6net_aton(ip_string):
123     packed_ip = ''
124     # first account for shorthand syntax
125     pieces = ip_string.split(':')
126     pcount = 0
127     for part in pieces:
128         if part != '':
129             pcount = pcount + 1
130     if pcount < 8:
131         rs = '0:'*(8-pcount)
132         ip_string = ip_string.replace('::',':'+rs)
133     if ip_string[0] == ':':
134         ip_string = ip_string[1:]
135     pieces = ip_string.split(':')
136     for part in pieces:
137         # pad with the zeros
138         i = 4-len(part)
139         part = i*'0'+part
140         packed_ip = packed_ip +  chr(int(part[:2],16))+ chr(int(part[2:],16))
141     return packed_ip
142
143 def ipv6net_ntoa(packed_ip):
144     ip_string = ''
145     count = 0
146     for c in packed_ip:
147         ip_string = ip_string + hex(ord(c))[2:]
148         count = count + 1
149         if count == 2:
150             ip_string = ip_string + ':'
151             count = 0
152     return ip_string[:-1]    
153
154 def getversion(qname, id, rd, ra, versionstr):
155     msg = message()
156     msg.header.id = id
157     msg.header.qr = 1
158     msg.header.aa = 1
159     msg.header.rd = rd
160     msg.header.ra = ra
161     msg.header.rcode = 0
162     msg.question.qname = qname
163     msg.question.qtype = 'TXT'
164     msg.question.qclass = 'CH'
165     if qname == 'version.bind':
166         msg.header.ancount = 2
167         msg.answerlist.append({qname:{'CNAME':[{'cname':'version.oak',
168                                                 'ttl':360000,
169                                                 'class':'CH'}]}})
170         msg.answerlist.append({'version.oak':{'TXT':[{'txtdata':versionstr,
171                                                       'ttl':360000,
172                                                       'class':'CH'}]}})
173     else:
174         msg.header.ancount = 1
175         msg.answerlist.append({qname:{'TXT':[{'txtdata':versionstr,
176                                               'ttl':360000,
177                                               'class':'CH'}]}})
178     return msg
179
180 def getrcode(rcode):
181     if rcode == 0:
182         rcodestr = 'NOERROR(No error condition)'
183     elif rcode == 1:
184         rcodestr = 'FORMERR(Format Error)'
185     elif rcode == 2:
186         rcodestr = 'SERVFAIL(Internal failure)'
187     elif rcode == 3:
188         rcodestr = 'NXDOMAIN(Name does not exist)'
189     elif rcode == 4:
190         rcodestr = 'NOTIMP(Not Implemented)'
191     elif rcode == 5:
192         rcodestr = 'REFUSED(Security violation)'
193     elif rcode == 6:
194         rcodestr = 'YXDOMAIN(Name exists)'
195     elif rcode == 7:
196         rcodestr = 'YXRRSET(RR exists)'
197     elif rcode == 8:
198         rcodestr = 'NXRRSET(RR does not exist)'
199     elif rcode == 9:
200         rcodestr = 'NOTAUTH(Server not Authoritative)'
201     elif rcode == 10:
202         rcodestr = 'NOTZONE(Name not in zone)'
203     else:
204         rcodestr = 'Unknown RCODE(' + str(rcode) + ')'
205     return rcodestr
206
207 def printrdata(dnstype, rdata):
208     if dnstype == 'A':
209         return rdata['address']
210     elif dnstype == 'MX':
211         return str(rdata['preference'])+'\t'+rdata['exchange']+'.'
212     elif dnstype == 'NS':
213         return rdata['nsdname']+'.'
214     elif dnstype == 'PTR':
215         return rdata['ptrdname']+'.'
216     elif dnstype == 'CNAME':
217         return rdata['cname']+'.'
218     elif dnstype == 'SOA':
219         return (rdata['mname']+'.\t'+rdata['rname']+'. (\n'+35*' '+str(rdata['serial'])+'\n'+
220                 35*' '+str(rdata['refresh'])+'\n'+35*' '+str(rdata['retry'])+'\n'+35*' '+
221                 str(rdata['expire'])+'\n'+35*' '+str(rdata['minimum'])+' )')
222
223 def makezonedatalist(zonedata, origin):
224     # unravel structure into list
225     zonedatalist = []
226     # get soa first
227     soanode = zonedata[origin]
228     zonedatalist.append([origin+'.','SOA',soanode['SOA'][0]])
229     for item in soanode.keys():
230         if item != 'SOA':
231             for listitem in soanode[item]:
232                 zonedatalist.append([origin+'.', item, listitem])
233     for nodename in zonedata.keys():
234         if nodename != origin:
235             for item in zonedata[nodename].keys():
236                 for listitem in zonedata[nodename][item]:
237                     zonedatalist.append([nodename+'.', item, listitem])
238     return zonedatalist
239
240 def writezonefile(zonedata, origin, file):
241     zonedatalist = makezonedatalist(zonedata, origin)
242     for rr in zonedatalist:
243         owner = rr[0]
244         dnstype = rr[1]
245         line = (owner + (35-len(owner))*' ' + str(rr[2]['ttl']) + '\t\tIN\t' +
246                 dnstype + '\t' + printrdata(dnstype, rr[2]))
247         file.write(line + '\n')
248
249 def readzonefiles(zonedict):
250     for k in zonedict.keys():
251         filepath = zonedict[k]['filename']
252         try:
253             pr = zonefileparser()
254             pr.parse(zonedict[k]['origin'],filepath)
255             zonedict[k]['zonedata'] = pr.getzdict()
256         except ZonefileError, lineno:
257             log(0,'Error reading zone file ' + filepath  + ' at line ' +
258                 str(lineno) + '\n')
259             del zonedict[k]
260
261 def slowloop(tofunc='',timeout=5.0):
262     if not tofunc:
263         def tofunc(): return
264     map = asyncore.socket_map
265     while map:
266         r = []; w=[]; e=[]
267         for fd, obj in map.items():
268             if obj.readable():
269                 r.append(fd)
270             if obj.writable():
271                 w.append(fd)
272         try:
273             starttime = time.time()            
274             r,w,e = select.select(r,w,e,timeout)
275             endtime = time.time()
276             if endtime-starttime >= timeout:
277                 tofunc()
278         except select.error, err:
279             if err[0] != EINTR:
280                 raise
281             r=[]; w=[]; e=[]
282             log(0,'ERROR in select')
283
284         for fd in r:
285             try:
286                 obj=map[fd]
287             except KeyError:
288                 log(0,'KeyError in socket map')                
289                 continue
290             try:
291                 obj.handle_read_event()
292             except:
293                 log(0,'calling HANDLE ERROR from loop')
294                 log(0,repr(obj))
295                 obj.handle_error()
296         for fd in w:
297             try:
298                 obj=map[fd]
299             except KeyError:
300                 log(0,'KeyError in socket map')                
301                 continue
302             try:
303                 obj.handle_read_event()
304             except:
305                 log(0,'calling HANDLE ERROR from loop')
306                 log(0,repr(obj))                
307                 obj.handle_error()
308
309 def fastloop(tofunc='',timeout=5.0):
310     if not tofunc:
311         def tofunc(): return
312     polltimeout = timeout*1000
313     map = asyncore.socket_map
314     while map:
315         regfds = 0
316         pollobj = select.poll()
317         for fd, obj in map.items():
318             flags = 0
319             if obj.readable():
320                 flags = select.POLLIN
321             if obj.writable():
322                 flags = flags | select.POLLOUT
323             if flags:
324                 pollobj.register(fd, flags)
325                 regfds = regfds + 1
326         try:
327             starttime = time.time()
328             r = pollobj.poll(polltimeout)
329             endtime = time.time()
330             if endtime-starttime >= timeout:
331                 tofunc()
332         except select.error, err:
333             if err[0] != EINTR:
334                 raise
335             r = []
336             log(0,'ERROR in select')
337         for fd, flags in r:
338             try:
339                 obj = map[fd]
340                 badvals = (select.POLLPRI + select.POLLERR +
341                            select.POLLHUP + select.POLLNVAL)
342                 if (flags & badvals):
343                     if (flags & select.POLLPRI):
344                         log(0,'POLLPRI')
345                     if (flags & select.POLLERR):
346                         log(0,'POLLERR')
347                     if (flags & select.POLLHUP):
348                         log(0,'POLLHUP')
349                     if (flags & select.POLLNVAL):
350                         log(0,'POLLNVAL')
351                     obj.handle_error()
352                 else:
353                     if (flags  & select.POLLIN):
354                         obj.handle_read_event()
355                     if (flags & select.POLLOUT):
356                         obj.handle_write_event()
357             except KeyError:
358                 log(0,'KeyError in socket map')
359                 continue
360             except:
361                 # print traceback
362                 sf = StringIO.StringIO()
363                 traceback.print_exc(file=sf)
364                 log(0,'ERROR IN LOOP:')
365                 log(0,sf.getvalue())
366                 sf.close()
367                 log(0,repr(obj))
368                 obj.handle_error()
369
370 if hasattr(select,'poll'):
371     loop = fastloop
372 else:
373     loop = slowloop
374
375 class ZonefileError(Exception):
376     def __init__(self, linenum, errordesc=''):
377         self.linenum = linenum
378         self.errordesc = errordesc
379     def __str__(self):
380         return str(self.linenum) + ' (' + self.errordesc + ')'
381
382 class zonefileparser:
383     def __init__(self):
384         self.zonedata = {}
385         self.dnstypes = ['A','AAAA','CNAME','HINFO','LOC','MX',
386                          'NS','PTR','RP','SOA','SRV','TXT']
387         
388     def stripcomments(self, line):
389         i = line.find(';')
390         if i >= 0:
391             line = line[:i]
392         return line
393
394     def strip(self, line):
395         # strip trailing linefeeds
396         if line[-1:] == '\n':
397             line = line[:-1]
398         return line
399
400     def getzdict(self):
401         return self.zonedata
402
403     def addorigin(self, origin, name):
404         if name[-1:] != '.':
405             return name + '.' + origin
406         else:
407             return name[:-1]
408
409     def getstrings(self, s):
410         if s.find('"') == -1:
411             return s.split()
412         else:
413             x = s.split('"')
414             rlist = []
415             for i in x:
416                 if i != '' and i != ' ':
417                     rlist.append(i)
418             return rlist
419
420     def getlocsize(self, s):
421         if s[-1:] == 'm':
422             size = float(s[:-1])*100
423         else:
424             size = float(s)*100
425         i = 0
426         while size > 9:
427             size = size/10
428             i = i + 1
429         return (int(size),i)
430
431     def getloclat(self, l,c):
432         deg = float(l[0])
433         min = 0
434         secs = 0
435         if len(l) == 3:
436             min = float(l[1])
437             secs = float(l[2])
438         elif len(l) == 2:
439             min = float(l[1])
440         rval = ((((deg *60) + min) * 60) + secs) * 1000
441         if c in ['N','E']:
442             rval = rval + (2**31)
443         elif c in ['S','W']:
444             rval = (2**31) - rval
445         else:
446             log(0,'ERROR: unsupported latitude/longitude direction')
447         return long(rval)
448
449     def getgname(self, name, iter):
450         if name == '0' or name == 'O':
451             return ''
452         start = 0
453         offset = 0
454         width = 0
455         base = 'd'
456         for x in range(name.count('$')):
457             i = name.find('$',start)
458             j = i
459             start = i+1
460             if i>0:
461                 if name[i-1] == '\\':
462                     continue
463             if len(name)>i+1:
464                 if name[i+1] == '$':
465                     continue
466                 if name[i+1] == '{':
467                     j = name.find('}',i+1)
468                     owb = name[i+2:j].split(',')
469                     if len(owb) == 1:
470                         offset = int(owb[0])
471                     elif len(owb) == 2:
472                         offset = int(owb[0])
473                         width = int(owb[1])
474                     elif len(owb) == 3:
475                         offset = int(owb[0])
476                         width = int(owb[1])
477                         base = owb[2]
478             val = iter - offset
479             if base == 'd':
480                 rs = str(val)
481             elif base == 'o':
482                 rs = oct(val)
483             elif base == 'x':
484                 rs = hex(val)[2:].lower()
485             elif base == 'X':
486                 rs = hex(val)[2:].upper()
487             else:
488                 rs = ''
489             if len(rs) > width:
490                 rs = (width-len(rs))*'0'+rs
491             name = name[:i]+rs+name[j+1:]
492             start = i+len(rs)+1
493
494         return name
495
496     def getrrdata(self, origin, dnstype, dnsclass, ttl, tokens):
497         rdata = {}
498         rdata['class'] = dnsclass
499         rdata['ttl'] = ttl
500         if dnstype == 'A':
501             rdata['address'] = tokens[0]
502         elif dnstype == 'AAAA':
503             rdata['address'] = tokens[0]
504         elif dnstype == 'CNAME':
505             rdata['cname'] = self.addorigin(origin,tokens[0].lower())
506         elif dnstype == 'HINFO':
507             sl = self.getstrings(' '.join(tokens))
508             rdata['cpu'] = sl[0]
509             rdata['os'] = sl[1]
510         elif dnstype == 'LOC':
511             if 'N' in tokens:
512                 i = tokens.index('N')
513             else:
514                 i = tokens.index('S')
515             lat = self.getloclat(tokens[0:i],tokens[i])            
516             if 'E' in tokens:
517                 j = tokens.index('E')
518             else:
519                 j = tokens.index('W')
520             lng = self.getloclat(tokens[i+1:j],tokens[j])
521             size = self.getlocsize('1m')
522             horiz_pre = self.getlocsize('1000m')
523             vert_pre = self.getlocsize('10m')
524             if len(tokens[j+1:]) == 2:
525                 size = self.getlocsize(tokens[-1:][0])
526             elif len(tokens[j+1:]) == 3:
527                 size = self.getlocsize(tokens[-2:-1][0])
528                 horiz_pre = self.getlocsize(tokens[-1:][0])
529             elif len(tokens[j+1:]) == 4:
530                 size = self.getlocsize(tokens[-3:-2][0])
531                 horiz_pre = self.getlocsize(tokens[-2:-1][0])
532                 vert_pre = self.getlocsize(tokens[-1:][0])
533             if tokens[j+1][-1:] == 'm':
534                 alt = int((float(tokens[j+1][:-1])*100)+10000000)
535             else:
536                 size = int((float(tokens[j+1])*100)+10000000)
537             rdata['version'] = 0
538             rdata['size'] = size
539             rdata['horiz_pre'] = horiz_pre
540             rdata['vert_pre'] = vert_pre
541             rdata['latitude'] = lat
542             rdata['longitude'] = lng
543             rdata['altitude'] = 0
544         elif dnstype == 'MX':
545             rdata['preference'] = int(tokens[0])
546             rdata['exchange'] = self.addorigin(origin,tokens[1].lower())
547         elif dnstype == 'NS':
548             rdata['nsdname'] = self.addorigin(origin,tokens[0].lower())
549         elif dnstype == 'PTR':
550             rdata['ptrdname'] = self.addorigin(origin,tokens[0].lower())
551         elif dnstype == 'RP':
552             rdata['mboxdname'] = self.addorigin(origin,tokens[0].lower())
553             rdata['txtdname'] = self.addorigin(origin,tokens[1].lower())
554         elif dnstype == 'SOA':
555             rdata['mname'] = self.addorigin(origin,tokens[0].lower())
556             rdata['rname'] = self.addorigin(origin,tokens[1].lower())
557             rdata['serial'] = int(tokens[2])
558             rdata['refresh'] = int(tokens[3])
559             rdata['retry'] = int(tokens[4])
560             rdata['expire'] = int(tokens[5])
561             rdata['minimum'] = int(tokens[6])
562         elif dnstype == 'SRV':
563             rdata['priority'] = int(tokens[0])
564             rdata['weight'] = int(tokens[1])
565             rdata['port'] = int(tokens[2])
566             rdata['target'] = self.addorigin(origin,tokens[3].lower())
567         elif dnstype == 'TXT':
568             rdata['txtdata'] = self.getstrings(' '.join(tokens))[0]
569         else:
570             raise ZonefileError(lineno,'bad DNS type')            
571         return rdata
572
573     def addrec(self, owner, dnstype, rrdata):
574         if self.zonedata.has_key(owner):
575             if not self.zonedata[owner].has_key(dnstype):
576                 self.zonedata[owner][dnstype] = []
577         else:
578             self.zonedata[owner] = {}
579             self.zonedata[owner][dnstype] = []
580         self.zonedata[owner][dnstype].append(rrdata)
581
582     def parse(self, origin, f):
583         closefile = 0
584         if type(f) != types.FileType:
585             # must be a path
586             try:
587                 f = file(f)
588                 closefile = 1
589             except:
590                 log(0,'Invalid path to zonefile')
591                 return
592         lastowner = ''
593         lastdnsclass = ''
594         lastttl = 3600
595         lineno = 0
596         while 1:
597             line = f.readline()
598             if not line:
599                 break
600             lineno = lineno + 1
601             line = self.stripcomments(line)
602             line = self.strip(line)
603             if not line:
604                 continue
605             if line.find('(') >= 0:
606                 # grab lines until end paren
607                 if line.find(')') == -1:
608                     line2 = self.stripcomments(f.readline())
609                     lineno = lineno + 1
610                     line2 = self.strip(line2)
611                     line = line + line2
612                     while line2.find(')') == -1:
613                         line2 = self.strip(self.stripcomments(f.readline()))
614                         lineno = lineno + 1
615                         line = line + line2
616                 # now strip the parenthesis
617                 line = line.replace(')','')
618                 line = line.replace('(','')
619             # now line equals the entire RR entry
620             tokens = line.split()
621             if tokens[0].upper() == '$ORIGIN':
622                 try:
623                     origin = tokens[1].lower()
624                 except:
625                     raise ZonefileError(lineno, 'bad origin')
626             elif tokens[0].upper() == '$INCLUDE':
627                 try:
628                     f2 = file(tokens[1].lower())
629                     if len(tokens) > 2:
630                         self.parse(tokens[2].lower(), f2)
631                     else:
632                         self.parse(origin, f2)
633                     f2.close()
634                 except:
635                     raise ZonefileError(lineno, 'bad INCLUDE directive')
636             elif tokens[0].upper() == '$TTL':
637                 try:
638                     lastttl = int(tokens[1])
639                 except:
640                     raise ZonefileError(lineno, 'bad TTL directive')
641             elif tokens[0].upper() == '$GENERATE':
642                 try:
643                     lhs = tokens[2].lower()
644                     dnstype = tokens[3].upper()
645                     rhs = tokens[4].lower()
646                     rng = tokens[1].split('-')                    
647                     start = int(rng[0])
648                     i = rng[1].find('/')
649                     if i != -1:
650                         stop = int(rng[1][:i])+1
651                         step = int(rng[1][i+1:])
652                     else:
653                         stop = int(rng[1])+1
654                         step = 1
655                     for i in range(start,stop,step):
656                         grhs = self.getgname(rhs,i)
657                         if dnstype in ['NS','CNAME','PTR']:
658                             grhs = self.addorigin(origin,grhs)
659                         rrdata = self.getrrdata(origin, dnstype, 'IN', lastttl,
660                                                 [grhs])
661                         glhs = self.addorigin(origin,self.getgname(lhs,i))
662                         self.addrec(glhs,dnstype, rrdata)
663                 except KeyError:
664                     raise ZonefileError(lineno, 'bad GENERATE directive')
665             else:
666                 try:
667                     # if line begins with blank then owner is last owner
668                     if line[0] in string.whitespace:
669                         owner = lastowner
670                     else:
671                         owner = tokens[0].lower()
672                         tokens = tokens[1:]
673                         if owner == '@':
674                             owner = origin
675                         elif owner[-1:] != '.':
676                             owner = owner + '.' + origin
677                         else:
678                             owner = owner[:-1] # strip off trailing dot
679                     # line format is either: [class] [ttl] type RDATA
680                     #                     or [ttl] [class] type RDATA
681                     # - items in brackets are optional
682                     #
683                     # need to figure out which token is type
684                     # and backfill the missing data
685                     count = 0
686                     for token in tokens:
687                         if token.upper() in self.dnstypes:
688                             break
689                         count = count + 1
690                     # the following strips off the ttl and class if they exist
691                     if count == 0:
692                         ttl = lastttl
693                         dnsclass = lastdnsclass
694                     elif count == 1:
695                         if tokens[0].isdigit():
696                             ttl = int(tokens[0])
697                             dnsclass = lastdnsclass
698                         else:
699                             ttl = lastttl
700                             dnsclass = tokens[0].upper()
701                         tokens = tokens[1:]
702                     elif count == 2:
703                         if tokens[0].isdigit():
704                             ttl = int(tokens[0])
705                             dnsclass = tokens[1].upper()
706                         else:
707                             ttl = int(tokens[1])
708                             dnsclass = tokens[0].upper()
709                         tokens = tokens[2:]
710                     else:
711                         raise ZonefileError(lineno,'bad ttl or class')
712                     dnstype = tokens[0]
713                     # make sure all of the structure is there
714                     rrdata = self.getrrdata(origin, dnstype, dnsclass,
715                                             ttl, tokens[1:])
716                     self.addrec(owner, dnstype, rrdata)
717                     lastowner = owner
718                     lastttl = ttl
719                     lastdnsclass = dnsclass
720                 except:
721                     raise ZonefileError(lineno,'unable to parse line')
722         if closefile:
723             f.close()
724         
725 class dnsconfig:
726     def __init__(self):
727         # self.zonedb = zonedb({})
728         self.cached = {}
729         self.loglevel = 0
730         
731     def getview(self, msg, address, port):
732         # return:
733         #  1. a list of zone keys
734         #  2. whether or not to use the resolver
735         #     (i.e. answer recursive queries)
736         #  3. a list of forwarder addresses
737         return ['servers.csail.mit.edu'], 1, []
738
739     def allowupdate(self, msg, address, port):
740         # return 1 if updates are allowed
741         # NOTE: can only update the zones
742         #       returned by the getview func
743         return 1
744
745     def outpackets(self, packetlist):
746         # modify outgoing packets
747         return packetlist
748
749 class dnsheader:
750     def __init__(self, id=1):
751         self.id = id # 16bit identifier generated by queryer
752         self.qr = 0 # one bit field specifying query(0) or response(1)
753         self.opcode = 0 # 4bit field specifying type of query
754         self.aa = 0 # authoritative answer
755         self.tc = 0 # message is not truncated
756         self.rd = 1 # recursion desired
757         self.ra = 0 # recursion available?
758         self.z = 0 # reserved for future use
759         self.rcode = 0 # response code (set in response)
760         self.qdcount = 1 # number of questions, only 1 is supported
761         self.ancount = 0 # number of rrs in the answer section
762         self.nscount = 0 # number of name server rrs in authority section
763         self.arcount = 0 # number or rrs in the additional section
764
765 class dnsquestion:
766     def __init__(self):
767         self.qname = 'localhost'
768         self.qtype = 'A'
769         self.qclass = 'IN'
770
771 class dnsupdatezone:
772     pass
773
774 class message:
775     def __init__(self, msgdata=''):
776         if msgdata:
777             self.header = dnsheader()
778         else:
779             self.header = dnsheader(id=random.randrange(1,32768))
780         self.question = dnsquestion()
781         self.answerlist = []
782         self.authlist = []
783         self.addlist = []
784         self.u = ''
785         self.qtypes = {1:'A',2:'NS',3:'MD',4:'MF',5:'CNAME',6:'SOA',
786                        7:'MB',8:'MG',9:'MR',10:'NULL',11:'WKS',
787                        12:'PTR',13:'HINFO',14:'MINFO',15:'MX',
788                        16:'TXT',17:'RP',28:'AAAA',29:'LOC',33:'SRV',
789                        38:'A6',39:'DNAME',251:'IXFR',252:'AXFR',
790                        253:'MAILB',254:'MAILA',255:'ANY'}
791         self.rqtypes = {}
792         for key in self.qtypes.keys():
793             self.rqtypes[self.qtypes[key]] = key
794         self.qclasses = {1:'IN',2:'CS',3:'CH',4:'HS',254:'NONE',255:'ANY'}
795         self.rqclasses = {}
796         for key in self.qclasses.keys():
797             self.rqclasses[self.qclasses[key]] = key
798
799         if msgdata:
800             self.processpkt(msgdata)
801
802     def getdomainname(self, data, i):
803         log(4,'IN GETDOMAINNAME')
804         domainname = ''
805         gotpointer = 0
806         labellength= ord(data[i])
807         log(4,'labellength:' + str(labellength))
808         i = i + 1
809         while labellength != 0:
810             while labellength >= 192:
811                 # pointer
812                 if not gotpointer:
813                     rindex = i + 1
814                     gotpointer = 1
815                     log(4,'got pointer')
816                 i = asctoint(chr(ord(data[i-1]) & 63)+data[i])
817                 log(4,'new index:'+str(i))
818                 labellength = ord(data[i])
819                 log(4,'labellength:' + str(labellength))
820                 i = i + 1
821             if domainname:
822                 domainname = domainname + '.' + data[i:i+labellength]
823             else:
824                 domainname = data[i:i+labellength]
825             log(4,'domainname:'+domainname)
826             i = i + labellength
827             labellength = ord(data[i])
828             log(4,'labellength:' + str(labellength))
829             i = i + 1
830         if not gotpointer:
831             rindex = i
832
833         return domainname.lower(), rindex
834
835     def getrrdata(self, type, msgdata, rdlength, i):
836         log(4,'unpacking RR data')
837         rdata = msgdata[i:i+rdlength]
838         if type == 'A':
839             return {'address':socket.inet_ntoa(rdata)}
840         elif type == 'AAAA':
841             return {'address':ipv6net_ntoa(rdata)}
842         elif type == 'CNAME':
843             cname, i = self.getdomainname(msgdata,i)
844             return {'cname':cname}
845         elif type == 'HINFO':
846             cpulen = ord(rdata[0])
847             cpu = rdata[1:cpulen+1]
848             return {'cpu':cpu,
849                     'os':rdata[cpulen+2:]}
850         elif type == 'LOC':
851             return {'version':ord(rdata[0]),
852                     'size':self.locsize(rdata[1]),
853                     'horiz_pre':self.locsize(rdata[2]),
854                     'vert_pre':self.locsize(rdata[3]),
855                     'latitude':asctoint(rdata[4:8]),
856                     'longitude':asctoint(rdata[8:12]),
857                     'altitude':asctoint(rdata[12:16])}
858         elif type == 'MX':
859             exchange, i = self.getdomainname(msgdata,i+2)
860             return {'preference':asctoint(rdata[:2]),
861                     'exchange':exchange}
862         elif type == 'NS':
863             nsdname, i = self.getdomainname(msgdata,i)
864             return {'nsdname':nsdname}
865         elif type == 'PTR':
866             ptrdname, i = self.getdomainname(msgdata,i)
867             return {'ptrdname':ptrdname}
868         elif type == 'RP':
869             mboxdname, i = self.getdomainname(msgdata,i)
870             txtdname, i = self.getdomainname(msgdata,i)
871             return {'mboxdname':mboxdname,
872                     'txtdname':txtdname}
873         elif type == 'SOA':
874             mname, i = self.getdomainname(msgdata,i)
875             rname, i = self.getdomainname(msgdata,i)
876             return {'mname':mname,
877                     'rname':rname,
878                     'serial':asctoint(msgdata[i:i+4]),
879                     'refresh':asctoint(msgdata[i+4:i+8]),
880                     'retry':asctoint(msgdata[i+8:i+12]),
881                     'expire':asctoint(msgdata[i+12:i+16]),
882                     'minimum':asctoint(msgdata[i+16:i+20])}
883         elif type == 'SRV':
884             target, i = self.getdomainname(msgdata,i+6)            
885             return {'priority':asctoint(rdata[0:2]),
886                     'weight':asctoint(rdata[2:4]),
887                     'port':asctoint(rdata[4:6]),
888                     'target':target}
889         elif type == 'TXT':
890             return {'txtdata':rdata[1:]}
891         else:
892             return {'rdata':rdata}
893         
894     def getrr(self, data, i):
895         log(4,'unpacking RR name')
896         name, i = self.getdomainname(data, i)
897         type = asctoint(data[i:i+2])
898         type = self.qtypes.get(type,chr(type))
899         klass = asctoint(data[i+2:i+4])
900         klass = self.qclasses.get(klass,chr(klass))
901         ttl = asctoint(data[i+4:i+8])
902         rdlength = asctoint(data[i+8:i+10])
903         rrdata = self.getrrdata(type,data,rdlength,i+10)
904         rrdata['ttl'] = ttl
905         rrdata['class'] = klass
906         rr = {name:{type:[rrdata]}}
907         return rr, i+10+rdlength
908
909     def processpkt(self, msgdata):
910         self.header.id = asctoint(msgdata[:2])
911         self.header.qr = ord(msgdata[2]) >> 7
912         self.header.opcode = (ord(msgdata[2]) & 127) >> 3
913         if self.header.opcode == 5:
914             # UPDATE packet
915             log(4,'processing UPDATE packet')
916             del self.header.aa
917             del self.header.tc
918             del self.header.rd
919             del self.header.ra
920             del self.header.qdcount
921             del self.header.ancount
922             del self.header.nscount
923             del self.header.arcount
924             del self.question
925             self.zone = dnsupdatezone()
926             del self.answerlist
927             del self.authlist
928             del self.addlist
929             self.header.z = 0
930             self.header.rcode = ord(msgdata[3]) & 15
931             self.header.zocount = asctoint(msgdata[4:6])
932             self.header.prcount = asctoint(msgdata[6:8])
933             self.header.upcount = asctoint(msgdata[8:10])
934             self.header.arcount = asctoint(msgdata[10:12])
935             self.zolist = []
936             self.prlist = []
937             self.uplist = []
938             self.addlist = []
939             i = 12
940             for x in range(self.header.zocount):
941                 (dn, i) = self.getdomainname(msgdata,i)
942                 self.zone.zname = dn
943                 type = asctoint(msgdata[i:i+2])
944                 self.zone.ztype = self.qtypes.get(type,chr(type))
945                 klass = asctoint(msgdata[i+2:i+4])
946                 self.zone.zclass = self.qclasses.get(klass,chr(klass))
947                 i = i + 4
948             for x in range(self.header.prcount):
949                 rr, i  = self.getrr(msgdata,i)
950                 self.prlist.append(rr)
951             for x in range(self.header.upcount):
952                 rr, i  = self.getrr(msgdata,i)
953                 self.uplist.append(rr)
954             for x in range(self.header.arcount):
955                 rr, i  = self.getrr(msgdata,i)
956                 self.adlist.append(rr)
957         else:
958             self.header.aa = (ord(msgdata[2]) & 4) >> 2
959             self.header.tc = (ord(msgdata[2]) & 2) >> 1
960             self.header.rd = ord(msgdata[2]) & 1
961             self.header.ra = ord(msgdata[3]) >> 7
962             self.header.z = (ord(msgdata[3]) & 112) >> 4
963             self.header.rcode = ord(msgdata[3]) & 15
964             self.header.qdcount = asctoint(msgdata[4:6])
965             self.header.ancount = asctoint(msgdata[6:8])
966             self.header.nscount = asctoint(msgdata[8:10])
967             self.header.arcount = asctoint(msgdata[10:12])
968             i = 12
969             for x in range(self.header.qdcount):
970                 log(4,'unpacking question')
971                 (dn, i) = self.getdomainname(msgdata,i)
972                 self.question.qname = dn
973                 rrtype = asctoint(msgdata[i:i+2])
974                 self.question.qtype = self.qtypes.get(rrtype,chr(rrtype))
975                 klass = asctoint(msgdata[i+2:i+4])
976                 self.question.qclass = self.qclasses.get(klass,chr(klass))
977                 i = i + 4
978             for x in range(self.header.ancount):
979                 log(4,'unpacking answer RR')
980                 rr, i = self.getrr(msgdata,i)
981                 self.answerlist.append(rr)
982             for x in range(self.header.nscount):
983                 log(4,'unpacking auth RR')
984                 rr, i = self.getrr(msgdata,i)            
985                 self.authlist.append(rr)
986             for x in range(self.header.arcount):
987                 log(4,'unpacking additional RR')
988                 rr, i = self.getrr(msgdata,i)            
989                 self.addlist.append(rr)
990         return
991
992     def pds(self, s, l):
993         # pad string with chr(0)'s so that
994         # return string length is l
995         x = l - len(s)
996         return x*chr(0) + s
997
998     def locsize(self, s):
999         x1 = ord(s) >> 4
1000         x2 = ord(s) & 15
1001         return (x1, x2)
1002
1003     def packlocsize(self, x):
1004         return chr((x[0] << 4) + x[1])
1005
1006     def packdomainname(self, name, i, msgcomp):
1007         log(4,'packing domainname: ' + name)
1008         if name == '':
1009             return chr(0)
1010         if name in msgcomp.keys():
1011             log(4,'using pointer for: ' + name)
1012             return msgcomp[name]
1013         packedname = ''
1014         tokens = name.split('.')
1015         for j in range(len(tokens)):
1016             packedname = packedname + chr(len(tokens[j])) + tokens[j]
1017             nameleft = '.'.join(tokens[j+1:])
1018             if nameleft in msgcomp.keys():
1019                 log(4,'using pointer for: ' + nameleft)
1020                 return packedname+msgcomp[nameleft]
1021         # haven't used a pointer so put this in the dictionary
1022         pointer = inttoasc(i)
1023         if len(pointer) == 1:
1024             msgcomp[name] = chr(192)+pointer
1025         else:
1026             msgcomp[name] = chr(192|ord(pointer[0])) + pointer[1]
1027         log(4,'added pointer for ' + name + '(' + str(i) + ')')
1028         return packedname + chr(0)
1029
1030     def packrr(self, rr, i, msgcomp):
1031         rrname = rr.keys()[0]
1032         rrtype = rr[rrname].keys()[0]
1033         if self.rqtypes.has_key(rrtype):
1034             typeval = self.rqtypes[rrtype]
1035         else:
1036             typeval = ord(rrtype)
1037         dbrec = rr[rrname][rrtype][0]
1038         ttl = dbrec['ttl']
1039         rclass = self.rqclasses[dbrec['class']]
1040         packedrr = (self.packdomainname(rrname, i, msgcomp) +
1041                     self.pds(inttoasc(typeval),2) +
1042                     self.pds(inttoasc(rclass),2) +
1043                     self.pds(inttoasc(ttl),4))
1044         i = i + len(packedrr) + 2
1045         if rrtype == 'A':
1046             rdata = socket.inet_aton(dbrec['address'])
1047         elif rrtype == 'AAAA':
1048             rdata = ipv6net_aton(dbrec['address'])
1049         elif rrtype == 'CNAME':
1050             rdata = self.packdomainname(dbrec['cname'], i, msgcomp)
1051         elif rrtype == 'HINFO':
1052             rdata = (chr(len(dbrec['cpu'])) + dbrec['cpu'] +
1053                      chr(len(dbrec['os'])) + dbrec['os'])
1054         elif rrtype == 'LOC':
1055             rdata = (chr(dbrec['version']) +
1056                      self.packlocsize(dbrec['size']) +
1057                      self.packlocsize(dbrec['horiz_pre']) +
1058                      self.packlocsize(dbrec['vert_pre']) +
1059                      self.pds(inttoasc(dbrec['latitude']),4) +
1060                      self.pds(inttoasc(dbrec['longitude']),4) +
1061                      self.pds(inttoasc(dbrec['altitude']),4))
1062         elif rrtype == 'MX':
1063             rdata = (self.pds(inttoasc(dbrec['preference']),2) +
1064                      self.packdomainname(dbrec['exchange'], i+2, msgcomp))
1065         elif rrtype == 'NS':
1066             rdata = self.packdomainname(dbrec['nsdname'], i, msgcomp)
1067         elif rrtype == 'PTR':
1068             rdata = self.packdomainname(dbrec['ptrdname'], i, msgcomp)
1069         elif rrtype == 'RP':
1070             rdata1 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
1071             i = i + len(rdata1)
1072             rdata2 = self.packdomainname(dbrec['mboxdname'], i , msgcomp)
1073             rdata = rdata1 + rdata2
1074         elif rrtype == 'SOA':
1075             rdata1 = self.packdomainname(dbrec['mname'], i, msgcomp)
1076             i = i + len(rdata1)
1077             rdata2 = self.packdomainname(dbrec['rname'], i, msgcomp)
1078             rdata = (rdata1 +
1079                      rdata2 +
1080                      self.pds(inttoasc(dbrec['serial']),4) +
1081                      self.pds(inttoasc(dbrec['refresh']),4) +
1082                      self.pds(inttoasc(dbrec['retry']),4) +
1083                      self.pds(inttoasc(dbrec['expire']),4) +
1084                      self.pds(inttoasc(dbrec['minimum']),4))
1085         elif rrtype == 'SRV':
1086             rdata = (self.pds(inttoasc(dbrec['priority']),2) +
1087                      self.pds(inttoasc(dbrec['weight']),2) +
1088                      self.pds(inttoasc(dbrec['port']),2) +
1089                      self.packdomainname(dbrec['target'], i+6, msgcomp))
1090         elif rrtype == 'TXT':
1091             rdata = chr(len(dbrec['txtdata'])) + dbrec['txtdata']
1092         else:
1093             rdata = dbrec['rdata']
1094
1095         return packedrr+self.pds(inttoasc(len(rdata)),2)+rdata
1096
1097     def buildpkt(self):
1098         # keep dictionary of names packed (so we can use pointers)
1099         msgcomp = {}
1100         # header
1101         if self.header.id > 65535:
1102             log(0,'building packet with bad ID field')
1103             self.header.id = 1
1104         msgdata = inttoasc(self.header.id)
1105         if len(msgdata) == 1:
1106             msgdata = chr(0) + msgdata
1107         h1 = ((self.header.qr << 7) +
1108               (self.header.opcode << 3) +
1109               (self.header.aa << 2) +
1110               (self.header.tc << 1) +
1111               (self.header.rd))
1112         h2 = ((self.header.ra << 7) +
1113               (self.header.z << 4) +
1114               (self.header.rcode))
1115         msgdata = msgdata + chr(h1) + chr(h2)
1116         msgdata = msgdata + self.pds(inttoasc(self.header.qdcount),2)
1117         msgdata = msgdata + self.pds(inttoasc(self.header.ancount),2)
1118         msgdata = msgdata + self.pds(inttoasc(self.header.nscount),2)
1119         msgdata = msgdata + self.pds(inttoasc(self.header.arcount),2)
1120         # question
1121         msgdata = msgdata + self.packdomainname(self.question.qname, len(msgdata), msgcomp)
1122         if self.rqtypes.has_key(self.question.qtype):
1123             typeval = self.rqtypes[self.question.qtype]
1124         else:
1125             typeval = ord(self.question.qtype)
1126         msgdata = msgdata + self.pds(inttoasc(typeval),2)
1127         if self.rqclasses.has_key(self.question.qclass):
1128             classval = self.rqclasses[self.question.qclass]
1129         else:
1130             classval = ord(self.question.qclass)
1131         msgdata = msgdata + self.pds(inttoasc(classval),2)
1132         # rr's
1133         # RR record format:
1134         # {'name' : {'type' : [rdata, rdata, ...]}
1135         # example: {'test.blah.net': {'A': [{'address': '10.1.1.2',
1136         #                                    'ttl': 3600L}]}}
1137         for rr in self.answerlist:
1138             log(4,'packing answer RR')
1139             msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1140         for rr in self.authlist:
1141             log(4,'packing auth RR')
1142             msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1143         for rr in self.addlist:
1144             log(4,'packing additional RR')
1145             msgdata = msgdata + self.packrr(rr, len(msgdata), msgcomp)
1146             
1147         return msgdata
1148
1149     def printpkt(self):
1150         print 'ID: ' +str(self.header.id)
1151         if self.header.qr:
1152             print 'QR: RESPONSE'
1153         else:
1154             print 'QR: QUERY'
1155         if self.header.opcode == 0:
1156             print 'OPCODE: STANDARD QUERY'
1157         elif self.header.opcode == 1:
1158             print 'OPCODE: INVERSE QUERY'
1159         elif self.header.opcode == 2:
1160             print 'OPCODE: SERVER STATUS REQUEST'
1161         elif self.header.opcode == 5:
1162             print 'UPDATE REQUEST'
1163         else:
1164             print 'OPCODE: UNKNOWN QUERY TYPE'
1165         if self.header.opcode != 5:
1166             if self.header.aa:
1167                 print 'AA: AUTHORITATIVE ANSWER'
1168             else:
1169                 print 'AA: NON-AUTHORITATIVE ANSWER'
1170             if self.header.tc:
1171                 print 'TC: MESSAGE IS TRUNCATED'
1172             else:
1173                 print 'TC: MESSAGE IS NOT TRUNCATED'
1174             if self.header.rd:
1175                 print 'RD: RECURSION DESIRED'
1176             else:
1177                 print 'RD: RECURSION NOT DESIRED'
1178             if self.header.ra:
1179                 print 'RA: RECURSION AVAILABLE'
1180             else:
1181                 print 'RA: RECURSION IS NOT AVAILABLE'
1182         if self.header.rcode == 1:
1183             printrcode =  'FORMERR'
1184         elif self.header.rcode == 2:
1185             printrcode =  'SERVFAIL'
1186         elif self.header.rcode == 3:
1187             printrcode =  'NXDOMAIN'
1188         elif self.header.rcode == 4:
1189             printrcode =  'NOTIMP'
1190         elif self.header.rcode == 5:
1191             printrcode =  'REFUSED'
1192         elif self.header.rcode == 6:
1193             printrcode = 'YXDOMAIN'
1194         elif self.header.rcode == 7:
1195             printrcode = 'YXRRSET'
1196         elif self.header.rcode == 8:
1197             printrcode = 'NXRRSET'
1198         elif self.header.rcode == 9:
1199             printrcode = 'NOTAUTH'
1200         elif self.header.rcode == 10:
1201             printrcode = 'NOTZONE'
1202         else:
1203             printrcode =  'NOERROR'
1204         print 'RCODE: ' + printrcode
1205         if self.header.opcode == 5:
1206             print 'NUMBER OF RRs in the Zone Section: ' + str(self.header.zocount)
1207             print 'NUMBER OF RRs in the Prerequisite Section: ' + str(self.header.prcount)
1208             print 'NUMBER OF RRs in the Update Section: ' + str(self.header.upcount)
1209             print 'NUMBER OF RRs in the Additional Data Section: ' + str(self.header.arcount)
1210             print 'ZONE SECTION:'
1211             print 'zname: ' + self.zone.zname
1212             print 'zonetype: ' + self.zone.ztype
1213             print 'zoneclass: ' + self.zone.zclass
1214             print 'PREREQUISITE RRs:'
1215             for rr in self.prlist:
1216                 print rr
1217             print 'UPDATE RRs:'        
1218             for rr in self.uplist:
1219                 print rr
1220             print 'ADDITIONAL RRs:'        
1221             for rr in self.addlist:
1222                 print rr
1223
1224
1225         else:
1226             print 'NUMBER OF QUESTION RRs: ' + str(self.header.qdcount)
1227             print 'NUMBER OF ANSWER RRs: ' + str(self.header.ancount)
1228             print 'NUMBER OF NAME SERVER RRs: ' + str(self.header.nscount)
1229             print 'NUMBER OF ADDITIONAL RRs: ' + str(self.header.arcount)
1230             print 'QUESTION SECTION:'
1231             print 'qname: ' + self.question.qname
1232             print 'querytype: ' + self.question.qtype
1233             print 'queryclass: ' + self.question.qclass
1234             print 'ANSWER RRs:'
1235             for rr in self.answerlist:
1236                 print rr
1237             print 'AUTHORITY RRs:'        
1238             for rr in self.authlist:
1239                 print rr
1240             print 'ADDITIONAL RRs:'        
1241             for rr in self.addlist:
1242                 print rr
1243
1244 class zonedb:
1245     def __init__(self, zdict):
1246         self.zdict = zdict
1247         self.updates = {}
1248         for k in self.zdict.keys():
1249             if self.zdict[k]['type'] == 'slave':
1250                 self.zdict[k]['lastupdatetime'] = 0
1251
1252     def error(self, id, qname, querytype, queryclass, rcode):
1253         error = message()
1254         error.header.id = id
1255         error.header.rcode = rcode
1256         error.header.qr = 1
1257         error.question.qname = qname
1258         error.question.qtype = querytype
1259         error.question.qclass = queryclass
1260         return error
1261
1262     def getorigin(self, zkey):
1263         origin = ''
1264         if self.zdict.has_key(zkey):
1265             origin = self.zdict[zkey]['origin']
1266         return origin
1267
1268     def getmasterip(self, zkey):
1269         masterip = ''
1270         if self.zdict.has_key(zkey):
1271             if self.zdict[zkey].has_key('masterip'):
1272                 masterip = self.zdict[zkey]['masterip']
1273         return masterip
1274
1275     def zonetrans(self, query):
1276         # build a list of messages
1277         # each message contains one rr of the zone
1278         # the first and last message are the
1279         # SOA records
1280         origin = query.question.qname
1281         querytype = query.question.qtype
1282         zkey = ''
1283         for zonekey in self.zdict.keys():
1284             if self.zdict[zonekey]['origin'] == query.question.qname:
1285                 zkey = zonekey
1286         if not zkey:
1287             return []
1288         zonedata = self.zdict[zkey]['zonedata']
1289         queryid = query.header.id
1290         soarec = zonedata[origin]['SOA'][0]
1291         soa = {origin:{'SOA':[soarec]}}
1292         curserial = soarec['serial']
1293         rrlist = []
1294         if querytype == 'IXFR':
1295             clientserial = query.authlist[0][origin]['SOA'][0]['serial']
1296             if clientserial < curserial:
1297                 for i in range(clientserial,curserial+1):
1298                     if self.updates[zkey].has_key(i):
1299                         for rr in self.updates[zkey][i]['added']:
1300                             rrlist.append(rr)
1301                         for rr in self.updates[zkey][i]['removed']:
1302                             rrlist.append(rr)
1303                 if len(rrlist) > 0:
1304                     rrlist.insert(0,soa)
1305                 rrlist.append(soa)
1306             else:
1307                 rrlist.append(soa)
1308         else:
1309             for nodename in zonedata.keys():
1310                 for rrtype in zonedata[nodename].keys():
1311                     if not (rrtype == 'SOA' and nodename == origin):
1312                         for rr in zonedata[nodename][rrtype]:
1313                             rrlist.append({nodename:{rrtype:[rr]}})
1314             rrlist.insert(0,soa)
1315             rrlist.append(soa)
1316         msglist = []
1317         for rr in rrlist:
1318             msg = message()
1319             msg.header.id = queryid
1320             msg.header.qr = 1
1321             msg.header.aa = 1
1322             msg.header.rd = 0
1323             msg.header.qdcount = 1
1324             msg.question.qname = origin
1325             msg.question.qtype = querytype
1326             msg.question.qclass = 'IN'
1327             msg.header.ancount = 1
1328             msg.answerlist.append(rr)
1329             msglist.append(msg)
1330         return msglist
1331
1332     def update_zone(self, rrlist, params):
1333         zonekey = params[0]
1334         zonedata = {}
1335         soa = rrlist.pop()
1336         origin = soa.keys()[0]
1337         for rr in rrlist:
1338             rrname = rr.keys()[0]
1339             rrtype = rr[rrname].keys()[0]
1340             dbrec = rr[rrname][rrtype][0]
1341             if zonedata.has_key(rrname):
1342                 if not zonedata[rrname].has_key(rrtype):
1343                     zonedata[rrname][rrtype] = []
1344             else:
1345                 zonedata[rrname] = {}
1346                 zonedata[rrname][rrtype] = []
1347             zonedata[rrname][rrtype].append(dbrec)
1348         self.zdict[zonekey]['zonedata'] = zonedata
1349         curtime = time.time()
1350         self.zdict[zonekey]['lastupdatetime'] = curtime
1351         try:
1352             f = file(self.zdict[zonekey]['filename'],'w')
1353             writezonefile(zonedata, self.zdict[zonekey]['origin'], f)
1354             f.close()
1355         except:
1356             log(0,'unable to write zone ' + zonekey + 'to disk')
1357         log(1,'finished zone transfer for: ' + zonekey + ' (' + str(curtime) + ')')
1358
1359     def remove_zone(self, zonekey):
1360         if self.zdict.has_key(zonekey):
1361             del self.zdict[zonekey]
1362
1363     def getslaves(self, curtime):
1364         rlist = []
1365         for k in self.zdict.keys():
1366             if self.zdict[k]['type'] == 'slave':
1367                 origin = self.zdict[k]['origin']
1368                 refresh = self.zdict[k]['zonedata'][origin]['SOA'][0]['refresh']
1369                 if self.zdict[k]['lastupdatetime'] + refresh < curtime:
1370                     rlist.append((k, origin, self.zdict[k]['masterip']))
1371         return rlist
1372
1373     def zmatch(self, qname, zkeys):
1374         for zkey in zkeys:
1375             if self.zdict.has_key(zkey):
1376                 origin = self.zdict[zkey]['origin']
1377                 if qname.rfind(origin) != -1:
1378                     return zkey
1379         return ''
1380
1381     def getzlist(self, name, zone):
1382         if name == zone:
1383             return
1384         zlist = []
1385         i = name.rfind(zone)
1386         if i == -1:
1387             return
1388         firstpart = name[:i-1]
1389         partlist = firstpart.split('.')
1390         partlist.reverse()
1391         lastpart = zone
1392         for x in range(len(partlist)):
1393             lastpart = partlist[x] + '.' + lastpart
1394             zlist.append(lastpart)
1395         return zlist
1396
1397     def lookup(self, zkeys, query, addr, server, dorecursion, flist, cbfunc):
1398         # handle zone transfers seperately
1399         qname = query.question.qname
1400         querytype = query.question.qtype
1401         queryclass = query.question.qclass
1402         if querytype in ['AXFR','IXFR']:
1403             for zkey in self.zdict.keys():
1404                 if zkey in zkeys:
1405                     if qname == self.zdict[zkey]['origin']:
1406                         answerlist = self.zonetrans(query)
1407                         break
1408             else:
1409                 answerlist = []
1410             cbfunc(query, addr, server, dorecursion, flist, answerlist)
1411         else:
1412             zonekey = self.zmatch(qname, zkeys)
1413             if zonekey:
1414                 origin = self.zdict[zonekey]['origin']
1415                 zonedict = self.zdict[zonekey]['zonedata']
1416                 referral = 0
1417                 rranswerlist = []
1418                 rrnslist = []
1419                 rraddlist = []
1420                 answer = message()
1421                 answer.header.aa = 1
1422                 answer.header.id = query.header.id
1423                 answer.header.qr = 1
1424                 answer.header.opcode = query.header.opcode
1425                 answer.header.rcode = 4
1426                 answer.header.ra = dorecursion
1427                 answer.question.qname = query.question.qname
1428                 answer.question.qtype = query.question.qtype
1429                 answer.question.qclass = query.question.qclass
1430                 answer.header.ra = dorecursion
1431                 s = '.servers.csail.mit.edu'
1432                 if qname.endswith(s):
1433                     host = qname[:-len(s)]
1434                     value = sipb_xen_database.NIC.get_by(hostname=host)
1435                     if value is None:
1436                         pass
1437                     else:
1438                         ip = value.ip
1439                         rranswerlist.append({qname: {'A': [{'address': ip, 
1440                                                             'class': 'IN', 
1441                                                             'ttl': 10}]}})
1442                 if zonedict.has_key(qname):
1443                     # found the node, now take care of CNAMEs
1444                     if zonedict[qname].has_key('CNAME'):
1445                         if querytype != 'CNAME':
1446                             nodetype = 'CNAME'
1447                             while nodetype == 'CNAME':
1448                                 rranswerlist.append({qname:{'CNAME':[zonedict[qname]['CNAME'][0]]}})
1449                                 qname = zonedict[qname]['CNAME'][0]['cname']
1450                                 if zonedict.has_key(qname):
1451                                     nodetype = zonedict[qname].keys()[0]
1452                                 else:
1453                                     # error, shouldn't have a CNAME that points to nothing
1454                                     return
1455                     # if we get this far, then the record has matched and we should return
1456                     # a reply that has no error (even if there is no info macthing the qtype)
1457                     answer.header.rcode = 0
1458                     answernode = zonedict[qname]
1459                     if querytype == 'ANY':
1460                         for type in answernode.keys():
1461                             for rec in answernode[type]:
1462                                 rranswerlist.append({qname:{type:[rec]}})
1463                     elif answernode.has_key(querytype):
1464                         for rec in answernode[querytype]:
1465                             rranswerlist.append({qname:{querytype:[rec]}})
1466                         # do rrset ordering (cyclic)
1467                         if len(answernode[querytype]) > 1:
1468                             rec = answernode[querytype].pop(0)
1469                             answernode[querytype].append(rec)
1470                     else:
1471                         # remove all cname rrs from answerlist
1472                         rranswerlist = []
1473                 else:
1474                     # would check for wildcards here (but aren't because they seem bad)
1475                     # see if we need to give a referral
1476                     zlist = self.getzlist(qname,origin)
1477                     for zonename in zlist:
1478                         if zonedict.has_key(zonename):
1479                             if zonedict[zonename].has_key('NS'):
1480                                 answer.header.rcode = 0
1481                                 referral = 1
1482                                 for rec in zonedict[zonename]['NS']:
1483                                     rrnslist.append({zonename:{'NS':[rec]}})
1484                                     nsdname = rec['nsdname']
1485                                     # add glue records if they exist
1486                                     if zonedict.has_key(nsdname):
1487                                         if zonedict[nsdname].has_key('A'):
1488                                             for gluerec in zonedict[nsdname]['A']:
1489                                                 rraddlist.append({nsdname:{'A':[gluerec]}})
1490                     # negative caching stuff
1491                     if not referral:
1492                         if not rranswerlist:
1493                             # NOTE: RFC1034 section 4.3.4 says we should add the SOA record
1494                             #       to the additional section of the response.  BIND adds
1495                             #       it to the ns section though
1496                             answer.header.rcode = 3
1497                             rrnslist.append({origin:{'SOA':[zonedict[origin]['SOA'][0]]}})
1498                         else:
1499                             for rec in zonedict[origin]['NS']:
1500                                 rrnslist.append({origin:{'NS':[rec]}})
1501                 answer.header.ancount = len(rranswerlist)
1502                 answer.header.nscount = len(rrnslist)
1503                 answer.header.arcount = len(rraddlist)
1504                 answer.answerlist = rranswerlist
1505                 answer.authlist = rrnslist
1506                 answer.addlist = rraddlist
1507                 cbfunc(query, addr, server, dorecursion, flist, [answer])
1508             else:
1509                 cbfunc(query, addr, server, dorecursion, flist, [])
1510
1511     def handle_update(self, msg, addr, ns):
1512         zkey = ''
1513         slaves = []
1514         for zonekey in self.zdict.keys():
1515             if (self.zdict[zonekey]['type'] == 'master' and
1516                 self.zdict[zonekey]['origin'] == msg.zone.zname):
1517                 zkey = zonekey
1518         if not zkey:
1519             log(2,'SENDING NOTAUTH UPDATE ERROR')
1520             errormsg = self.error(msg.header.id, msg.zone.zname,
1521                                   msg.zone.ztype, msg.zone.zclass, 9)
1522             return errormsg, '', slaves
1523         # find the slaves for the zone
1524         if self.zdict[zkey].has_key('slaves'):
1525             slaves = self.zdict[zkey]['slaves']
1526         origin = self.zdict[zkey]['origin']
1527         zd = self.zdict[zkey]['zonedata']
1528         # check the permissions
1529         if not ns.config.allowupdate(msg, addr[0], addr[1]):
1530             log(2,'SENDING REFUSED UPDATE ERROR')
1531             errormsg = self.error(msg.header.id, msg.zone.zname,
1532                                   msg.zone.ztype, msg.zone.zclass, 5)
1533             return errormsg, origin, slaves
1534         # now check the prereqs
1535         temprrset = {}
1536         for rr in msg.prlist:
1537             rrname = rr.keys()[0]
1538             rrtype = rr[rrname].keys()[0]
1539             dbrec = rr[rrname][rrtype][0]
1540             if dbrec['ttl'] != 0:
1541                 log(2,'FORMERROR(1)')
1542                 errormsg = self.error(msg.header.id, msg.zone.zname,
1543                                       msg.zone.ztype, msg.zone.zclass, 1)
1544                 return errormsg, origin, slaves
1545             if rrname.rfind(msg.zone.zname) == -1:
1546                 log(2,'NOTZONE(10)')
1547                 errormsg = self.error(msg.header.id, msg.zone.zname,
1548                                       msg.zone.ztype, msg.zone.zclass, 10)
1549                 return errormsg, origin, slaves
1550             if dbrec['class'] == 'ANY':
1551                 if dbrec['rdata']:
1552                     log(2,'FORMERROR(1)')
1553                     errormsg = self.error(msg.header.id, msg.zone.zname,
1554                                           msg.zone.ztype, msg.zone.zclass, 1)
1555                     return errormsg, origin, slaves
1556                 if rrtype == 'ANY':
1557                     if not zd.has_key(rrname):
1558                         log(2,'NXDOMAIN(3)')
1559                         errormsg = self.error(msg.header.id, msg.zone.zname,
1560                                               msg.zone.ztype, msg.zone.zclass, 3)
1561                         return errormsg, origin, slaves
1562                 else:
1563                     rrsettest = 0
1564                     if zd.has_key(rrname):
1565                         if zd[rrname].has_key(rrtype):
1566                             rrsettest = 1
1567                     if not rrsettest:
1568                         log(2,'NXRRSET(8)')
1569                         errormsg = self.error(msg.header.id, msg.zone.zname,
1570                                               msg.zone.ztype, msg.zone.zclass, 8)
1571                         return errormsg, origin, slaves
1572             if dbrec['class'] == 'NONE':
1573                 if dbrec['rdata']:
1574                     log(2,'FORMERROR(1)')
1575                     errormsg = self.error(msg.header.id, msg.zone.zname,
1576                                           msg.zone.ztype, msg.zone.zclass, 1)
1577                     return errormsg, origin, slaves
1578                 if rrtype == 'ANY':
1579                     if zd.has_key(rrname):
1580                         log(2,'YXDOMAIN(6)')
1581                         errormsg = self.error(msg.header.id, msg.zone.zname,
1582                                               msg.zone.ztype, msg.zone.zclass, 6)
1583                         return errormsg, origin, slaves
1584                 else:
1585                     if zd.has_key(rrname):
1586                         if zd[rrname].has_key(rrtype):
1587                             log(2,'YXRRSET(7)')
1588                             errormsg = self.error(msg.header.id, msg.zone.zname,
1589                                                   msg.zone.ztype, msg.zone.zclass, 7)
1590                             return errormsg, origin, slaves
1591             if dbrec['class'] == msg.zone.zclass:
1592                 if temprrset.has_key(rrname):
1593                     if not temprrset[rrname].has_key(rrtype):
1594                         temprrset[rrname][rrtype] = []
1595                 else:
1596                     temprrset[rrname] = {}
1597                     temprrset[rrname][rrtype] = []
1598                 temprrset[rrname][rrtype].append(dbrec)
1599             else:
1600                 log(2,'FORMERROR(1)')
1601                 errormsg = self.error(msg.header.id, msg.zone.zname,
1602                                       msg.zone.ztype, msg.zone.zclass, 1)
1603                 return errormsg, origin, slaves
1604         for nodename in temprrset.keys():
1605             if not self.rrmatch(temprrset[nodename],zd[nodename]):
1606                 log(2,'NXRRSET(8)')
1607                 errormsg = self.error(msg.header.id, msg.zone.zname,
1608                                       msg.zone.ztype, msg.zone.zclass, 8)
1609                 return errormsg, origin, slaves
1610
1611         # update section prescan
1612         for rr in msg.uplist:
1613             rrname = rr.keys()[0]
1614             rrtype = rr[rrname].keys()[0]
1615             dbrec = rr[rrname][rrtype][0]
1616             if rrname.rfind(msg.zone.zname) == -1:
1617                 log(2,'NOTZONE(10)')
1618                 errormsg = self.error(msg.header.id, msg.zone.zname,
1619                                       msg.zone.ztype, msg.zone.zclass, 10)
1620                 return errormsg, origin, slaves
1621             if dbrec['class'] == msg.zone.zclass:
1622                 if rrtype in ['ANY','MAILA','MAILB','AXFR']:
1623                     log(2,'FORMERROR(1)')
1624                     errormsg = self.error(msg.header.id, msg.zone.zname,
1625                                           msg.zone.ztype, msg.zone.zclass, 1)
1626                     return errormsg, origin, slaves
1627             elif dbrec['class'] == 'ANY':
1628                 if dbrec['ttl'] != 0 or dbrec['rdata'] or rrtype in ['MAILA','MAILB','AXFR']:
1629                     log(2,'FORMERROR(1)')
1630                     errormsg = self.error(msg.header.id, msg.zone.zname,
1631                                           msg.zone.ztype, msg.zone.zclass, 1)
1632                     return errormsg, origin, slaves
1633             elif dbrec['class'] == 'NONE':
1634                 if dbrec['ttl'] != 0 or rrtype in ['ANY','MAILA','MAILB','AXFR']:
1635                     log(2,'FORMERROR(1)')
1636                     errormsg = self.error(msg.header.id, msg.zone.zname,
1637                                           msg.zone.ztype, msg.zone.zclass, 1)
1638                     return errormsg, origin, slaves
1639             else:
1640                 log(2,'FORMERROR(1)')
1641                 errormsg = self.error(msg.header.id, msg.zone.zname,
1642                                       msg.zone.ztype, msg.zone.zclass, 1)
1643                 return errormsg, origin, slaves
1644
1645         # now handle actual update
1646         curserial = zd[msg.zone.zname]['SOA'][0]['serial']
1647         # update the soa serial here
1648         clearupdatehist = 0
1649         if len(msg.uplist) > 0:
1650             # initialize history structure
1651             if not self.updates.has_key(zkey):
1652                 self.updates[zkey] = {}
1653                 self.updates[zkey][curserial] = {'removed':[],
1654                                                  'added':[]}
1655             if curserial == 2**32:
1656                 newserial = 2
1657                 clearupdatehist = 1
1658             else:
1659                 newserial = curserial + 1
1660             self.updates[zkey][newserial] = {'removed':[],
1661                                              'added':[]}
1662             zd[msg.zone.zname]['SOA'][0]['serial'] = newserial
1663         for rr in msg.uplist:
1664             rrname = rr.keys()[0]
1665             rrtype = rr[rrname].keys()[0]
1666             dbrec = rr[rrname][rrtype][0]
1667             if dbrec['class'] == msg.zone.zclass:
1668                 if rrtype == 'SOA':
1669                     if zd.has_key(rrname):
1670                         if zd[rrname].has_key('SOA'):
1671                             if dbrec['serial'] > zd[rrname]['SOA'][0]['serial']:
1672                                 del zd[rrname]['SOA'][0]
1673                                 zd[rrname]['SOA'].append(dbrec)
1674                                 clearupdatehist = 1
1675                 elif rrtype == 'WKS':
1676                     if zd.has_key(rrname):
1677                         if zd[rrname].has_key('WKS'):
1678                             rdata = zd[rrname]['WKS'][0]
1679                             oldrr = {rrname:{'WKS':[rdata]}}
1680                             self.updates[zkey][curserial]['removed'].append(oldrr)
1681                             del zd[rrname]['WKS'][0]
1682                             zd[rrname]['WKS'].append(dbrec)
1683                             newrr = {rrname:{'WKS':[dbrec]}}
1684                             self.updates[zkey][newserial]['added'].append(newrr)
1685                 else:
1686                     if zd.has_key(rrname):
1687                         if not zd[rrname].has_key(rrtype):
1688                             zd[rrname][rrtype] = []
1689                     else:
1690                         zd[rrname] = {}
1691                         zd[rrname][rrtype] = []
1692                     zd[rrname][rrtype].append(dbrec)
1693                     newrr = {rrname:{rrtype:[dbrec]}}
1694                     self.updates[zkey][newserial]['added'].append(newrr)
1695             elif dbrec['class'] == 'ANY':
1696                 if rrtype == 'ANY':
1697                     if rrname == msg.zone.zname:
1698                         if zd.has_key(rrname):
1699                             for dnstype in zd[rrname].keys():
1700                                 if dnstype not in ['SOA','NS']:
1701                                     for rdata in zd[rrname][dnstype]:
1702                                         oldrr = {rrname:{dnstype:[rdata]}}
1703                                         self.updates[zkey][curserial]['removed'].append(oldrr)
1704                                     del zd[rrname][dnstype]                                    
1705                     else:
1706                         if zd.has_key(rrname):
1707                             for dnstype in zd[rrname].keys():
1708                                 for rdata in zd[rrname][dnstype]:
1709                                     oldrr = {rrname:{dnstype:[rdata]}}
1710                                     self.updates[zkey][curserial]['removed'].append(oldrr)
1711                             del zd[rrname]
1712                 else:
1713                     if zd.has_key(rrname):
1714                         if zd[rrname].has_key(rrtype):
1715                             if rrname == msg.zone.zname:
1716                                 if rrtype not in ['SOA','NS']:
1717                                     for rdata in zd[rrname][dnstype]:
1718                                         oldrr = {rrname:{dnstype:[rdata]}}
1719                                         self.updates[zkey][curserial]['removed'].append(oldrr)
1720                                     del zd[rrname][rrtype]
1721                             else:
1722                                 for rdata in zd[rrname][dnstype]:
1723                                     oldrr = {rrname:{dnstype:[rdata]}}
1724                                     self.updates[zkey][curserial]['removed'].append(oldrr)
1725                                 del zd[rrname][rrtype]
1726             elif dbrec['class'] == 'NONE':
1727                 if not (rrname == msg.zone.zname and rrtype in ['SOA','NS']):
1728                     if zd.had_key(rrname):
1729                         if zd[rrname].has_key(rrtype):
1730                             for i in range(len(zd[rrname][rrtype])):
1731                                 if dbrec == zd[rrname][rrtype][i]:
1732                                     rdata = zd[rrname][dnstype][i]
1733                                     oldrr = {rrname:{dnstype:[rdata]}}
1734                                     self.updates[zkey][curserial]['removed'].append(oldrr)
1735                                     del zd[rrname][rrtype][i]
1736                             if len(zd[rrname][rrtype]) == 0:
1737                                 del zd[rrname][rrtype]
1738         if clearupdatehist:
1739             self.updates[zkey] = {}
1740         log(2,'SENDING UPDATE NOERROR MSG')
1741         noerrormsg = self.error(msg.header.id, msg.zone.zname,
1742                               msg.zone.ztype, msg.zone.zclass, 0)
1743         return noerrormsg, origin, slaves
1744
1745 class dnscache:
1746     def __init__(self,cachezone):
1747         self.cachedb = cachezone
1748         # go through and set all of the root ttls to zero
1749         for node in self.cachedb.keys():
1750             for rtype in self.cachedb[node].keys():
1751                 for rr in self.cachedb[node][rtype]:
1752                     rr['ttl'] = 0
1753                     if rtype == 'NS':
1754                         rr['rtt'] = 0
1755         # add special entries for localhost
1756         self.cachedb['localhost'] = {'A':[{'address':'127.0.0.1', 'ttl':0, 'class':'IN'}]}
1757         self.cachedb['1.0.0.127.in-addr.arpa'] = {'PTR':[{'ptrdname':'localhost', 'ttl':0,'class':'IN'}]}
1758         self.cachedb['']['SOA'] = []
1759         self.cachedb['']['SOA'].append({'class':'IN','ttl':0,'mname':'cachedb',
1760                                         'rname':'cachedb@localhost','serial':1,'refresh':10800,
1761                                         'retry':3600,'expire':604800,'minimum':3600})
1762
1763     def hasrdata(self, irrdata, rrdatalist):
1764         # compare everything but ttls
1765         test = 0
1766         testrrdata = irrdata.copy()
1767         del testrrdata['ttl']
1768         for rrdata in rrdatalist:
1769             temprrdata = rrdata.copy()
1770             del temprrdata['ttl']
1771             if temprrdata == testrrdata:
1772                 test = 1
1773         return test
1774
1775     def add(self, rr, qzone, nsdname):
1776         # NOTE: can't cache records from sites
1777         # that don't own those records (i.e. example.com
1778         # can't give us A records for www.example.net)
1779         name = rr.keys()[0]
1780         if (qzone != '') and (name[-len(qzone):] != qzone):
1781             log(2,'cache GOT possible POISON: ' + name + ' for zone ' + qzone)
1782             return
1783         rtype = rr[name].keys()[0]
1784         rdata = rr[name][rtype][0]
1785         if rdata['ttl'] < 3600:
1786             log(2,'low ttl: ' + str(rdata['ttl']))
1787             rdata['ttl'] = 3600
1788         rdata['ttl'] = int(time.time() + rdata['ttl'])
1789         if rtype == 'NS':
1790             rdata['rtt'] = 0
1791         name = name.lower()
1792         rtype = rtype.upper()
1793         if self.cachedb.has_key(name):
1794             if self.cachedb[name].has_key(rtype):
1795                 if not self.hasrdata(rdata, self.cachedb[name][rtype]):
1796                     self.cachedb[name][rtype].append(rdata)
1797                     log(3,'appended rdata to ' +
1798                         name + '(' + rtype + ') in cache')
1799                 else:
1800                     log(3,'same rdata for ' + name + '(' +
1801                         rtype + ') is already in cache')
1802             else:
1803                 self.cachedb[name][rtype] = [rdata]
1804                 log(3,'appended ' + rtype + ' and rdata to node ' +
1805                     name + ' in cache')
1806         else:
1807             self.cachedb[name] = {rtype:[rdata]}
1808             log(3,'added node ' + name + '(' + rtype + ') to cache')
1809         self.reap()
1810
1811     def addneg(self, qname, querytype, queryclass):
1812         if not self.cachedb.has_key(qname):
1813             self.cachedb['qname'] = {querytype: [{'ttl':time.time()+3600}]}
1814         else:
1815             if not self.cachedb[qname].has_key(querytype):
1816                 self.cachedb[qname][querytype] = [{'ttl':time.time()+3600}]
1817     
1818     def haskey(self, qname, querytype, msg=''):
1819         log(3,'looking for ' + qname + '(' + querytype + ') in cache')
1820         if self.cachedb.has_key(qname):
1821             rranswerlist = []
1822             rrnslist = []
1823             rraddlist = []
1824             if self.cachedb[qname].has_key('CNAME'):
1825                 if querytype != 'CNAME':
1826                     nodetype = 'CNAME'
1827                     while nodetype == 'CNAME':
1828                         if len(self.cachedb[qname]['CNAME'][0].keys()) > 1:
1829                             log(3,'Adding CNAME to cache answer')
1830                             rranswerlist.append({qname:{'CNAME':[self.cachedb[qname]['CNAME'][0]]}})
1831                         qname = self.cachedb[qname]['CNAME'][0]['cname']
1832                         if self.cachedb.has_key(qname):
1833                             nodetype = self.cachedb[qname].keys()[0]
1834                         else:
1835                             # shouldn't have a CNAME that points to nothing
1836                             return
1837             if querytype == 'ANY':
1838                 for type in self.cache[qname].keys():
1839                     for rec in self.cachedb[qname][type]:
1840                         # can't append negative entries
1841                         if len(rec.keys()) > 1:
1842                             rranswerlist.append({qname:{type:[rec]}})
1843             elif self.cachedb[qname].has_key(querytype):
1844                 for rec in self.cachedb[qname][querytype]:
1845                     if len(rec.keys()) > 1:
1846                         rranswerlist.append({qname:{querytype:[rec]}})
1847             if rranswerlist:
1848                 if msg:
1849                     answer = message()
1850                     answer.header.id = msg.header.id
1851                     answer.header.qr = 1
1852                     answer.header.opcode = msg.header.opcode
1853                     answer.header.ra = 1
1854                     answer.question.qname = msg.question.qname
1855                     answer.question.qtype = msg.question.qtype
1856                     answer.question.qclass = msg.question.qclass
1857                     answer.header.rcode = 0
1858                     answer.header.ancount = len(rranswerlist)
1859                     answer.answerlist = rranswerlist
1860                     return answer
1861                 else:
1862                     return 1
1863         else:
1864             log(3,'Cache has no node for ' + qname)
1865         
1866     def getnslist(self, qname):
1867         # find the best nameserver to ask from the cache
1868         tokens = qname.split('.')
1869         nsdict = {}
1870         curtime = time.time()
1871         for i in range(len(tokens)):
1872             domainname = '.'.join(tokens[i:])
1873             if self.cachedb.has_key(domainname):
1874                 if self.cachedb[domainname].has_key('NS'):
1875                     for nsrec in self.cachedb[domainname]['NS']:
1876                         badserver = 0
1877                         if nsrec.has_key('badtill'):
1878                             if nsrec['badtill'] < curtime:
1879                                 del nsrec['badtill']
1880                             else:
1881                                 badserver = 1
1882                         if badserver:
1883                             log(2,'BAD SERVER, not using ' + nsrec['nsdname'])
1884                         if self.cachedb.has_key(nsrec['nsdname']) and not badserver:
1885                             if self.cachedb[nsrec['nsdname']].has_key('A'):
1886                                 for arec in self.cachedb[nsrec['nsdname']]['A']:
1887                                     nsdict[nsrec['rtt']] = {'name':nsrec['nsdname'],
1888                                                             'ip':arec['address']}
1889                     if nsdict:
1890                         break
1891         if not nsdict:
1892             domainname = ''
1893             # nothing in the cache matches so give back the root servers
1894             for nsrec in self.cachedb['']['NS']:
1895                 badserver = 0
1896                 if nsrec.has_key('badtill'):
1897                     if curtime > nsrec['badtill']:
1898                         del nsrec['badtill']
1899                     else:
1900                         badserver = 1
1901                 if not badserver:
1902                     for arec in self.cachedb[nsrec['nsdname']]['A']:
1903                         nsdict[(nsrec['rtt'])] = {'name':nsrec['nsdname'],'ip':arec['address']}
1904
1905         return (domainname, nsdict)
1906
1907     def badns(self, zonename, nsdname):
1908         if self.cachedb.has_key(zonename):
1909             if self.cachedb[zonename].has_key('NS'):
1910                 for nsrec in self.cachedb[zonename]['NS']:
1911                     if nsrec['nsdname'] == nsdname:
1912                         log(2,'Setting ' + nsdname + ' as bad nameserver')
1913                         nsrec['badtill'] = time.time() + 3600
1914         
1915
1916     def updatertt(self, qname, zone, rtt):
1917         if self.cachedb.has_key(zone):
1918             if self.cachedb[zone].has_key('NS'):
1919                 for rr in self.cachedb[zone]['NS']:
1920                     if rr['nsdname'] == qname:
1921                         log(2,'updating rtt for ' + qname + ' to ' + str(rtt))
1922                         rr['rtt'] = rtt
1923
1924     def reap(self):
1925         # expire all old records
1926         ntime = time.time()
1927         for nodename in self.cachedb.keys():
1928             for rrtype in self.cachedb[nodename].keys():
1929                 for rdata in self.cachedb[nodename][rrtype]:
1930                     ttl = rdata['ttl']
1931                     if ttl != 0:
1932                         if ttl < ntime:
1933                             self.cachedb[nodename][rrtype].remove(rdata)
1934                 if len(self.cachedb[nodename][rrtype]) == 0:
1935                     del self.cachedb[nodename][rrtype]
1936             if len(self.cachedb[nodename]) == 0:
1937                 del self.cachedb[nodename]
1938                        
1939         return
1940
1941     def zonetrans(self, queryid):
1942         # build a list of messages
1943         # each message contains one rr of the zone
1944         # the first and last message are the
1945         # SOA records
1946         zonedata = self.cachedb
1947         rrlist = []
1948         soa = {'':{'SOA':[zonedata['']['SOA'][0]]}}
1949         for nodename in zonedata.keys():
1950             for rrtype in zonedata[nodename].keys():
1951                 if not (rrtype == 'SOA' and nodename == ''):
1952                     for rr in zonedata[nodename][rrtype]:
1953                         rrlist.append({nodename:{rrtype:[rr]}})
1954         rrlist.insert(0,soa)
1955         rrlist.append(soa)
1956         msglist = []
1957         for rr in rrlist:
1958             msg = message()
1959             msg.header.id = queryid
1960             msg.header.qr = 1
1961             msg.header.aa = 1
1962             msg.header.rd = 0
1963             msg.header.qdcount = 1
1964             msg.question.qname = 'cache'
1965             msg.question.qtype = 'AXFR'
1966             msg.question.qclass = 'IN'
1967             msg.header.ancount = 1
1968             msg.answerlist.append(rr)
1969             msglist.append(msg)
1970         return msglist
1971
1972 class gethostaddr(asyncore.dispatcher):
1973     def __init__(self, hostname, cbfunc, serveraddr='127.0.0.1'):
1974         asyncore.dispatcher.__init__(self)
1975         self.msg = message()
1976         self.msg.question.qname = hostname
1977         self.msg.question.qtype = 'A'
1978         self.cbfunc = cbfunc
1979         self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
1980         self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
1981         self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
1982         self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
1983
1984     def handle_read(self):
1985         replydata, addr = self.socket.recvfrom(1500)
1986         self.close()
1987         try:
1988             replymsg = message(replydata)
1989         except:
1990             log(0,'unable to process packet')
1991             return
1992         answername = replymsg.question.qname
1993         cname = ''
1994         # go through twice to catch cnames after A recs
1995         for rr in replymsg.answerlist:
1996             rrname = rr.keys()[0]
1997             rrtype = rr[rrname].keys()[0]
1998             dbrec = rr[rrname][rrtype][0]
1999             if rrname == answername and rrtype == 'CNAME':
2000                 answername = dbrec['cname']
2001                 cname = answername
2002         for rr in replymsg.answerlist:
2003             rrname = rr.keys()[0]
2004             rrtype = rr[rrname].keys()[0]
2005             dbrec = rr[rrname][rrtype][0]
2006             if rrname == answername and rrtype == 'A':
2007                 self.cbfunc(dbrec['address'])
2008                 return
2009         # if we got a cname and no A send query for cname
2010         if cname:
2011             self.msg = message()
2012             self.msg.question.qname = cname
2013             self.msg.question.qtype = 'A'
2014             self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
2015         else:
2016             self.cbfunc('')
2017
2018     def writable(self):
2019         return 0
2020
2021     def handle_write(self):
2022         pass
2023
2024     def handle_connect(self):
2025         pass
2026
2027     def handle_close(self):
2028         self.close()
2029
2030     def log_info (self, message, type='info'):
2031         if __debug__ or type != 'info':
2032             log(0,'%s: %s' % (type, message))
2033
2034 class simpleudprequest(asyncore.dispatcher):
2035     def __init__(self, msg, cbfunc, serveraddr='127.0.0.1', outqkey=''):
2036         asyncore.dispatcher.__init__(self)
2037         self.gotanswer = 0
2038         self.msg = msg
2039         self.cbfunc = cbfunc
2040         self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2041         self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2042         self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)
2043         self.outqkey = outqkey
2044         self.socket.sendto(self.msg.buildpkt(), (serveraddr,53))
2045
2046     def handle_read(self):
2047         replydata, addr = self.socket.recvfrom(1500)
2048         self.close()
2049         try:
2050             replymsg = message(replydata)
2051         except:
2052             log(0,'unable to process packet')
2053             return
2054         self.cbfunc(replymsg, self.outqkey)
2055
2056     def writable(self):
2057         return 0
2058
2059     def handle_write(self):
2060         pass
2061
2062     def handle_connect(self):
2063         pass
2064
2065     def handle_close(self):
2066         self.close()
2067
2068     def log_info (self, message, type='info'):
2069         if __debug__ or type != 'info':
2070             log(0,'%s: %s' % (type, message))
2071
2072 class simpletcprequest(asyncore.dispatcher):
2073     def __init__(self, msg, cbfunc, cbparams=[], serveraddr='127.0.0.1', errorfunc=''):
2074         asyncore.dispatcher.__init__(self)
2075         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
2076         self.query = msg
2077         self.cbfunc = cbfunc
2078         self.cbparams = cbparams
2079         self.errorfunc = errorfunc
2080         msgdata = msg.buildpkt()
2081         ml = inttoasc(len(msgdata))
2082         if len(ml) == 1:
2083             ml = chr(0) + ml
2084         self.buffer = ml+msgdata
2085         self.rbuffer = ''
2086         self.rmsgleft = 0
2087         self.rrlist = []
2088         log(2,'sending tcp request to ' + serveraddr)
2089         self.connect((serveraddr,53))
2090
2091     def recv (self, buffer_size):
2092         try:
2093             data = self.socket.recv (buffer_size)
2094             if not data:
2095                 # a closed connection is indicated by signaling
2096                 # a read condition, and having recv() return 0.
2097                 self.handle_close()
2098                 return ''
2099             else:
2100                 return data
2101         except socket.error, why:
2102             # winsock sometimes throws ENOTCONN
2103             if why[0] in [ECONNRESET, ENOTCONN, ESHUTDOWN, ETIMEDOUT]:
2104                 self.handle_close()
2105                 return ''
2106             else:
2107                 raise socket.error, why
2108
2109     def handle_connect(self):
2110         pass
2111
2112     def handle_msg(self, msg):
2113         if self.query.question.qtype == 'AXFR':
2114             if len(self.rrlist) == 0:
2115                 if len(msg.answerlist) == 0:
2116                     if self.errorfunc:
2117                         self.errorfunc(self.cbparams[0])
2118                     self.close()
2119                     return
2120             rr = msg.answerlist[0]
2121             rrname = rr.keys()[0]
2122             rrtype = rr[rrname].keys()[0]
2123             self.rrlist.append(rr)
2124             if rrtype == 'SOA' and len(self.rrlist) > 1:
2125                 self.close()
2126                 if self.cbparams:
2127                     self.cbfunc(self.rrlist, self.cbparams)
2128                 else:
2129                     self.cbfunc(self.rrlist)
2130         else:
2131             self.close()
2132             if self.cbparams:
2133                 self.cbfunc(msg, self.cbparams)
2134             else:
2135                 self.cbfunc(msg)
2136
2137     def handle_read(self):
2138         data = self.recv(8192)
2139         if len(self.rbuffer) == 0:
2140             self.rmsglength = asctoint(data[:2])
2141             data = data[2:]
2142         self.rbuffer = self.rbuffer + data
2143         while len(self.rbuffer) >= self.rmsglength and self.rmsglength != 0:
2144             msgdata = self.rbuffer[:self.rmsglength]
2145             self.rbuffer = self.rbuffer[self.rmsglength:]
2146             if len(self.rbuffer) == 0:
2147                 self.rmsglength = 0
2148             else:
2149                 self.rmsglength = asctoint(self.rbuffer[:2])
2150                 self.rbuffer = self.rbuffer[2:]
2151             try:
2152                 self.handle_msg(message(msgdata))
2153             except:
2154                 return
2155             
2156     def writable(self):
2157         return (len(self.buffer) > 0)
2158     
2159     def handle_write(self):
2160         sent = self.send(self.buffer)
2161         self.buffer = self.buffer[sent:]
2162
2163     def handle_close(self):
2164         if self.errorfunc:
2165             self.errorfunc(self.query.question.qname)
2166         self.close()
2167
2168     def log_info (self, message, type='info'):
2169         if __debug__ or type != 'info':
2170             log(0,'%s: %s' % (type, message))
2171
2172 class udpdnsserver(asyncore.dispatcher):
2173     def __init__(self, port, dnsserver):
2174         asyncore.dispatcher.__init__(self)
2175         self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2176         self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2177         self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
2178         self.bind(('',port))
2179         self.dnsserver = dnsserver
2180         self.maxmsgsize = 500
2181
2182     def handle_read(self):
2183         try:
2184             while 1:
2185                 msgdata, addr = self.socket.recvfrom(1500)
2186                 self.dnsserver.handle_packet(msgdata, addr, self)
2187         except socket.error, why:
2188             if why[0] != asyncore.EWOULDBLOCK:
2189                 raise socket.error, why
2190
2191     def sendpackets(self, msglist, addr):
2192         for msg in msglist:
2193             msgdata = msg.buildpkt()
2194             if len(msgdata) > self.maxmsgsize:
2195                 msg.header.tc = 1
2196                 # take off all the answers to ensure
2197                 # the packet size is small enough
2198                 msg.header.ancount = 0
2199                 msg.header.nscount = 0
2200                 msg.header.arcount = 0
2201                 msg.answerlist = []
2202                 msg.authlist = []
2203                 msg.addlist = []
2204                 msgdata = msg.buildpkt()
2205             self.sendto(msgdata, addr)
2206         
2207     def writable(self):
2208         return 0
2209
2210     def handle_write(self):
2211         pass
2212
2213     def handle_connect(self):
2214         pass
2215
2216     def handle_close(self):
2217         # print '1:In handle close'
2218         return
2219
2220     def log_info (self, message, type='info'):
2221         if __debug__ or type != 'info':
2222             log(0,'%s: %s' % (type, message))
2223
2224 class tcpdnschannel(asynchat.async_chat):
2225     def __init__(self, server, s, addr):
2226         asynchat.async_chat.__init__(self, s)
2227         self.server = server
2228         self.addr = addr
2229         self.set_terminator(None)
2230         self.databuffer = ''
2231         self.msglength = 0
2232         log(3,'Created new tcp channel')
2233
2234     def collect_incoming_data(self, data):
2235         if self.msglength == 0:
2236             self.msglength = asctoint(data[:2])
2237             data = data[2:]
2238         self.databuffer = self.databuffer + data
2239         if len(self.databuffer) == self.msglength:
2240             # got entire message
2241             self.server.dnsserver.handle_packet(self.databuffer, self.addr, self)
2242             self.databuffer = ''
2243             
2244     def sendpackets(self, msglist, addr):
2245         for msg in msglist:
2246             x = msg.buildpkt()
2247             ml = inttoasc(len(x))
2248             if len(ml) == 1:
2249                 ml = chr(0) + ml
2250             self.push(ml+x)
2251         self.close()
2252
2253     def log_info (self, message, type='info'):
2254         if __debug__ or type != 'info':
2255             log(0,'%s: %s' % (type, message))
2256
2257 class tcpdnsserver(asyncore.dispatcher):
2258     def __init__(self, port, dnsserver):
2259         asyncore.dispatcher.__init__(self)
2260         self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
2261         self.set_reuse_addr()
2262         self.bind(('',port))
2263         self.listen(5)
2264         self.dnsserver = dnsserver
2265
2266     def handle_accept(self):
2267         conn, addr = self.accept()
2268         tcpdnschannel(self, conn, addr)
2269
2270     def handle_close(self):
2271         self.close()
2272
2273     def log_info (self, message, type='info'):
2274         if __debug__ or type != 'info':
2275             log(0,'%s: %s' % (type, message))
2276
2277 class nameserver:
2278     def __init__(self, resolver, localconfig):
2279         self.resolver = resolver
2280         self.config = localconfig
2281         self.zdb = self.config.zonedatabase
2282         self.last_reap_time = time.time()
2283         self.maint_int = 10
2284         self.slavesupdating = []
2285         self.notifys = []
2286         self.sentnotify = []
2287         self.notify_retry_time = 30
2288         self.notify_retries = 4
2289         self.askedsoa = {}
2290         self.soatimeout = 10
2291
2292     def error(self, id, qname, querytype, queryclass, rcode):
2293         error = message()
2294         error.header.id = id
2295         error.header.rcode = rcode
2296         error.header.qr = 1
2297         error.question.qname = qname
2298         error.question.qtype = querytype
2299         error.question.qclass = queryclass
2300         return error
2301
2302     def need_zonetransfer(self, zkey, origin, masterip, trynum=0):
2303         self.askedsoa[zkey] = {'masterip':masterip,
2304                                'senttime':time.time(),
2305                                'origin':origin,
2306                                'trynum':trynum+1}
2307         query = message()
2308         query.header.id = random.randrange(1,32768)
2309         query.header.rd = 0
2310         query.question.qname = origin
2311         query.question.qtype = 'SOA'
2312         query.question.qclass = 'IN'
2313         log(3,'slave checking for new data in ' + origin)
2314         simpleudprequest(query, self.handle_soaquery,
2315                          masterip, zkey)
2316
2317     def handle_soaquery(self, msg, zkey):
2318         origin = msg.question.qname
2319         masterip = self.askedsoa[zkey]['masterip']
2320         del self.askedsoa[zkey]
2321         if zkey not in self.slavesupdating:
2322             self.slavesupdating.append(zkey)
2323             query = message()
2324             query.header.id = random.randrange(1,32768)
2325             query.header.rd = 0
2326             query.question.qname = origin
2327             query.question.qtype = 'AXFR'
2328             query.question.qclass = 'IN'
2329             log(3,'Updating slave zone: ' + zkey)
2330             simpletcprequest(query, self.handle_zonetrans,
2331                              [zkey],masterip,self.handle_zterror)
2332
2333     def handle_zonetrans(self, rrlist, params):
2334         log(1,'handling zone transfer')
2335         zonekey = params[0]
2336         self.zdb.update_zone(rrlist, params)
2337         self.slavesupdating.remove(zonekey)
2338
2339     def handle_zterror(self, zonekey):
2340         self.slavesupdating.remove(zonekey)
2341         self.zdb.remove_zone(zonekey)
2342
2343     def rrmatch(self, rrset1, rrset2):
2344         for rrtype in rrset1.keys():
2345             if rrtype not in rrset2.keys():
2346                 return
2347             else:
2348                 if len(rrset1[rrtype]) != len(rrset2[rrtype]):
2349                     return
2350         return 1
2351
2352     def process_notify(self, msg, ipaddr, port):
2353         (zkeys, dorecursion, flist) = self.config.getview(msg, ipaddr, port)
2354         goodzkey = ''
2355         for zkey in zkeys:
2356             origin = self.zdb.getorigin(zkey)
2357             if origin == msg.question.qname:
2358                 masterip = self.zdb.getmasterip(zkey)
2359                 if masterip:
2360                     goodzkey = zkey
2361         if goodzkey:
2362             log(3,'got NOTIFY from ' + masterip)
2363             self.need_zonetransfer(goodzkey, origin, masterip, 0)
2364         return
2365
2366     def notify(self):
2367         curtime = time.time()
2368         for origin, ipaddr, trynum, senttime in self.sentnotify:
2369             if senttime + self.notify_retry_time > curtime:
2370                 self.notifys.append((origin, ipaddr, trynum))
2371                 self.sentnotify.remove((origin, ipaddr, trynum, senttime))
2372         for origin, ipaddr, trynum in self.notifys:
2373             msg = message()
2374             msg.question.qname = origin
2375             msg.question.qtype = 'SOA'
2376             msg.question.qclass = 'IN'
2377             msg.header.opcode = 4
2378             # there probably is a better way to do this
2379             if self.resolver:
2380                 self.resolver.send_to([msg],(ipaddr,53))
2381                 if trynum+1 <= self.notify_retries:
2382                     self.sentnotify.append((origin,ipaddr,trynum+1,curtime))
2383         self.notifys = []
2384         
2385     def handle_packet(self, msgdata, addr, server):
2386         # self.reap()
2387         try:
2388             msg = message(msgdata)
2389         except:
2390             return
2391         # find a matching view
2392         (zkeys, dorecursion, flist) = self.config.getview(msg, addr[0], addr[1])
2393         if not msg.header.qr and msg.header.opcode == 5:
2394             log(2,'GOT UPDATE PACKET')
2395             # check the zone section
2396             if (msg.header.zocount != 1 or
2397                 msg.zone.ztype != 'SOA' or
2398                 msg.zone.zclass != 'IN'):
2399                 log(2,'SENDING FORMERR UPDATE ERROR')
2400                 errormsg = self.error(msg.header.id, msg.zone.zname,
2401                                   msg.zone.ztype, msg.zone.zclass, 1)
2402                 server.sendpackets([errormsg],addr)
2403             else:
2404                 (answer, origin, slaves) = self.zdb.handle_update(msg, addr, self)
2405                 if answer.header.rcode == 0:
2406                     # schedule NOTIFYs to slaves
2407                     for ipaddr in slaves:
2408                         self.notifys.append((origin, ipaddr, 0))
2409                 server.sendpackets([answer],addr)
2410         elif msg.header.opcode == 4:
2411             if msg.header.qr:
2412                 log(0,'got NOTIFY response')
2413                 for origin, ipaddr, trynum, senttime in self.sentnotify:
2414                     if ipaddr == addr[0] and msg.question.qname == origin:
2415                         self.sentnotify.remove((origin, ipaddr, trynum, senttime))
2416             else:
2417                 log(0,'got NOTIFY')
2418                 self.process_notify(msg, addr[0], addr[1])
2419         elif not msg.header.qr and msg.header.opcode == 0:
2420             # it's a question
2421             qname = msg.question.qname.lower()
2422             log(2,'GOT QUERY for ' + qname + '(' + msg.question.qtype +
2423                 ') from ' + addr[0])
2424             # handle special version packet
2425             if (msg.question.qtype == 'TXT' and
2426                 msg.question.qclass == 'CH'):
2427                 if qname == 'version.bind':                    
2428                     server.sendpackets([getversion(qname,
2429                                                   msg.header.id,
2430                                                   msg.header.rd,
2431                                                   dorecursion, '1.0')],addr)
2432                 elif qname == 'version.oak':
2433                     server.sendpackets([getversion(qname,
2434                                                   msg.header.id,
2435                                                   msg.header.rd,
2436                                                   dorecursion, '1.0')],addr)
2437                 return
2438             self.zdb.lookup(zkeys, msg, addr, server, dorecursion,
2439                                flist, self.lookup_callback)
2440
2441     def lookup_callback(self, msg, addr, server, dorecursion, flist, answerlist):
2442         if answerlist:
2443             server.sendpackets(self.config.outpackets(answerlist), addr)
2444         elif dorecursion:
2445             if msg.question.qtype in ['AXFR','IXFR']:
2446                 if msg.question.qname == 'cache' and msg.question.qtype == 'AXFR':
2447                     if self.resolver:
2448                         server.sendpackets(self.resolver.cache.zonetrans(msg.header.id),addr)
2449                 else:
2450                     # won't forward zone transfers and
2451                     # don't handle recursive zone transfers
2452                     server.sendpackets([self.error(msg.header.id, msg.question.qname,
2453                                                    msg.question.qtype,
2454                                                    msg.question.qclass,2)],addr)
2455             else:
2456                 self.resolver.handle_query(msg, addr, flist, server.sendpackets)
2457                              
2458     def reap(self):
2459         log(4,'in nameserver reap')
2460         # do all maintenence (interval) stuff here
2461         if self.resolver:
2462             self.resolver.reap()
2463         self.notify()
2464         curtime = time.time()
2465         if curtime > (self.last_reap_time + self.maint_int):
2466             self.last_reap_time = curtime
2467             # do zone transfers here if slave server and haven't asked for soa
2468             for (zkey, origin, masterip) in self.zdb.getslaves(curtime):
2469                 if not self.askedsoa.has_key(zkey):
2470                     self.need_zonetransfer(zkey, origin, masterip)
2471         for zkey in self.askedsoa.keys():
2472             if curtime > self.askedsoa[zkey]['senttime'] + self.soatimeout:
2473                 if self.askedsoa[zkey]['trynum'] > 3:
2474                     self.zdb.remove_zone(zkey)
2475                     del self.askedsoa[zkey]                    
2476                 else:
2477                     masterip = self.askedsoa[zkey]['masterip']
2478                     origin = self.askedsoa[zkey]['origin']
2479                     trynum = self.askedsoa[zkey]['trynum']
2480                     del self.askedsoa[zkey]
2481                     self.need_zonetransfer(zkey, origin, masterip, trynum)
2482                 
2483     def log_info (self, message, type='info'):
2484         if __debug__ or type != 'info':
2485             log(0,'%s: %s' % (type, message))
2486
2487 class resolver(asyncore.dispatcher):
2488     def __init__(self, cache, port=0):
2489         asyncore.dispatcher.__init__(self)
2490         self.create_socket(socket.AF_INET, socket.SOCK_DGRAM)
2491         self.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 200 * 1024)
2492         self.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 200 * 1024)   
2493         self.bind(('',port))
2494         self.cache = cache
2495         self.outqnum = 0
2496         self.outq = {}
2497         self.holdq = {}
2498         self.holdtime = 10
2499         self.holdqlength = 100
2500         self.last_reap_time = time.time()
2501         self.maint_int = 10
2502         self.timeout = 3
2503
2504     def getoutqkey(self):
2505         self.outqnum = self.outqnum + 1
2506         if self.outqnum == 99999:
2507             self.outqnum = 1
2508         return str(self.outqnum)
2509
2510     def error(self, id, qname, querytype, queryclass, rcode):
2511         error = message()
2512         error.header.id = id
2513         error.header.rcode = rcode
2514         error.header.qr = 1
2515         error.question.qname = qname
2516         error.question.qtype = querytype
2517         error.question.qclass = queryclass
2518         return error
2519
2520     def qpacket(self, id, qname, querytype, queryclass):
2521         # create a question
2522         query = message()
2523         query.header.id = id
2524         query.header.rd = 0
2525         query.question.qname = qname
2526         query.question.qtype = querytype
2527         query.question.qclass = queryclass
2528         return query
2529
2530     def send_to(self, msglist, addr):
2531         for msg in msglist:
2532             data = msg.buildpkt()
2533             if len(data) > 512:
2534                 # packet to big
2535                 msg.header.tc = 1
2536                 msg.header.ancount = 0
2537                 msg.answerlist = []
2538                 msg.header.nscount = 0
2539                 msg.authlist = []
2540                 msg.header.arcount = 0
2541                 msg.addlist = []
2542                 self.socket.sendto(msg.buildpkt(), addr)
2543             else:
2544                 self.socket.sendto(data, addr)
2545
2546     def handle_read(self):
2547         try:
2548             while 1:
2549                 msgdata, addr = self.socket.recvfrom(1500)
2550                 # should put 'try' here in production server
2551                 self.handle_packet(msgdata, addr)
2552         except socket.error, why:
2553             if why[0] != asyncore.EWOULDBLOCK:
2554                 raise socket.error, why
2555
2556     def handle_packet(self, msgdata, addr):
2557         try:
2558             msg = message(msgdata)
2559         except:
2560             return
2561         if not msg.header.qr:
2562             self.handle_query(msg, addr, [], self.send_to)
2563         else:
2564             log(2,'received unsolicited reply')
2565
2566
2567     def handle_query(self, msg, addr, flist, cbfunc):
2568         qname = msg.question.qname
2569         querytype = msg.question.qtype
2570         queryclass = msg.question.qclass
2571         # check the cache first
2572         answer = self.cache.haskey(qname,querytype,msg)
2573         if answer:
2574             cbfunc([answer], addr)
2575             log(2,'sent answer for ' + qname + '(' + querytype +
2576                 ') from cache')
2577         else:
2578             # check if query is already in progess
2579             for oqkey in self.outq.keys():
2580                 if (self.outq[oqkey]['qname'] == qname and
2581                     self.outq[oqkey]['querytype'] == querytype):
2582                     log(2,'query already in progress for '+qname+'('+querytype+')')
2583                     # put entry in hold queue to try later
2584                     hqrec = {'processtime':time.time()+self.holdtime,
2585                              'query':msg,'addr':addr,
2586                              'qname':qname,'querytype':querytype,
2587                              'queryclass':queryclass,
2588                              'cbfunc':cbfunc}
2589                     self.putonhold(hqrec)
2590                     return
2591                 
2592             outqkey = self.getoutqkey()+str(msg.header.id)                
2593             self.outq[outqkey] = {'query':msg,
2594                                   'addr':addr,
2595                                   'qname':qname,
2596                                   'querytype':querytype,
2597                                   'queryclass':queryclass,
2598                                   'cbfunc':cbfunc,
2599                                   'answerlist':[],
2600                                   'addlist':[],
2601                                   'qsent':0}
2602             if flist:
2603                 self.outq[outqkey]['flist'] = flist
2604                 self.askfns(outqkey)
2605             else:
2606                 self.askns(outqkey)
2607
2608     def putonhold(self,hqrec):
2609         hqid = hqrec['qname']+hqrec['querytype']
2610         if self.holdq.has_key(hqid):
2611             if len(self.holdq[hqid]) < self.holdqlength:
2612                 hqrec['processtime']=time.time()+self.holdtime
2613                 self.holdq[hqid].append(hqrec)
2614         
2615             
2616     def askns(self, outqkey):
2617         qname = self.outq[outqkey]['qname']
2618         querytype = self.outq[outqkey]['querytype']
2619         queryclass = self.outq[outqkey]['queryclass']
2620         # don't try more than 10 times to avoid loops
2621         if self.outq[outqkey]['qsent'] == 10:
2622             del self.outq[outqkey]
2623             log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
2624                    ' POSSIBLE LOOP')
2625             return
2626         # find the best nameservers to ask from the cache
2627         (qzone, nsdict) = self.cache.getnslist(qname)
2628         if not nsdict:
2629             # there are no good servers
2630             if self.outq[outqkey]['addr'] != 'IQ':
2631                 qid = self.outq[outqkey]['query'].header.id
2632                 self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
2633                                    self.outq[outqkey]['addr'])
2634             del self.outq[outqkey]
2635             log(2,'Dropping query for ' + qname + '(' + querytype + ')' +
2636                    'no good name servers to ask')
2637             return
2638         # pick the best nameserver
2639         rtts = nsdict.keys()
2640         rtts.sort()
2641         bestnsip = nsdict[rtts[0]]['ip']
2642         bestnsname = nsdict[rtts[0]]['name']
2643         # fill in the callback data structure
2644         id=random.randrange(1,32768)
2645         self.outq[outqkey]['nsqueriedlastip'] = bestnsip
2646         self.outq[outqkey]['nsqueriedlastname'] = bestnsname
2647         self.outq[outqkey]['nsdict'] = nsdict
2648         self.outq[outqkey]['qzone'] = qzone
2649         self.outq[outqkey]['qsenttime'] = time.time()
2650         self.outq[outqkey]['qsent'] = self.outq[outqkey]['qsent'] + 1
2651         # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (bestnsip,53))
2652         self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
2653                                                       self.handle_response, bestnsip, outqkey)
2654         # update rtt so that we ask a different server next time
2655         self.cache.updatertt(bestnsname,qzone,1)
2656         log(2,outqkey+'|sent query to ' + bestnsip + '(' + bestnsname +
2657                ') for ' + qname + '(' + querytype + ')')
2658
2659     def askfns(self, outqkey):
2660         flist = self.outq[outqkey]['flist']
2661         qname = self.outq[outqkey]['qname']
2662         querytype = self.outq[outqkey]['querytype']
2663         queryclass = self.outq[outqkey]['queryclass']
2664         self.outq[outqkey]['qsenttime'] = time.time()        
2665         id=random.randrange(1,32768)
2666         # self.socket.sendto(self.qpacket(id,qname,querytype,queryclass), (flist[0],53))
2667         self.outq[outqkey]['request'] = simpleudprequest(self.qpacket(id,qname,querytype,queryclass),
2668                                                          self.handle_fresponse, flist[0], outqkey)
2669         log(2,''+outqkey+'|sent query to forwarder')
2670
2671     def handle_response(self, msg, outqkey):
2672         # either reponse:
2673         # 1. contains a name error
2674         # 2. answers the question
2675         #    (cache data and return it)
2676         # 3. is (contains) a CNAME and qtype isn't
2677         #    (cache cname and change qname to it)
2678         #    (check if qname and qtype are in any other rrs in the response)
2679         #    (must check cache again here)
2680         # 4. contains a better delegation
2681         #    (cache the delegation and start again)
2682         # 5. is aserver failure
2683         #    (delete server from list and try again)
2684
2685         # make sure that original question is still outstanding
2686         if not self.outq.has_key(outqkey):
2687             # should never get here
2688             # if we do we aren't doing housekeeping of callbacks very well
2689             log(2,''+outqkey+'|got response for a question already answered for ' + msg.question.qname)
2690             return
2691
2692         querytype = self.outq[outqkey]['querytype']
2693         if msg.header.rcode not in [1,2,4,5]:        
2694             # update rtt time
2695             rtt = time.time() - self.outq[outqkey]['qsenttime']
2696             nsname = self.outq[outqkey]['nsqueriedlastname']
2697             zone = self.outq[outqkey]['qzone']
2698             self.cache.updatertt(nsname,zone,rtt)
2699
2700         if msg.header.rcode == 3:
2701             log(2,outqkey+'|GOT Name Error for ' + msg.question.qname +
2702                 '(' + msg.question.qtype + ')')
2703             # name error
2704             # cache negative answer
2705             self.cache.addneg(self.outq[outqkey]['qname'],
2706                               self.outq[outqkey]['querytype'],
2707                               self.outq[outqkey]['queryclass'])
2708             if self.outq[outqkey]['addr'] != 'IQ':
2709                 answer = message()                
2710                 answer.question.qname = self.outq[outqkey]['query'].question.qname
2711                 answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2712                 answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2713                 answer.header.id = self.outq[outqkey]['query'].header.id
2714                 answer.header.qr = 1
2715                 answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2716                 answer.header.ra = 1
2717                 self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2718             del self.outq[outqkey]
2719             
2720         elif msg.header.ancount > 0:
2721             # answer (may be CNAME)
2722             haveanswer = 0
2723             cname = ''
2724             log(2,'CACHING ANSWERLIST ENTRIES')
2725             for rr in msg.answerlist:
2726                 rrname = rr.keys()[0]
2727                 rrtype = rr[rrname].keys()[0]
2728                 if ((rrname == msg.question.qname or rrname == cname ) and
2729                     rrtype == msg.question.qtype):
2730                     haveanswer = 1
2731                 if rrname == msg.question.qname and rrtype == 'CNAME':
2732                     cname = rr[rrname][rrtype][0]['cname']
2733                 self.cache.add(rr, self.outq[outqkey]['qzone'],
2734                                self.outq[outqkey]['nsqueriedlastname'])
2735             if haveanswer:
2736                 if self.outq[outqkey]['addr'] != 'IQ':
2737                     log(2,''+outqkey+'|GOT Answer for ' + msg.question.qname +
2738                         '(' + msg.question.qtype + ')' )
2739                     answer = message()
2740                     answer.answerlist = msg.answerlist + self.outq[outqkey]['answerlist']
2741                     answer.header.ancount = len(answer.answerlist)
2742                     answer.question.qname = self.outq[outqkey]['query'].question.qname
2743                     answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2744                     answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2745                     answer.header.id = self.outq[outqkey]['query'].header.id
2746                     answer.header.qr = 1
2747                     answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2748                     answer.header.ra = 1
2749                     self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2750                     log(2,outqkey+'|sent answer retrieved from remote server for ' +
2751                            self.outq[outqkey]['query'].question.qname)
2752                 else:
2753                     log(2,outqkey+'|GOT Answer(IQ) for ' + msg.question.qname + '(' +
2754                         msg.question.qtype + ')')
2755                 del self.outq[outqkey]
2756             elif cname:
2757                 log(2,outqkey+'|GOT CNAME for ' + msg.question.qname + '(' + msg.question.qtype + ')')
2758                 self.outq[outqkey]['answerlist'] = self.outq[outqkey]['answerlist'] + msg.answerlist
2759                 self.outq[outqkey]['qname'] = cname
2760                 self.askns(outqkey)
2761             else:
2762                 log(2,outqkey+'|GOT BOGUS answer for '  + msg.question.qname + '(' +
2763                     msg.question.qtype + ')')
2764                 del self.outq[outqkey]
2765             
2766         elif msg.header.nscount > 0 and msg.header.ancount == 0:
2767             log(2,outqkey+'|GOT DELEGATION for ' + msg.question.qname + '(' + msg.question.qtype + ')')
2768             # delegation
2769             # cache the nameserver rrs and start over
2770             # if there are no glue records for nameservers must fetch them first
2771             log(2,'CACHING AUTHLIST ENTRIES')
2772             for rr in msg.authlist:
2773                 self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2774             log(2,'CACHING ADDLIST ENTRIES')
2775             for rr in msg.addlist:
2776                 self.cache.add(rr,self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2777             rrlist = msg.authlist+msg.addlist
2778             fetchglue = 0
2779             nscount = 0
2780             for rr in msg.authlist:
2781                 nodename = rr.keys()[0]
2782                 if rr[nodename].keys()[0] == 'NS':
2783                     nscount = nscount + 1
2784                     nsdname = rr[nodename]['NS'][0]['nsdname']
2785                     if not self.cache.haskey(nsdname,'A'):
2786                         log(2,outqkey+'|Glue record not in cache for ' + nsdname + '(A)')
2787                         fetchglue = fetchglue + 1
2788                         # need to fetch A rec
2789                         noutqkey = self.getoutqkey()+str(random.randrange(1,32768))
2790                         self.outq[noutqkey] = {'query':'',
2791                                                'addr':'IQ',
2792                                                'qname':nsdname,
2793                                                'querytype':'A',
2794                                                'queryclass':'IN',
2795                                                'qsent':0}
2796                         log(2,outqkey+'|sending a query to fetch glue records for ' + nsdname + '(A)')
2797                         self.askns(noutqkey)
2798             if not nscount:
2799                 log(2,outqkey+'|Dropping query (no ns recs) for ' +
2800                        msg.question.qname + '(' + msg.question.qtype + ')' )
2801                 del self.outq[outqkey]
2802             elif fetchglue == nscount:
2803                 log(2,outqkey+'|Stalling query (no glue recs) for ' +
2804                        msg.question.qname + '(' + msg.question.qtype + ')')
2805                 self.putonhold(self.outq[outqkey])
2806                 del self.outq[outqkey]                
2807             else:
2808                 log(2,outqkey+'|got (some) glue with delegation')
2809                 self.askns(outqkey)
2810
2811         elif msg.header.rcode in [1,2,4,5]:
2812             log(2,outqkey+'|GOT ' + getrcode(msg.header.rcode))
2813             log(2,'SERVER ' + self.outq[outqkey]['nsqueriedlastname'] + '(' + 
2814                  self.outq[outqkey]['nsqueriedlastip'] + ') FAILURE for ' + msg.question.qname)
2815             # don't ask this server for a while
2816             self.cache.badns(self.outq[outqkey]['qzone'],self.outq[outqkey]['nsqueriedlastname'])
2817             self.askns(outqkey)
2818         else:
2819             log(2,outqkey+'|GOT UNPARSEABLE REPLY')
2820             msg.printpkt()
2821
2822     def handle_fresponse(self, msg, outqkey):
2823         if msg.header.rcode in [1,2,4,5]:
2824             self.outq[outqkey]['flist'].pop(0)
2825             if len(self.outq[outqkey]['flist']) == 0:
2826                 qid = self.outq[outqkey]['query'].header.id
2827                 qname = self.outq[outqkey]['qname']
2828                 querytype = self.outq[outqkey]['querytype']
2829                 queryclass = self.outq[outqkey]['queryclass']
2830                 self.outq[outqkey]['cbfunc'](self.error(qid,qname,querytype,queryclass,2),
2831                                              self.outq[outqkey]['addr'])
2832                 del self.outq[outqkey]
2833             else:
2834                 self.askfns(outqkey)
2835         else:
2836             answer = message()
2837             answer.header.id = self.outq[outqkey]['query'].header.id
2838             answer.header.qr = 1
2839             answer.header.opcode = self.outq[outqkey]['query'].header.opcode
2840             answer.header.ra = 1
2841             answer.question.qname = self.outq[outqkey]['query'].question.qname
2842             answer.question.qtype = self.outq[outqkey]['query'].question.qtype
2843             answer.question.qclass = self.outq[outqkey]['query'].question.qclass
2844             answer.header.ancount = msg.header.ancount
2845             answer.header.nscount = msg.header.nscount
2846             answer.header.arcount = msg.header.arcount                        
2847             answer.answerlist = msg.answerlist
2848             answer.authlist = msg.authlist
2849             answer.addlist = msg.addlist                
2850             if msg.header.rcode == 3:
2851                 # name error
2852                 # cache negative answer
2853                 self.cache.addneg(self.outq[outqkey]['qname'],
2854                                   self.outq[outqkey]['querytype'],
2855                                   self.outq[outqkey]['queryclass'])
2856             else:
2857                 # cache all rrs
2858                 for rr in msg.answerlist:
2859                     self.cache.add(rr,'','forwarder')
2860                 for rr in msg.authlist:
2861                     self.cache.add(rr,'','forwarder')
2862                 for rr in msg.addlist:
2863                     self.cache.add(rr,'','forwarder')
2864             self.outq[outqkey]['cbfunc']([answer], self.outq[outqkey]['addr'])
2865             del self.outq[outqkey]
2866
2867     def writable(self):
2868         return 0
2869
2870     def handle_write(self):
2871         pass
2872
2873     def handle_connect(self):
2874         pass
2875
2876     def handle_close(self):
2877         # print '1:In handle close'
2878         return
2879
2880     def process_holdq(self):
2881         curtime = time.time()
2882         for hqkey in self.holdq.keys():
2883             for hqrec in self.holdq[hqkey]:
2884                 if curtime >= hqrec['processtime']:
2885                     log(2,'processing held query')
2886                     answer = self.cache.haskey(hqrec['qname'],
2887                                                hqrec['querytype'],
2888                                                hqrec['query'])
2889                     if answer:
2890                         hqrec['cbfunc']([answer], hqrec['addr'])
2891                         log(2,'sent answer for ' + hqrec['qname'] +
2892                             '(' + hqrec['querytype'] +  ') from cache')
2893                     self.holdq[hqkey].remove(hqrec)
2894             if len(self.holdq[hqkey]) == 0:
2895                 del self.holdq[hqkey]
2896
2897     def reap(self):
2898         self.process_holdq()
2899         curtime = time.time()
2900         log(3,timestamp() + 'processed HOLDQ (sockets: ' +
2901             str(len(asyncore.socket_map.keys()))+')')
2902         if curtime > (self.last_reap_time + self.maint_int):
2903             self.last_reap_time = curtime
2904             for outqkey in self.outq.keys():
2905                 if curtime > self.outq[outqkey]['qsenttime'] + self.timeout:
2906                     log(2,'query for '+self.outq[outqkey]['qname']+'('+
2907                         self.outq[outqkey]['querytype']+') expired')
2908                     # don't set forwarders as bad
2909                     if not self.outq[outqkey].has_key('flist'):
2910                         self.cache.badns(self.outq[outqkey]['qzone'],
2911                                          self.outq[outqkey]['nsqueriedlastname'])
2912                     if self.outq[outqkey].has_key('request'):
2913                         log(3,'closing socket for expired query')
2914                         self.outq[outqkey]['request'].close()
2915                     del self.outq[outqkey]
2916         return
2917
2918     def log_info (self, message, type='info'):
2919         if __debug__ or type != 'info':
2920             log(0,'%s: %s' % (type, message))
2921
2922
2923 def run(configobj):
2924     global loglevel
2925     r = resolver(dnscache(configobj.cached))
2926     ns = nameserver(r, configobj)
2927     udpds = udpdnsserver(53, ns)
2928     tcpds = tcpdnsserver(53, ns)
2929     loglevel = configobj.loglevel
2930     try:
2931         loop(ns.reap)
2932     except KeyboardInterrupt:
2933         print 'server done'
2934
2935 if __name__ == '__main__':
2936     sipb_xen_database.connect('postgres://sipb-xen@sipb-xen-dev/sipb_xen')
2937     zonedict = {'example.net':{'origin':'example.net',
2938                                'filename':'db.example.net',
2939                                'type':'master',
2940                                'slaves':[]}}
2941
2942
2943     zonedict = {'servers.csail.mit.edu':{'origin':'servers.csail.mit.edu',
2944                                          'filename':'db.servers.csail.mit.edu',
2945                                          'type':'master',
2946                                          'slaves':[]}}
2947
2948     zonedict2 = {'example.net':{'origin':'example.net',
2949                                 'filename':'db.example.net',
2950                                 'type':'slave',
2951                                 'masterip':'127.0.0.1'}}
2952     readzonefiles(zonedict)
2953     lconfig = dnsconfig()
2954     lconfig.zonedatabase = zonedb(zonedict)
2955     pr = zonefileparser()
2956     pr.parse('','db.ca')
2957     lconfig.cached = pr.getzdict()
2958     lconfig.loglevel = 3
2959
2960     run(lconfig)