Use correct dhcp options, and use SO_BINDTODEVICE to send out the right device
[invirt/packages/invirt-dhcp.git] / dhcpserver.py
index 5c7721c..6171843 100644 (file)
 #!/usr/bin/python
+import sys
+sys.path.append('pydhcplib/')
 import pydhcplib
+import pydhcplib.dhcp_network
 from pydhcplib.dhcp_packet import *
 from pydhcplib.type_hw_addr import hwmac
 from pydhcplib.type_ipv4 import ipv4
+from pydhcplib.type_strlist import strlist
+import socket
+import IN
 
+import event_logger
+if '__main__' == __name__:
+    event_logger.init("stdout", 'DEBUG', {})
 from event_logger import Log
 
+import psycopg2
+import time
 import sipb_xen_database
+from sqlalchemy import create_engine
+
+dhcp_options = {'subnet_mask': '255.255.0.0',
+                'router': '18.181.0.1',
+                'domain_name_server': '18.70.0.160,18.71.0.151,18.72.0.3',
+                'domain_name': 'mit.edu',
+                'ip_address_lease_time': 60*60*24}
 
 class DhcpBackend:
     def __init__(self, database=None):
         if database is not None:
-            sipb_xen_database.connect(database)
-    def findIP(self, mac):
-        value = sipb_xen_database.NIC.get_by(mac_addr=mac)
-        if value is None:
-            return None
-        ip = value.ip
-        if ip is None:  #Deactivated?
-            return None
-        return ip
+            self.database = database
+            sipb_xen_database.connect(create_engine(database))
+    def findNIC(self, mac):
+        for i in range(3):
+            try:
+                value = sipb_xen_database.NIC.get_by(mac_addr=mac)
+            except psycopg2.OperationalError:
+                time.sleep(0.5)
+                if i == 2:  #Try twice to reconnect.
+                    raise
+                #Sigh.  SQLAlchemy should do this itself.
+                sipb_xen_database.connect(create_engine(self.database))
+            else:
+                break
+        return value
+    def find_interface(self, packet):
+        chaddr = hwmac(packet.GetHardwareAddress())
+        nic = self.findNIC(str(chaddr))
+        if nic is None or nic.ip is None:
+            return ("18.181.0.60", None)
+        ipstr = ''.join(reversed(['%02X' % i for i in ipv4(nic.ip).list()]))
+        for line in open('/proc/net/route'):
+            parts = line.split()
+            if parts[1] == ipstr:
+                Log.Output(Log.debug, "find_interface found "+str(nic.ip)+" on "+parts[0])
+                return ("18.181.0.60", parts[0])
+        return ("18.181.0.60", None)
+                            
+    def getParameters(self, **extra):
+        all_options=dict(dhcp_options)
+        all_options.update(extra)
+        options = {}
+        for parameter, value in all_options.iteritems():
+            if value is None:
+                continue
+            option_type = DhcpOptionsTypes[DhcpOptions[parameter]]
+
+            if option_type == "ipv4" :
+                # this is a single ip address
+                options[parameter] = map(int,value.split("."))
+            elif option_type == "ipv4+" :
+                # this is multiple ip address
+                iplist = value.split(",")
+                opt = []
+                for single in iplist :
+                    opt.extend(ipv4(single).list())
+                options[parameter] = opt
+            elif option_type == "32-bits" :
+                # This is probably a number...
+                digit = int(value)
+                options[parameter] = [digit>>24&0xFF,(digit>>16)&0xFF,(digit>>8)&0xFF,digit&0xFF]
+            elif option_type == "16-bits" :
+                digit = int(value)
+                options[parameter] = [(digit>>8)&0xFF,digit&0xFF]
+
+            elif option_type == "char" :
+                digit = int(value)
+                options[parameter] = [digit&0xFF]
+
+            elif option_type == "bool" :
+                if value=="False" or value=="false" or value==0 :
+                    options[parameter] = [0]
+                else : options[parameter] = [1]
+                    
+            elif option_type == "string" :
+                options[parameter] = strlist(value).list()
+                
+            else :
+                options[parameter] = strlist(value).list()
+        return options
 
     def Discover(self, packet):
         Log.Output(Log.debug,"dhcp_backend : Discover ")
         chaddr = hwmac(packet.GetHardwareAddress())
-        ip = self.findIP(str(chaddr))
+        nic = self.findNIC(str(chaddr))
+        if nic is None:
+            return False
+        ip = nic.ip
+        if ip is None:  #Deactivated?
+            return False
+        hostname = nic.hostname
+        if hostname is not None:
+            hostname += ".servers.csail.mit.edu"
         if ip is not None:
             ip = ipv4(ip)
             Log.Output(Log.debug,"dhcp_backend : Discover result = "+str(ip))
-            packet_parameters = {}
+            packet_parameters = self.getParameters(host_name=hostname)
 
             # FIXME: Other offer parameters go here
             packet_parameters["yiaddr"] = ip.list()
@@ -44,6 +131,8 @@ class DhcpBackend:
         
         chaddr = hwmac(packet.GetHardwareAddress())
         request = packet.GetOption("request_ip_address")
+        if not request:
+            request = packet.GetOption("ciaddr")
         yiaddr = packet.GetOption("yiaddr")
 
         if not discover:
@@ -69,8 +158,21 @@ class DhcpServer(pydhcplib.dhcp_network.DhcpServer):
         self.backend = backend
         Log.Output(Log.debug, "__init__ DhcpServer")
 
+    def SendDhcpPacketTo(self, To, packet):
+        (ip, intf) = self.backend.find_interface(packet)
+        if intf:
+            out_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+            out_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST,1)
+            out_socket.setsockopt(socket.SOL_SOCKET, IN.SO_BINDTODEVICE, intf)
+            #out_socket.bind((ip, self.listen_port))
+            ret = out_socket.sendto(packet.EncodePacket(), (To,self.emit_port))
+            out_socket.close()
+            return ret
+        else:
+            return self.dhcp_socket.sendto(packet.EncodePacket(),(To,self.emit_port))
+
     def SendPacket(self, packet):
-            """Encode and send the packet."""
+        """Encode and send the packet."""
         
         giaddr = packet.GetOption('giaddr')
 
@@ -98,7 +200,7 @@ class DhcpServer(pydhcplib.dhcp_network.DhcpServer):
         offer = DhcpPacket()
         offer.CreateDhcpOfferPacketFrom(packet)
         
-        if self.backend.Discover(offer) :
+        if self.backend.Discover(offer):
             self.SendPacket(offer)
         # FIXME : what if false ?
 
@@ -110,6 +212,8 @@ class DhcpServer(pydhcplib.dhcp_network.DhcpServer):
         ip = packet.GetOption("request_ip_address")
         sid = packet.GetOption("server_identifier")
         ciaddr = packet.GetOption("ciaddr")
+        #packet.PrintHeaders()
+        #packet.PrintOptions()
 
         if sid != [0,0,0,0] and ciaddr == [0,0,0,0] :
             Log.Output(Log.info, "Get DHCPREQUEST_SELECTING_STATE packet")
@@ -149,7 +253,6 @@ class DhcpServer(pydhcplib.dhcp_network.DhcpServer):
         # FIXME : what if false ?
 
 if '__main__' == __name__:
-    event_logger.init("stdout", event_logger.INFO, {})
     options = { "server_listen_port":67,
                 "client_listen_port":68,
                 "listen_address":"0.0.0.0"}