Fix XVM's classic networking problem
[invirt/packages/invirt-dhcp.git] / invirt-dhcpserver
1 #!/usr/bin/python
2 import os.path
3 import sys
4 import pydhcplib
5 import pydhcplib.dhcp_network
6 from pydhcplib.dhcp_packet import *
7 from pydhcplib.type_hw_addr import hwmac
8 from pydhcplib.type_ipv4 import ipv4
9 from pydhcplib.type_strlist import strlist
10 import socket
11 import IN
12 from Queue import Queue
13 from threading import Thread
14 from subprocess import PIPE, Popen
15 import netifaces as ni
16 sys.path.append('/usr/lib/xen-default/lib/python/')
17 from xen.lowlevel import xs
18
19 import syslog as s
20
21 import time
22 from invirt import database
23 from invirt.config import structs as config
24
25 dhcp_options = {'domain_name_server': ','.join(config.dhcp.dns),
26                 'ip_address_lease_time': config.dhcp.get('leasetime', 60*60*24)}
27
28 class Interfaces(object):
29     @staticmethod
30     def primary_ip(name):
31         """primary_ip returns an interface's primary IP address.
32
33         This is the first IPv4 address returned by "ip addr show $name"
34         """
35         # TODO(quentin): The netifaces module is a pile of crappy C.
36         # Figure out a way to do this in native Python.
37         return ni.ifaddresses(name)[ni.AF_INET][0]['addr']
38
39     @staticmethod
40     def exists(name):
41         """exists checks if an interface exists.
42
43         Args:
44         name: Interface name
45         """
46         return os.path.exists("/sys/class/net/"+name)
47
48
49 class DhcpBackend:
50     def __init__(self, queue):
51         database.connect()
52         self.queue = queue
53         self.main_ip = Interfaces.primary_ip(config.xen.iface)
54     def add_route_and_arp(self, ip, intf, gateway):
55         try:
56             p = Popen(['ip', 'route', 'add', ip, 'dev', intf, 'src', self.main_ip, 'metric', '2' if intf.startswith('vif') else '1'], stdout=PIPE, stderr=PIPE)
57             (out, err) = p.communicate()
58             if p.returncode == 0:
59                 s.syslog(s.LOG_INFO, "Added route for IP %s to interface %s" % (ip, intf))
60                 self.queue.put((ip, gateway))
61             sys.stderr.write(err)
62             sys.stdout.write(out)
63         except Exception as e:
64             s.syslog(s.LOG_ERR, "Could not add route for IP %s: %s" % (ip, e))
65     def findNIC(self, mac):
66         database.clear_cache()
67         return database.NIC.query.filter_by(mac_addr=mac).first()
68     def find_interface(self, packet):
69         chaddr = hwmac(packet.GetHardwareAddress())
70         nic = self.findNIC(str(chaddr))
71         return self.find_interface_by_nic(nic)
72     def find_interface_by_nic(self, nic):
73         if nic is None or nic.ip is None:
74             return None
75         ipstr = ''.join(reversed(['%02X' % i for i in ipv4(nic.ip.encode("utf-8")).list()]))
76         for line in open('/proc/net/route'):
77             parts = line.split()
78             if parts[1] == ipstr:
79                 s.syslog(s.LOG_DEBUG, "find_interface found "+str(nic.ip)+" on "+parts[0])
80                 return parts[0]
81         # Either the machine isn't running, or the route is missing.  We can
82         # fix the latter.
83         try:
84             xsc = xs.xs()
85             domid = xsc.read('', '/vm/%s/device/vif/0/frontend-id' % (nic.machine.uuid))
86             # If we didn't find the domid, the machine is either off or the
87             # UUID in xenstore isn't right.  Try slightly harder.
88             if not domid:
89                 for uuid in xsc.ls('', '/vm'):
90                     if xsc.read('', '/vm/%s/name' % (uuid)) == 'd_' + nic.machine.name:
91                         domid = xsc.read('', '/vm/%s/device/vif/0/frontend-id' % (uuid))
92             if not domid:
93                 xsc.close()
94                 return None
95             for vifnum in xsc.ls('', '/local/domain/0/backend/vif/%s' % (domid)):
96                 if xsc.read('', '/local/domain/0/backend/vif/%s/%s/mac' % (domid, vifnum)) == nic.mac_addr:
97                     # Prefer the tap if it exists; paravirtualized HVMs will
98                     # have already unplugged it, so if it's there, it's the one
99                     # in use.
100                     for viftype in ('tap', 'vif'):
101                         vif = '%s%s.%s' % (viftype, domid, vifnum)
102                         if Interfaces.exists(vif):
103                             self.add_route_and_arp(nic.ip, vif, nic.gateway)
104                             xsc.close()
105                             return vif
106             xsc.close()
107         except Exception as e:
108             try:
109                 xsc.close()
110             except Exception as e2:
111                 s.syslog(s.LOG_ERR, "Could not close connection to xenstore: %s" % (e2))
112             s.syslog(s.LOG_ERR, "Could not find interface and add missing route: %s" % (e))
113         return None
114                             
115     def getParameters(self, **extra):
116         all_options=dict(dhcp_options)
117         all_options.update(extra)
118         options = {}
119         for parameter, value in all_options.iteritems():
120             if value is None:
121                 continue
122             option_type = DhcpOptionsTypes[DhcpOptions[parameter]]
123
124             if option_type == "ipv4" :
125                 # this is a single ip address
126                 options[parameter] = map(int,value.split("."))
127             elif option_type == "ipv4+" :
128                 # this is multiple ip address
129                 iplist = value.split(",")
130                 opt = []
131                 for single in iplist :
132                     opt.extend(ipv4(single).list())
133                 options[parameter] = opt
134             elif option_type == "32-bits" :
135                 # This is probably a number...
136                 digit = int(value)
137                 options[parameter] = [digit>>24&0xFF,(digit>>16)&0xFF,(digit>>8)&0xFF,digit&0xFF]
138             elif option_type == "16-bits" :
139                 digit = int(value)
140                 options[parameter] = [(digit>>8)&0xFF,digit&0xFF]
141
142             elif option_type == "char" :
143                 digit = int(value)
144                 options[parameter] = [digit&0xFF]
145
146             elif option_type == "bool" :
147                 if value=="False" or value=="false" or value==0 :
148                     options[parameter] = [0]
149                 else : options[parameter] = [1]
150                     
151             elif option_type == "string" :
152                 options[parameter] = strlist(value).list()
153             
154             elif option_type == "RFC3397" :
155                 parsed_value = ""
156                 for item in value:
157                     components = item.split('.')
158                     item_fmt = "".join(chr(len(elt)) + elt for elt in components) + "\x00"
159                     parsed_value += item_fmt
160                 
161                 options[parameter] = strlist(parsed_value).list()
162             
163             else :
164                 options[parameter] = strlist(value).list()
165         return options
166
167     def Discover(self, packet):
168         s.syslog(s.LOG_DEBUG, "dhcp_backend : Discover ")
169         chaddr = hwmac(packet.GetHardwareAddress())
170         nic = self.findNIC(str(chaddr))
171         if nic is None or nic.machine is None:
172             return False
173         ip = nic.ip.encode("utf-8")
174         if ip is None:  #Deactivated?
175             return False
176
177         options = {}
178         options['subnet_mask'] = nic.netmask.encode("utf-8")
179         options['router'] = nic.gateway.encode("utf-8")
180         if nic.hostname and '.' in nic.hostname:
181             options['host_name'], options['domain_name'] = nic.hostname.encode('utf-8').split('.', 1)
182         elif nic.machine.name:
183             options['host_name'] = nic.machine.name.encode('utf-8')
184             options['domain_name'] = config.dns.domains[0]
185         else:
186             hostname = None
187         if DhcpOptions['domain_search'] in packet.GetOption('parameter_request_list'):
188             options['host_name'] += '.' + options['domain_name']
189             del options['domain_name']
190             options['domain_search'] = [config.dhcp.search_domain]
191         ip = ipv4(ip)
192         s.syslog(s.LOG_DEBUG,"dhcp_backend : Discover result = "+str(ip))
193         packet_parameters = self.getParameters(**options)
194
195         # FIXME: Other offer parameters go here
196         packet_parameters["yiaddr"] = ip.list()
197
198         packet.SetMultipleOptions(packet_parameters)
199         return True
200         
201     def Request(self, packet):
202         s.syslog(s.LOG_DEBUG, "dhcp_backend : Request")
203         
204         discover = self.Discover(packet)
205         
206         chaddr = hwmac(packet.GetHardwareAddress())
207         request = packet.GetOption("request_ip_address")
208         if not request:
209             request = packet.GetOption("ciaddr")
210         yiaddr = packet.GetOption("yiaddr")
211
212         if not discover:
213             s.syslog(s.LOG_INFO,"Unknown MAC address: "+str(chaddr))
214             return False
215         
216         if yiaddr!="0.0.0.0" and yiaddr == request :
217             s.syslog(s.LOG_INFO,"Ack ip "+str(yiaddr)+" for "+str(chaddr))
218             n = self.findNIC(str(chaddr))
219             intf = self.find_interface_by_nic(n)
220             s.syslog(s.LOG_ERR, "Interface is %s" % (intf))
221             # Don't perform "other" actions if the machine isn't running
222             other_action = n.other_action if n.other_action and intf else ''
223             if other_action in ('renumber', 'renumber_dhcp'):
224                 (n.ip, n.netmask, n.gateway, n.other_ip, n.other_netmask,
225                  n.other_gateway) = (
226                  n.other_ip, n.other_netmask, n.other_gateway, n.ip,
227                  n.netmask, n.gateway)
228                 other_action = n.other_action = 'dnat'
229                 database.session.add(n)
230                 database.session.flush()
231             if other_action == 'dnat':
232                 # If the machine was booted in 'dnat' mode, then both
233                 # routes were already added by the invirt-database script.
234                 # If the machine was already on and has just been set to
235                 # 'dnat' mode, we need to add the route for the 'other' IP.
236                 # If the machine has just been 'renumbered' by us above,
237                 # the IPs will be swapped and only the route for the main
238                 # IP needs to be added.  Just try adding both of them, and
239                 # arp for whichever of them turns out to be new.
240                 for parms in [(n.ip, n.gateway), (n.other_ip, n.other_gateway)]:
241                     self.add_route_and_arp(parms[0], intf, parms[1])
242                 try:
243                     # iptables will let you add the same rule again and again;
244                     # let's not do that.
245                     p = Popen(['iptables', '-t', 'nat', '-C', 'PREROUTING', '-d', n.other_ip, '-j', 'DNAT', '--to-destination', n.ip], stdout=PIPE, stderr=PIPE)
246                     (out, err) = p.communicate()
247                     sys.stderr.write(err)
248                     sys.stdout.write(out)
249                     if p.returncode != 0:
250                         p2 = Popen(['iptables', '-t', 'nat', '-A', 'PREROUTING', '-d', n.other_ip, '-j', 'DNAT', '--to-destination', n.ip], stdout=PIPE, stderr=PIPE)
251                         (out, err) = p2.communicate()
252                         sys.stderr.write(err)
253                         sys.stdout.write(out)
254                         if p2.returncode == 0:
255                             s.syslog(s.LOG_INFO, "Added DNAT for IP %s to %s" % (n.other_ip, n.ip))
256                         else:
257                             s.syslog(s.LOG_ERR, "Could not add DNAT for IP %s to %s" % (n.other_ip, n.ip))
258                 except Exception as e:
259                     s.syslog(s.LOG_ERR, "Could not check and/or add DNAT for IP %s to %s: %s" % (n.other_ip, n.ip, e))
260             if other_action == 'remove':
261                 try:
262                     p = Popen(['ip', 'route', 'del', n.other_ip, 'dev', intf], stdout=PIPE, stderr=PIPE)
263                     (out, err) = p.communicate()
264                     sys.stderr.write(err)
265                     sys.stderr.write(out)
266                     if p.returncode == 0:
267                         s.syslog(s.LOG_INFO, "Removed route for IP %s" % (n.other_ip))
268                     else:
269                         s.syslog(s.LOG_ERR, "Could not remove route for IP %s" % (n.other_ip))
270                 except Exception as e:
271                     s.syslog(s.LOG_ERR, "Could not run ip to remove route for IP %s: %s" % (n.other_ip, e))
272                 try:
273                     p = Popen(['iptables', '-t', 'nat', '-D', 'PREROUTING', '-d', n.other_ip, '-j', 'DNAT', '--to-destination', n.ip], stdout=PIPE, stderr=PIPE)
274                     (out, err) = p.communicate()
275                     sys.stderr.write(err)
276                     sys.stdout.write(out)
277                     if p.returncode == 0:
278                         s.syslog(s.LOG_INFO, "Removed DNAT for IP %s" % (n.other_ip))
279                     else:
280                         s.syslog(s.LOG_ERR, "Could not remove DNAT for IP %s" % (n.other_ip))
281                 except Exception as e:
282                     s.syslog(s.LOG_ERR, "Could not run iptables to remove DNAT for IP %s: %s" % (n.other_ip, e))
283                 n.other_ip = n.other_netmask = n.other_gateway = n.other_action = None
284                 database.session.add(n)
285                 database.session.flush()
286             # We went through the DISCOVER codepath already to populate some
287             # of the packet's parameters.  If we renumbered the VM just above,
288             # the packet is set to offer them what they asked for - the old
289             # address.  So, we'll send them a DHCPNACK and they'll come right
290             # back and be offered the new address.  The code above won't be
291             # able to add duplicate routes, won't insert a duplicate DNAT,
292             # and won't ARP again because the routes will exist, so this won't
293             # incur much extra work.
294             if request != map(int, n.ip.split('.')):
295                 return False
296             return True
297         else:
298             s.syslog(s.LOG_INFO,"Requested ip "+str(request)+" not available for "+str(chaddr))
299         return False
300
301     def Decline(self, packet):
302         pass
303     def Release(self, packet):
304         pass
305     
306
307 class DhcpServer(pydhcplib.dhcp_network.DhcpServer):
308     def __init__(self, backend, options = {'client_listenport':68,'server_listenport':67}):
309         pydhcplib.dhcp_network.DhcpServer.__init__(self,"0.0.0.0",options["client_listen_port"],options["server_listen_port"],)
310         self.backend = backend
311         s.syslog(s.LOG_DEBUG, "__init__ DhcpServer")
312
313     def SendDhcpPacketTo(self, To, packet):
314         intf = self.backend.find_interface(packet)
315         if intf:
316             self.dhcp_socket.setsockopt(socket.SOL_SOCKET, IN.SO_BINDTODEVICE, intf)
317             ret = self.dhcp_socket.sendto(packet.EncodePacket(), (To,self.emit_port))
318             self.dhcp_socket.setsockopt(socket.SOL_SOCKET, IN.SO_BINDTODEVICE, '')
319             return ret
320         else:
321             return self.dhcp_socket.sendto(packet.EncodePacket(),(To,self.emit_port))
322
323     def SendPacket(self, packet):
324         """Encode and send the packet."""
325         
326         giaddr = packet.GetOption('giaddr')
327
328         # in all case, if giaddr is set, send packet to relay_agent
329         # network address defines by giaddr
330         if giaddr!=[0,0,0,0] :
331             agent_ip = ".".join(map(str,giaddr))
332             self.SendDhcpPacketTo(agent_ip,packet)
333             s.syslog(s.LOG_DEBUG, "SendPacket to agent : "+agent_ip)
334
335         # FIXME: This shouldn't broadcast if it has an IP address to send
336         # it to instead. See RFC2131 part 4.1 for full details
337         else :
338             s.syslog(s.LOG_DEBUG, "No agent, broadcast packet.")
339             self.SendDhcpPacketTo("255.255.255.255",packet)
340             
341
342     def HandleDhcpDiscover(self, packet):
343         """Build and send DHCPOFFER packet in response to DHCPDISCOVER
344         packet."""
345
346         logmsg = "Get DHCPDISCOVER packet from " + hwmac(packet.GetHardwareAddress()).str()
347
348         s.syslog(s.LOG_INFO, logmsg)
349         offer = DhcpPacket()
350         offer.CreateDhcpOfferPacketFrom(packet)
351         
352         if self.backend.Discover(offer):
353             self.SendPacket(offer)
354         # FIXME : what if false ?
355
356
357     def HandleDhcpRequest(self, packet):
358         """Build and send DHCPACK or DHCPNACK packet in response to
359         DHCPREQUEST packet. 4 types of DHCPREQUEST exists."""
360
361         ip = packet.GetOption("request_ip_address")
362         sid = packet.GetOption("server_identifier")
363         ciaddr = packet.GetOption("ciaddr")
364         #packet.PrintHeaders()
365         #packet.PrintOptions()
366
367         if sid != [0,0,0,0] and ciaddr == [0,0,0,0] :
368             s.syslog(s.LOG_INFO, "Get DHCPREQUEST_SELECTING_STATE packet")
369
370         elif sid == [0,0,0,0] and ciaddr == [0,0,0,0] and ip :
371             s.syslog(s.LOG_INFO, "Get DHCPREQUEST_INITREBOOT_STATE packet")
372
373         elif sid == [0,0,0,0] and ciaddr != [0,0,0,0] and not ip :
374             s.syslog(s.LOG_INFO,"Get DHCPREQUEST_INITREBOOT_STATE packet")
375
376         else : s.syslog(s.LOG_INFO,"Get DHCPREQUEST_UNKNOWN_STATE packet : not implemented")
377
378         if self.backend.Request(packet):
379             packet.TransformToDhcpAckPacket()
380             self.SendPacket(packet)
381         elif self.backend.Discover(packet):
382             packet.TransformToDhcpNackPacket()
383             self.SendPacket(packet)
384         else:
385             pass # We aren't authoritative, so don't reply if we don't know them.
386
387     # FIXME: These are not yet implemented.
388     def HandleDhcpDecline(self, packet):
389         s.syslog(s.LOG_INFO, "Get DHCPDECLINE packet")
390         self.backend.Decline(packet)
391         
392     def HandleDhcpRelease(self, packet):
393         s.syslog(s.LOG_INFO,"Get DHCPRELEASE packet")
394         self.backend.Release(packet)
395         
396     def HandleDhcpInform(self, packet):
397         s.syslog(s.LOG_INFO, "Get DHCPINFORM packet")
398
399         if self.backend.Request(packet) :
400             packet.TransformToDhcpAckPacket()
401             # FIXME : Remove lease_time from options
402             self.SendPacket(packet)
403
404         # FIXME : what if false ?
405
406 class ArpspoofWorker(Thread):
407     def __init__(self, queue):
408         Thread.__init__(self)
409         self.queue = queue
410         self.iface = config.xen.iface
411
412     def run(self):
413         while True:
414             (ip, gw) = self.queue.get()
415             try:
416                 p = Popen(['timeout', '-s', 'KILL', '5', 'arpspoof', '-i', self.iface, '-t', gw, ip], stdout=PIPE, stderr=PIPE)
417                 (out, err) = p.communicate()
418                 if p.returncode != 124:
419                     s.syslog(s.LOG_ERR, "arpspoof returned %s for IP %s gateway %s" % (p.returncode, ip, gw))
420                 else:
421                     s.syslog(s.LOG_INFO, "aprspoof'd for IP %s gateway %s" % (ip, gw))
422                 sys.stderr.write(err)
423                 sys.stdout.write(out)
424             except Exception as e:
425                 s.syslog(s.LOG_ERR, "Could not run arpspoof for IP %s gateway %s: %s" % (ip, gw, e))
426             self.queue.task_done()
427
428 if '__main__' == __name__:
429     options = { "server_listen_port":67,
430                 "client_listen_port":68,
431                 "listen_address":"0.0.0.0"}
432
433     myip = socket.gethostbyname(socket.gethostname())
434     if not myip:
435         print "invirt-dhcpserver: cannot determine local IP address by looking up %s" % socket.gethostname()
436         sys.exit(1)
437     
438     dhcp_options['server_identifier'] = ipv4(myip).int()
439
440     queue = Queue()
441
442     backend = DhcpBackend(queue)
443     server = DhcpServer(backend, options)
444
445     for x in range(2):
446         worker = ArpspoofWorker(queue)
447         worker.daemon = True
448         worker.start()
449
450     while True : server.GetNextDhcpPacket()