Put validation behind more abstraction.
authorEric Price <ecprice@mit.edu>
Tue, 3 Jun 2008 03:25:47 +0000 (23:25 -0400)
committerEric Price <ecprice@mit.edu>
Tue, 3 Jun 2008 03:25:47 +0000 (23:25 -0400)
svn path=/trunk/packages/sipb-xen-www/; revision=572

code/controls.py
code/main.py
code/templates/info.tmpl
code/templates/list.tmpl
code/validation.py
code/webcommon.py

index d4afab0..808c988 100644 (file)
@@ -92,14 +92,12 @@ def bootMachine(machine, cdtype):
         raise CodeError('"%s" on "control %s create %s' 
                         % (err, machine.name, cdtype))
 
         raise CodeError('"%s" on "control %s create %s' 
                         % (err, machine.name, cdtype))
 
-def createVm(owner, contact, name, memory, disk_size, machine_type, cdrom, clone_from):
+def createVm(username, state, owner, contact, name, memory, disksize, machine_type, cdrom, clone_from):
     """Create a VM and put it in the database"""
     # put stuff in the table
     transaction = ctx.current.create_transaction()
     try:
     """Create a VM and put it in the database"""
     # put stuff in the table
     transaction = ctx.current.create_transaction()
     try:
-        validation.validMemory(owner, memory)
-        validation.validDisk(owner, disk_size  * 1. / 1024)
-        validation.validAddVm(owner)
+        validation.Validate(username, state, owner=owner, memory=memory, disksize=disksize/1024.)
         res = meta.engine.execute('select nextval('
                                   '\'"machines_machine_id_seq"\')')
         id = res.fetchone()[0]
         res = meta.engine.execute('select nextval('
                                   '\'"machines_machine_id_seq"\')')
         id = res.fetchone()[0]
@@ -115,7 +113,7 @@ def createVm(owner, contact, name, memory, disk_size, machine_type, cdrom, clone
         machine.type_id = machine_type.type_id
         ctx.current.save(machine)
         disk = Disk(machine_id=machine.machine_id,
         machine.type_id = machine_type.type_id
         ctx.current.save(machine)
         disk = Disk(machine_id=machine.machine_id,
-                    guest_device_name='hda', size=disk_size)
+                    guest_device_name='hda', size=disksize)
         open_nics = NIC.select_by(machine_id=None)
         if not open_nics: #No IPs left!
             raise CodeError("No IP addresses left!  "
         open_nics = NIC.select_by(machine_id=None)
         if not open_nics: #No IPs left!
             raise CodeError("No IP addresses left!  "
@@ -140,7 +138,7 @@ def createVm(owner, contact, name, memory, disk_size, machine_type, cdrom, clone
 def getList():
     """Return a dictionary mapping machine names to dicts."""
     value_string = remctl('web', 'listvms')
 def getList():
     """Return a dictionary mapping machine names to dicts."""
     value_string = remctl('web', 'listvms')
-    value_dict = yaml.load(value_string, yaml.CSafeLoader)
+    value_dict = yaml.load(value_string, yaml.SafeLoader)
     return value_dict
 
 def parseStatus(s):
     return value_dict
 
 def parseStatus(s):
@@ -208,9 +206,9 @@ def deleteVM(machine):
     for mname, dname in delete_disk_pairs:
         remctl('web', 'lvremove', mname, dname)
 
     for mname, dname in delete_disk_pairs:
         remctl('web', 'lvremove', mname, dname)
 
-def commandResult(user, fields):
+def commandResult(username, state, fields):
     start_time = 0
     start_time = 0
-    machine = validation.testMachineId(user, fields.getfirst('machine_id'))
+    machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
     action = fields.getfirst('action')
     cdrom = fields.getfirst('cdrom')
     if cdrom is not None and not CDROM.get(cdrom):
     action = fields.getfirst('action')
     cdrom = fields.getfirst('cdrom')
     if cdrom is not None and not CDROM.get(cdrom):
@@ -235,7 +233,7 @@ def commandResult(user, fields):
                 raise CodeError('ERROR on remctl')
                 
     elif action == 'Power on':
                 raise CodeError('ERROR on remctl')
                 
     elif action == 'Power on':
-        if validation.maxMemory(user, machine) < machine.memory:
+        if validation.maxMemory(username, state, machine) < machine.memory:
             raise InvalidInput('action', 'Power on',
                                "You don't have enough free RAM quota "
                                "to turn on this machine.")
             raise InvalidInput('action', 'Power on',
                                "You don't have enough free RAM quota "
                                "to turn on this machine.")
@@ -263,7 +261,7 @@ def commandResult(user, fields):
     elif action == 'Delete VM':
         deleteVM(machine)
 
     elif action == 'Delete VM':
         deleteVM(machine)
 
-    d = dict(user=user,
+    d = dict(user=username,
              command=action,
              machine=machine)
     return d
              command=action,
              machine=machine)
     return d
index 385a2a6..073f81d 100755 (executable)
@@ -41,7 +41,7 @@ import sipb_xen_database
 from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess, Type, Autoinstall
 import validation
 import cache_acls
 from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess, Type, Autoinstall
 import validation
 import cache_acls
-from webcommon import InvalidInput, CodeError, g
+from webcommon import InvalidInput, CodeError, state
 import controls
 
 class Checkpoint:
 import controls
 
 class Checkpoint:
@@ -119,15 +119,15 @@ class Defaults:
 
 DEFAULT_HEADERS = {'Content-Type': 'text/html'}
 
 
 DEFAULT_HEADERS = {'Content-Type': 'text/html'}
 
-def error(op, user, fields, err, emsg):
+def error(op, username, fields, err, emsg):
     """Print an error page when a CodeError occurs"""
     """Print an error page when a CodeError occurs"""
-    d = dict(op=op, user=user, errorMessage=str(err),
+    d = dict(op=op, user=username, errorMessage=str(err),
              stderr=emsg)
     return templates.error(searchList=[d])
 
              stderr=emsg)
     return templates.error(searchList=[d])
 
-def invalidInput(op, user, fields, err, emsg):
+def invalidInput(op, username, fields, err, emsg):
     """Print an error page when an InvalidInput exception occurs"""
     """Print an error page when an InvalidInput exception occurs"""
-    d = dict(op=op, user=user, err_field=err.err_field,
+    d = dict(op=op, user=username, err_field=err.err_field,
              err_value=str(err.err_value), stderr=emsg,
              errorMessage=str(err))
     return templates.invalid(searchList=[d])
              err_value=str(err.err_value), stderr=emsg,
              errorMessage=str(err))
     return templates.invalid(searchList=[d])
@@ -142,49 +142,25 @@ def hasVnc(status):
             return 'location' in d
     return False
 
             return 'location' in d
     return False
 
-def parseCreate(user, fields):
-    name = fields.getfirst('name')
-    if not validation.validMachineName(name):
-        raise InvalidInput('name', name, 'You must provide a machine name.  Max 22 chars, alnum plus \'-\' and \'_\'.')
-    name = name.lower()
+def parseCreate(username, state, fields):
+    kws = dict([(kw, fields.getfirst(kw)) for kw in 'name owner memory disksize vmtype cdrom clone_from'.split()])
+    validate = validation.Validate(username, state, **kws)
+    return dict(contact=username, name=validate.name, memory=validate.memory,
+                disksize=validate.disksize, owner=validate.owner, machine_type=validate.vmtype,
+                cdrom=getattr(validate, 'cdrom', None),
+                clone_from=getattr(validate, 'clone_from', None))
 
 
-    if Machine.get_by(name=name):
-        raise InvalidInput('name', name,
-                           "Name already exists.")
-
-    owner = validation.testOwner(user, fields.getfirst('owner'))
-
-    memory = fields.getfirst('memory')
-    memory = validation.validMemory(owner, memory, on=True)
-
-    disk_size = fields.getfirst('disk')
-    disk_size = validation.validDisk(owner, disk_size)
-
-    vm_type = fields.getfirst('vmtype')
-    vm_type = validation.validVmType(vm_type)
-
-    cdrom = fields.getfirst('cdrom')
-    if cdrom is not None and not CDROM.get(cdrom):
-        raise CodeError("Invalid cdrom type '%s'" % cdrom)
-
-    clone_from = fields.getfirst('clone_from')
-    if clone_from and clone_from != 'ice3':
-        raise CodeError("Invalid clone image '%s'" % clone_from)
-
-    return dict(contact=user, name=name, memory=memory, disk_size=disk_size,
-                owner=owner, machine_type=vm_type, cdrom=cdrom, clone_from=clone_from)
-
-def create(user, fields):
+def create(username, state, fields):
     """Handler for create requests."""
     try:
     """Handler for create requests."""
     try:
-        parsed_fields = parseCreate(user, fields)
-        machine = controls.createVm(**parsed_fields)
+        parsed_fields = parseCreate(username, state, fields)
+        machine = controls.createVm(username, **parsed_fields)
     except InvalidInput, err:
         pass
     else:
         err = None
     except InvalidInput, err:
         pass
     else:
         err = None
-    g.clear() #Changed global state
-    d = getListDict(user)
+    state.clear() #Changed global state
+    d = getListDict(username)
     d['err'] = err
     if err:
         for field in fields.keys():
     d['err'] = err
     if err:
         for field in fields.keys():
@@ -194,16 +170,16 @@ def create(user, fields):
     return templates.list(searchList=[d])
 
 
     return templates.list(searchList=[d])
 
 
-def getListDict(user):
+def getListDict(username, state):
     """Gets the list of local variables used by list.tmpl."""
     checkpoint.checkpoint('Starting')
     """Gets the list of local variables used by list.tmpl."""
     checkpoint.checkpoint('Starting')
-    machines = g.machines
+    machines = state.machines
     checkpoint.checkpoint('Got my machines')
     on = {}
     has_vnc = {}
     checkpoint.checkpoint('Got my machines')
     on = {}
     has_vnc = {}
-    xmlist = g.xmlist
+    xmlist = state.xmlist
     checkpoint.checkpoint('Got uptimes')
     checkpoint.checkpoint('Got uptimes')
-    can_clone = 'ice3' not in g.xmlist_raw
+    can_clone = 'ice3' not in state.xmlist_raw
     for m in machines:
         if m not in xmlist:
             has_vnc[m] = 'Off'
     for m in machines:
         if m not in xmlist:
             has_vnc[m] = 'Off'
@@ -216,19 +192,19 @@ def getListDict(user):
                 has_vnc[m] = "WTF?"
             else:
                 has_vnc[m] = "ParaVM"+helppopup("ParaVM Console")
                 has_vnc[m] = "WTF?"
             else:
                 has_vnc[m] = "ParaVM"+helppopup("ParaVM Console")
-    max_memory = validation.maxMemory(user)
-    max_disk = validation.maxDisk(user)
+    max_memory = validation.maxMemory(username, state)
+    max_disk = validation.maxDisk(username)
     checkpoint.checkpoint('Got max mem/disk')
     defaults = Defaults(max_memory=max_memory,
                         max_disk=max_disk,
     checkpoint.checkpoint('Got max mem/disk')
     defaults = Defaults(max_memory=max_memory,
                         max_disk=max_disk,
-                        owner=user,
+                        owner=username,
                         cdrom='gutsy-i386')
     checkpoint.checkpoint('Got defaults')
     def sortkey(machine):
                         cdrom='gutsy-i386')
     checkpoint.checkpoint('Got defaults')
     def sortkey(machine):
-        return (machine.owner != user, machine.owner, machine.name)
+        return (machine.owner != username, machine.owner, machine.name)
     machines = sorted(machines, key=sortkey)
     machines = sorted(machines, key=sortkey)
-    d = dict(user=user,
-             cant_add_vm=validation.cantAddVm(user),
+    d = dict(user=username,
+             cant_add_vm=validation.cantAddVm(username, state),
              max_memory=max_memory,
              max_disk=max_disk,
              defaults=defaults,
              max_memory=max_memory,
              max_disk=max_disk,
              defaults=defaults,
@@ -237,14 +213,14 @@ def getListDict(user):
              can_clone=can_clone)
     return d
 
              can_clone=can_clone)
     return d
 
-def listVms(user, fields):
+def listVms(username, state, fields):
     """Handler for list requests."""
     checkpoint.checkpoint('Getting list dict')
     """Handler for list requests."""
     checkpoint.checkpoint('Getting list dict')
-    d = getListDict(user)
+    d = getListDict(username, state)
     checkpoint.checkpoint('Got list dict')
     return templates.list(searchList=[d])
 
     checkpoint.checkpoint('Got list dict')
     return templates.list(searchList=[d])
 
-def vnc(user, fields):
+def vnc(username, state, fields):
     """VNC applet page.
 
     Note that due to same-domain restrictions, the applet connects to
     """VNC applet page.
 
     Note that due to same-domain restrictions, the applet connects to
@@ -264,12 +240,12 @@ def vnc(user, fields):
     Remember to enable iptables!
     echo 1 > /proc/sys/net/ipv4/ip_forward
     """
     Remember to enable iptables!
     echo 1 > /proc/sys/net/ipv4/ip_forward
     """
-    machine = validation.testMachineId(user, fields.getfirst('machine_id'))
+    machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
 
     TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
 
     data = {}
 
     TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
 
     data = {}
-    data["user"] = user
+    data["user"] = username
     data["machine"] = machine.name
     data["expires"] = time.time()+(5*60)
     pickled_data = cPickle.dumps(data)
     data["machine"] = machine.name
     data["expires"] = time.time()+(5*60)
     pickled_data = cPickle.dumps(data)
@@ -282,7 +258,7 @@ def vnc(user, fields):
     status = controls.statusInfo(machine)
     has_vnc = hasVnc(status)
 
     status = controls.statusInfo(machine)
     has_vnc = hasVnc(status)
 
-    d = dict(user=user,
+    d = dict(user=username,
              on=status,
              has_vnc=has_vnc,
              machine=machine,
              on=status,
              has_vnc=has_vnc,
              machine=machine,
@@ -341,36 +317,36 @@ def getDiskInfo(data_dict, machine):
         data_dict['%s_size' % name] = "%0.1f GiB" % (disk.size / 1024.)
     return disk_fields
 
         data_dict['%s_size' % name] = "%0.1f GiB" % (disk.size / 1024.)
     return disk_fields
 
-def command(user, fields):
+def command(username, state, fields):
     """Handler for running commands like boot and delete on a VM."""
     back = fields.getfirst('back')
     try:
     """Handler for running commands like boot and delete on a VM."""
     back = fields.getfirst('back')
     try:
-        d = controls.commandResult(user, fields)
+        d = controls.commandResult(username, state, fields)
         if d['command'] == 'Delete VM':
             back = 'list'
     except InvalidInput, err:
         if not back:
             raise
         if d['command'] == 'Delete VM':
             back = 'list'
     except InvalidInput, err:
         if not back:
             raise
-        #print >> sys.stderr, err
+        print >> sys.stderr, err
         result = err
     else:
         result = 'Success!'
         if not back:
             return templates.command(searchList=[d])
     if back == 'list':
         result = err
     else:
         result = 'Success!'
         if not back:
             return templates.command(searchList=[d])
     if back == 'list':
-        g.clear() #Changed global state
-        d = getListDict(user)
+        state.clear() #Changed global state
+        d = getListDict(username)
         d['result'] = result
         return templates.list(searchList=[d])
     elif back == 'info':
         d['result'] = result
         return templates.list(searchList=[d])
     elif back == 'info':
-        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
+        machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
         return ({'Status': '302',
                  'Location': '/info?machine_id=%d' % machine.machine_id},
                 "You shouldn't see this message.")
     else:
         raise InvalidInput('back', back, 'Not a known back page.')
 
         return ({'Status': '302',
                  'Location': '/info?machine_id=%d' % machine.machine_id},
                 "You shouldn't see this message.")
     else:
         raise InvalidInput('back', back, 'Not a known back page.')
 
-def modifyDict(user, fields):
+def modifyDict(username, state, fields):
     """Modify a machine as specified by CGI arguments.
 
     Return a list of local variables for modify.tmpl.
     """Modify a machine as specified by CGI arguments.
 
     Return a list of local variables for modify.tmpl.
@@ -378,28 +354,20 @@ def modifyDict(user, fields):
     olddisk = {}
     transaction = ctx.current.create_transaction()
     try:
     olddisk = {}
     transaction = ctx.current.create_transaction()
     try:
-        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
-        owner = validation.testOwner(user, fields.getfirst('owner'), machine)
-        admin = validation.testAdmin(user, fields.getfirst('administrator'),
-                                     machine)
-        contact = validation.testContact(user, fields.getfirst('contact'),
-                                         machine)
-        name = validation.testName(user, fields.getfirst('name'), machine)
+        kws = dict([(kw, fields.getfirst(kw)) for kw in 'machine_id owner admin contact name memory vmtype disksize'.split()])
+        validate = validation.Validate(username, state, **kws)
+        machine = validate.machine
+        print >> sys.stderr, machine, machine.administrator, kws['admin']
         oldname = machine.name
         oldname = machine.name
-        command = "modify"
 
 
-        memory = fields.getfirst('memory')
-        if memory is not None:
-            memory = validation.validMemory(owner, memory, machine, on=False)
-            machine.memory = memory
+        if hasattr(validate, 'memory'):
+            machine.memory = validate.memory
 
 
-        vm_type = validation.validVmType(fields.getfirst('vmtype'))
-        if vm_type is not None:
-            machine.type = vm_type
+        if hasattr(validate, 'vmtype'):
+            machine.type = validate.vmtype
 
 
-        disksize = validation.testDisk(owner, fields.getfirst('disk'))
-        if disksize is not None:
-            disksize = validation.validDisk(owner, disksize, machine)
+        if hasattr(validate, 'disksize'):
+            disksize = validate.disksize
             disk = machine.disks[0]
             if disk.size != disksize:
                 olddisk[disk.guest_device_name] = disksize
             disk = machine.disks[0]
             if disk.size != disksize:
                 olddisk[disk.guest_device_name] = disksize
@@ -407,19 +375,20 @@ def modifyDict(user, fields):
                 ctx.current.save(disk)
 
         update_acl = False
                 ctx.current.save(disk)
 
         update_acl = False
-        if owner is not None and owner != machine.owner:
-            machine.owner = owner
+        if hasattr(validate, 'owner') and validate.owner != machine.owner:
+            machine.owner = validate.owner
             update_acl = True
             update_acl = True
-        if name is not None:
+        if hasattr(validate, 'name'):
             machine.name = name
             machine.name = name
-        if admin is not None and admin != machine.administrator:
-            machine.administrator = admin
+        if hasattr(validate, 'admin') and validate.admin != machine.administrator:
+            machine.administrator = validate.admin
             update_acl = True
             update_acl = True
-        if contact is not None:
-            machine.contact = contact
+        if hasattr(validate, 'contact'):
+            machine.contact = validate.contact
 
         ctx.current.save(machine)
         if update_acl:
 
         ctx.current.save(machine)
         if update_acl:
+            print >> sys.stderr, machine, machine.administrator
             cache_acls.refreshMachine(machine)
         transaction.commit()
     except:
             cache_acls.refreshMachine(machine)
         transaction.commit()
     except:
@@ -427,24 +396,24 @@ def modifyDict(user, fields):
         raise
     for diskname in olddisk:
         controls.resizeDisk(oldname, diskname, str(olddisk[diskname]))
         raise
     for diskname in olddisk:
         controls.resizeDisk(oldname, diskname, str(olddisk[diskname]))
-    if name is not None:
-        controls.renameMachine(machine, oldname, name)
-    return dict(user=user,
-                command=command,
+    if hasattr(validate, 'name'):
+        controls.renameMachine(machine, oldname, validate.name)
+    return dict(user=username,
+                command="modify",
                 machine=machine)
 
                 machine=machine)
 
-def modify(user, fields):
+def modify(username, state, fields):
     """Handler for modifying attributes of a machine."""
     try:
     """Handler for modifying attributes of a machine."""
     try:
-        modify_dict = modifyDict(user, fields)
+        modify_dict = modifyDict(username, state, fields)
     except InvalidInput, err:
         result = None
     except InvalidInput, err:
         result = None
-        machine = validation.testMachineId(user, fields.getfirst('machine_id'))
+        machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
     else:
         machine = modify_dict['machine']
         result = 'Success!'
         err = None
     else:
         machine = modify_dict['machine']
         result = 'Success!'
         err = None
-    info_dict = infoDict(user, machine)
+    info_dict = infoDict(username, machine)
     info_dict['err'] = err
     if err:
         for field in fields.keys():
     info_dict['err'] = err
     if err:
         for field in fields.keys():
@@ -453,7 +422,7 @@ def modify(user, fields):
     return templates.info(searchList=[info_dict])
 
 
     return templates.info(searchList=[info_dict])
 
 
-def helpHandler(user, fields):
+def helpHandler(username, state, fields):
     """Handler for help messages."""
     simple = fields.getfirst('simple')
     subjects = fields.getlist('subject')
     """Handler for help messages."""
     simple = fields.getfirst('simple')
     subjects = fields.getlist('subject')
@@ -496,7 +465,7 @@ console will suffer artifacts.
     if not subjects:
         subjects = sorted(help_mapping.keys())
 
     if not subjects:
         subjects = sorted(help_mapping.keys())
 
-    d = dict(user=user,
+    d = dict(user=username,
              simple=simple,
              subjects=subjects,
              mapping=help_mapping)
              simple=simple,
              subjects=subjects,
              mapping=help_mapping)
@@ -508,7 +477,7 @@ def badOperation(u, e):
     """Function called when accessing an unknown URI."""
     raise CodeError("Unknown operation")
 
     """Function called when accessing an unknown URI."""
     raise CodeError("Unknown operation")
 
-def infoDict(user, machine):
+def infoDict(username, machine):
     """Get the variables used by info.tmpl."""
     status = controls.statusInfo(machine)
     checkpoint.checkpoint('Getting status info')
     """Get the variables used by info.tmpl."""
     status = controls.statusInfo(machine)
     checkpoint.checkpoint('Getting status info')
@@ -579,7 +548,7 @@ def infoDict(user, machine):
     checkpoint.checkpoint('Got fields')
 
 
     checkpoint.checkpoint('Got fields')
 
 
-    max_mem = validation.maxMemory(machine.owner, machine, False)
+    max_mem = validation.maxMemory(machine.owner, state, machine, False)
     checkpoint.checkpoint('Got mem')
     max_disk = validation.maxDisk(machine.owner, machine)
     defaults = Defaults()
     checkpoint.checkpoint('Got mem')
     max_disk = validation.maxDisk(machine.owner, machine)
     defaults = Defaults()
@@ -588,7 +557,7 @@ def infoDict(user, machine):
     defaults.type = machine.type.type_id
     defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
     checkpoint.checkpoint('Got defaults')
     defaults.type = machine.type.type_id
     defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
     checkpoint.checkpoint('Got defaults')
-    d = dict(user=user,
+    d = dict(user=username,
              on=status is not None,
              machine=machine,
              defaults=defaults,
              on=status is not None,
              machine=machine,
              defaults=defaults,
@@ -601,14 +570,14 @@ def infoDict(user, machine):
              fields = fields)
     return d
 
              fields = fields)
     return d
 
-def info(user, fields):
+def info(username, state, fields):
     """Handler for info on a single VM."""
     """Handler for info on a single VM."""
-    machine = validation.testMachineId(user, fields.getfirst('machine_id'))
-    d = infoDict(user, machine)
+    machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
+    d = infoDict(username, machine)
     checkpoint.checkpoint('Got infodict')
     return templates.info(searchList=[d])
 
     checkpoint.checkpoint('Got infodict')
     return templates.info(searchList=[d])
 
-def unauthFront(_, fields):
+def unauthFront(_, _2, fields):
     """Information for unauth'd users."""
     return templates.unauth(searchList=[{'simple' : True}])
 
     """Information for unauth'd users."""
     return templates.unauth(searchList=[{'simple' : True}])
 
@@ -628,14 +597,16 @@ def printHeaders(headers):
     print
 
 
     print
 
 
-def getUser():
+def getUser(environ):
     """Return the current user based on the SSL environment variables"""
     """Return the current user based on the SSL environment variables"""
-    email = os.environ.get('SSL_CLIENT_S_DN_Email', None)
+    email = environ.get('SSL_CLIENT_S_DN_Email', None)
     if email is None:
         return None
     if email is None:
         return None
-    return email.split("@")[0]
+    if not email.endswith('@MIT.EDU'):
+        return None
+    return email[:-8]
 
 
-def main(operation, user, fields):
+def main(operation, username, state, fields):
     start_time = time.time()
     fun = mapping.get(operation, badOperation)
 
     start_time = time.time()
     fun = mapping.get(operation, badOperation)
 
@@ -643,7 +614,7 @@ def main(operation, user, fields):
         connect('postgres://sipb-xen@sipb-xen-dev.mit.edu/sipb_xen')
     try:
         checkpoint.checkpoint('Before')
         connect('postgres://sipb-xen@sipb-xen-dev.mit.edu/sipb_xen')
     try:
         checkpoint.checkpoint('Before')
-        output = fun(u, fields)
+        output = fun(username, state, fields)
         checkpoint.checkpoint('After')
 
         headers = dict(DEFAULT_HEADERS)
         checkpoint.checkpoint('After')
 
         headers = dict(DEFAULT_HEADERS)
@@ -652,24 +623,31 @@ def main(operation, user, fields):
             headers.update(new_headers)
         e = revertStandardError()
         if e:
             headers.update(new_headers)
         e = revertStandardError()
         if e:
+            if isinstance(output, basestring):
+                sys.stderr = StringIO()
+                x = str(output)
+                print >> sys.stderr, x
+                print >> sys.stderr, 'XXX'
+                print >> sys.stderr, e
+                raise Exception()
             output.addError(e)
         printHeaders(headers)
         output_string =  str(output)
         checkpoint.checkpoint('output as a string')
         print output_string
         if fields.has_key('timedebug'):
             output.addError(e)
         printHeaders(headers)
         output_string =  str(output)
         checkpoint.checkpoint('output as a string')
         print output_string
         if fields.has_key('timedebug'):
-            print '<pre>%s</pre>' % checkpoint
+            print '<pre>%s</pre>' % cgi.escape(checkpoint)
     except Exception, err:
         if not fields.has_key('js'):
             if isinstance(err, CodeError):
                 print 'Content-Type: text/html\n'
                 e = revertStandardError()
     except Exception, err:
         if not fields.has_key('js'):
             if isinstance(err, CodeError):
                 print 'Content-Type: text/html\n'
                 e = revertStandardError()
-                print error(operation, u, fields, err, e)
+                print error(operation, state.username, fields, err, e)
                 sys.exit(1)
             if isinstance(err, InvalidInput):
                 print 'Content-Type: text/html\n'
                 e = revertStandardError()
                 sys.exit(1)
             if isinstance(err, InvalidInput):
                 print 'Content-Type: text/html\n'
                 e = revertStandardError()
-                print invalidInput(operation, u, fields, err, e)
+                print invalidInput(operation, state.username, fields, err, e)
                 sys.exit(1)
         print 'Content-Type: text/plain\n'
         print 'Uh-oh!  We experienced an error.'
                 sys.exit(1)
         print 'Content-Type: text/plain\n'
         print 'Uh-oh!  We experienced an error.'
@@ -689,17 +667,15 @@ if __name__ == '__main__':
         logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
         logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.INFO)
 
         logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
         logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.INFO)
 
-    u = getUser()
-    g.user = u
+    username = getUser(os.environ)
+    state.username = username
     operation = os.environ.get('PATH_INFO', '')
     if not operation:
         print "Status: 301 Moved Permanently"
         print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
         sys.exit(0)
     operation = os.environ.get('PATH_INFO', '')
     if not operation:
         print "Status: 301 Moved Permanently"
         print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
         sys.exit(0)
-
-    if u is None:
+    if username is None:
         operation = 'unauth'
         operation = 'unauth'
-
     if operation.startswith('/'):
         operation = operation[1:]
     if not operation:
     if operation.startswith('/'):
         operation = operation[1:]
     if not operation:
@@ -707,6 +683,6 @@ if __name__ == '__main__':
 
     if os.getenv("SIPB_XEN_PROFILE"):
         import profile
 
     if os.getenv("SIPB_XEN_PROFILE"):
         import profile
-        profile.run('main(operation, u, fields)', 'log-'+operation)
+        profile.run('main(operation, username, state, fields)', 'log-'+operation)
     else:
     else:
-        main(operation, u, fields)
+        main(operation, username, state, fields)
index 64b9445..b65299c 100644 (file)
@@ -77,7 +77,7 @@ $errorRow('owner', $err)
 #filter None
 $helppopup("Administrator")#slurp
 #end filter
 #filter None
 $helppopup("Administrator")#slurp
 #end filter
-:</td><td><input type="text" name="administrator", value="$defaults.administrator"/></td></tr>
+:</td><td><input type="text" name="admin", value="$defaults.administrator"/></td></tr>
 #filter None
 $errorRow('administrator', $err)
 #end filter
 #filter None
 $errorRow('administrator', $err)
 #end filter
@@ -106,7 +106,7 @@ $vmTypeList($defaults.type)#slurp
 #filter None
 $errorRow('memory', $err)
 #end filter
 #filter None
 $errorRow('memory', $err)
 #end filter
-    <tr><td>Disk:</td><td><input type="text" size=3 name="disk" value="$defaults.disk"/>GiB (max $max_disk)</td><td>WARNING: Modifying disk size may corrupt your data.</td></tr>
+    <tr><td>Disk:</td><td><input type="text" size=3 name="disksize" value="$defaults.disk"/>GiB (max $max_disk)</td><td>WARNING: Modifying disk size may corrupt your data.</td></tr>
 #filter None
 $errorRow('disk', $err)
 #end filter
 #filter None
 $errorRow('disk', $err)
 #end filter
index 8310024..2242c2f 100644 (file)
@@ -39,7 +39,7 @@ $errorRow('memory', $err)
 #end filter
        <tr>
          <td>Disk</td>
 #end filter
        <tr>
          <td>Disk</td>
-         <td><input type="text" name="disk" value="$defaults.disk" size=3/> GiB (${"%0.1f" % ($max_disk-0.05)} max)</td>
+         <td><input type="text" name="disksize" value="$defaults.disk" size=3/> GiB (${"%0.1f" % ($max_disk-0.05)} max)</td>
        </tr>
 #filter None
 $errorRow('disk', $err)
        </tr>
 #filter None
 $errorRow('disk', $err)
index c0e3aeb..df5bdcc 100644 (file)
@@ -5,7 +5,7 @@ import getafsgroups
 import re
 import string
 from sipb_xen_database import Machine, NIC, Type, Disk
 import re
 import string
 from sipb_xen_database import Machine, NIC, Type, Disk
-from webcommon import InvalidInput, g
+from webcommon import InvalidInput
 
 MAX_MEMORY_TOTAL = 512
 MAX_MEMORY_SINGLE = 256
 
 MAX_MEMORY_TOTAL = 512
 MAX_MEMORY_SINGLE = 256
@@ -16,7 +16,49 @@ MIN_DISK_SINGLE = 0.1
 MAX_VMS_TOTAL = 10
 MAX_VMS_ACTIVE = 4
 
 MAX_VMS_TOTAL = 10
 MAX_VMS_ACTIVE = 4
 
-def getMachinesByOwner(user, machine=None):
+class Validate:
+    def __init__(self, username, state, machine_id=None, name=None, owner=None,
+                 admin=None, contact=None, memory=None, disksize=None,
+                 vmtype=None, cdrom=None, clone_from=None):
+        # XXX Successive quota checks aren't a good idea, since you
+        # can't necessarily change the locker and disk size at the
+        # same time.
+        created_new = (machine_id is None)
+
+        if machine_id is not None:
+            self.machine = testMachineId(username, machine_id)
+        machine = getattr(self, 'machine', None)
+
+        owner = testOwner(username, owner, machine)
+        if owner is not None:
+            self.owner = owner
+        admin = testAdmin(username, admin, machine)
+        if admin is not None:
+            self.admin = admin
+        contact = testContact(username, contact, machine)
+        if contact is not None:
+            self.contact = contact
+        name = testName(username, name, machine)
+        if name is not None:
+            self.name = name
+        if memory is not None:
+            self.memory = validMemory(self.owner, state, memory, machine,
+                                      on=not created_new)
+        if disksize is not None:
+            self.disksize = validDisk(self.owner, disksize, machine)
+        if vmtype is not None:
+            self.vmtype = validVmType(vmtype)
+        if cdrom is not None:
+            if not CDROM.get(cdrom):
+                raise CodeError("Invalid cdrom type '%s'" % cdrom)
+            self.cdrom = cdrom
+        if clone_from is not None:
+            if clone_from not in ('ice3', ):
+                raise CodeError("Invalid clone image '%s'" % clone_from)
+            self.clone_from = clone_from
+
+
+def getMachinesByOwner(owner, machine=None):
     """Return the machines owned by the same as a machine.
 
     If the machine is None, return the machines owned by the same
     """Return the machines owned by the same as a machine.
 
     If the machine is None, return the machines owned by the same
@@ -24,11 +66,9 @@ def getMachinesByOwner(user, machine=None):
     """
     if machine:
         owner = machine.owner
     """
     if machine:
         owner = machine.owner
-    else:
-        owner = user
     return Machine.select_by(owner=owner)
 
     return Machine.select_by(owner=owner)
 
-def maxMemory(user, machine=None, on=True):
+def maxMemory(owner, g, machine=None, on=True):
     """Return the maximum memory for a machine or a user.
 
     If machine is None, return the memory available for a new
     """Return the maximum memory for a machine or a user.
 
     If machine is None, return the memory available for a new
@@ -43,12 +83,12 @@ def maxMemory(user, machine=None, on=True):
         return machine.memory
     if not on:
         return MAX_MEMORY_SINGLE
         return machine.memory
     if not on:
         return MAX_MEMORY_SINGLE
-    machines = getMachinesByOwner(user, machine)
+    machines = getMachinesByOwner(owner, machine)
     active_machines = [x for x in machines if g.xmlist.get(x)]
     mem_usage = sum([x.memory for x in active_machines if x != machine])
     return min(MAX_MEMORY_SINGLE, MAX_MEMORY_TOTAL-mem_usage)
 
     active_machines = [x for x in machines if g.xmlist.get(x)]
     mem_usage = sum([x.memory for x in active_machines if x != machine])
     return min(MAX_MEMORY_SINGLE, MAX_MEMORY_TOTAL-mem_usage)
 
-def maxDisk(user, machine=None):
+def maxDisk(owner, machine=None):
     """Return the maximum disk that a machine can reach.
 
     If machine is None, the maximum disk for a new machine. Otherwise,
     """Return the maximum disk that a machine can reach.
 
     If machine is None, the maximum disk for a new machine. Otherwise,
@@ -59,11 +99,11 @@ def maxDisk(user, machine=None):
     else:
         machine_id = None
     disk_usage = Disk.query().filter_by(Disk.c.machine_id != machine_id,
     else:
         machine_id = None
     disk_usage = Disk.query().filter_by(Disk.c.machine_id != machine_id,
-                                        owner=user).sum(Disk.c.size) or 0
+                                        owner=owner).sum(Disk.c.size) or 0
     return min(MAX_DISK_SINGLE, MAX_DISK_TOTAL-disk_usage/1024.)
 
     return min(MAX_DISK_SINGLE, MAX_DISK_TOTAL-disk_usage/1024.)
 
-def cantAddVm(user):
-    machines = getMachinesByOwner(user)
+def cantAddVm(owner, g):
+    machines = getMachinesByOwner(owner)
     active_machines = [x for x in machines if g.xmlist.get(x)]
     if len(machines) >= MAX_VMS_TOTAL:
         return 'You have too many VMs to create a new one.'
     active_machines = [x for x in machines if g.xmlist.get(x)]
     if len(machines) >= MAX_VMS_TOTAL:
         return 'You have too many VMs to create a new one.'
@@ -72,12 +112,6 @@ def cantAddVm(user):
                 'To create more, turn one off.')
     return False
 
                 'To create more, turn one off.')
     return False
 
-def validAddVm(user):
-    reason = cantAddVm(user)
-    if reason:
-        raise InvalidInput('create', True, reason)
-    return True
-
 def haveAccess(user, machine):
     """Return whether a user has administrative access to a machine"""
     return user in cache_acls.accessList(machine)
 def haveAccess(user, machine):
     """Return whether a user has administrative access to a machine"""
     return user in cache_acls.accessList(machine)
@@ -98,8 +132,8 @@ def validMachineName(name):
             return False
     return True
 
             return False
     return True
 
-def validMemory(user, memory, machine=None, on=True):
-    """Parse and validate limits for memory for a given user and machine.
+def validMemory(owner, g, memory, machine=None, on=True):
+    """Parse and validate limits for memory for a given owner and machine.
 
     on is whether the memory must be valid after the machine is
     switched on.
 
     on is whether the memory must be valid after the machine is
     switched on.
@@ -111,19 +145,19 @@ def validMemory(user, memory, machine=None, on=True):
     except ValueError:
         raise InvalidInput('memory', memory,
                            "Minimum %s MiB" % MIN_MEMORY_SINGLE)
     except ValueError:
         raise InvalidInput('memory', memory,
                            "Minimum %s MiB" % MIN_MEMORY_SINGLE)
-    if memory > maxMemory(user, machine, on):
+    max_val = maxMemory(owner, g, machine, on)
+    if memory > max_val:
         raise InvalidInput('memory', memory,
         raise InvalidInput('memory', memory,
-                           'Maximum %s MiB for %s' % (maxMemory(user, machine),
-                                                      user))
+                           'Maximum %s MiB for %s' % (max_val, owner))
     return memory
 
     return memory
 
-def validDisk(user, disk, machine=None):
-    """Parse and validate limits for disk for a given user and machine."""
+def validDisk(owner, disk, machine=None):
+    """Parse and validate limits for disk for a given owner and machine."""
     try:
         disk = float(disk)
     try:
         disk = float(disk)
-        if disk > maxDisk(user, machine):
+        if disk > maxDisk(owner, machine):
             raise InvalidInput('disk', disk,
             raise InvalidInput('disk', disk,
-                               "Maximum %s G" % maxDisk(user, machine))
+                               "Maximum %s G" % maxDisk(owner, machine))
         disk = int(disk * 1024)
         if disk < MIN_DISK_SINGLE * 1024:
             raise ValueError
         disk = int(disk * 1024)
         if disk < MIN_DISK_SINGLE * 1024:
             raise ValueError
@@ -190,8 +224,10 @@ def testOwner(user, owner, machine=None):
 
     If machine is None, this is the owner of a new machine.
     """
 
     If machine is None, this is the owner of a new machine.
     """
-    if owner == user or machine is not None and owner == machine.owner:
+    if owner == user:
         return owner
         return owner
+    if machine is not None and owner in (machine.owner, None):
+        return None
     if owner is None:
         raise InvalidInput('owner', owner, "Owner must be specified")
     try:
     if owner is None:
         raise InvalidInput('owner', owner, "Owner must be specified")
     try:
@@ -213,9 +249,13 @@ def testDisk(user, disksize, machine=None):
     return disksize
 
 def testName(user, name, machine=None):
     return disksize
 
 def testName(user, name, machine=None):
-    if name in (None, machine.name):
+    if name is None:
+        return None
+    if machine is not None and name == machine.name:
         return None
     if not Machine.select_by(name=name):
         return None
     if not Machine.select_by(name=name):
+        if not validMachineName(name):
+            raise InvalidInput('name', name, 'You must provide a machine name.  Max 22 chars, alnum plus \'-\' and \'_\'.')
         return name
     raise InvalidInput('name', name, "Name is already taken.")
 
         return name
     raise InvalidInput('name', name, "Name is already taken.")
 
index 58d9333..5911787 100644 (file)
@@ -36,13 +36,13 @@ def cachedproperty(func):
             return value
     return property(getter)
 
             return value
     return property(getter)
 
-class Global(object):
-    """Global state of the system, to avoid duplicate remctls to get state"""
+class State(object):
+    """State for a request"""
     def __init__(self, user):
     def __init__(self, user):
-        self.user = user
+        self.username = user
 
     machines = cachedproperty(lambda self:
 
     machines = cachedproperty(lambda self:
-                                  Machine.query().join('acl').select_by(user=self.user))
+                                  Machine.query().join('acl').select_by(user=self.username))
     xmlist_raw = cachedproperty(lambda self: controls.getList())
     xmlist = cachedproperty(lambda self:
                                 dict((m, self.xmlist_raw[m.name])
     xmlist_raw = cachedproperty(lambda self: controls.getList())
     xmlist = cachedproperty(lambda self:
                                 dict((m, self.xmlist_raw[m.name])
@@ -55,4 +55,4 @@ class Global(object):
             if attr.startswith('__cache_'):
                 delattr(self, attr)
 
             if attr.startswith('__cache_'):
                 delattr(self, attr)
 
-g = Global(None)
+state = State(None)