Use joins, new xmlist.py
[invirt/packages/invirt-web.git] / code / main.py
index 8bc9ece..e208a13 100755 (executable)
@@ -11,6 +11,7 @@ import sha
 import simplejson
 import sys
 import time
+import urllib
 from StringIO import StringIO
 
 def revertStandardError():
@@ -36,8 +37,10 @@ sys.path.append('/home/ecprice/.local/lib/python2.5/site-packages')
 
 import templates
 from Cheetah.Template import Template
-from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess
+import sipb_xen_database
+from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess, Type, Autoinstall
 import validation
+import cache_acls
 from webcommon import InvalidInput, CodeError, g
 import controls
 
@@ -56,12 +59,15 @@ class Checkpoint:
 
 checkpoint = Checkpoint()
 
+def jquote(string):
+    return "'" + string.replace('\\', '\\\\').replace("'", "\\'").replace('\n', '\\n') + "'"
 
 def helppopup(subj):
     """Return HTML code for a (?) link to a specified help topic"""
-    return ('<span class="helplink"><a href="help?subject=' + subj + 
-            '&amp;simple=true" target="_blank" ' + 
-            'onclick="return helppopup(\'' + subj + '\')">(?)</a></span>')
+    return ('<span class="helplink"><a href="help?' +
+            cgi.escape(urllib.urlencode(dict(subject=subj, simple='true')))
+            +'" target="_blank" ' +
+            'onclick="return helppopup(' + cgi.escape(jquote(subj)) + ')">(?)</a></span>')
 
 def makeErrorPre(old, addition):
     if addition is None:
@@ -71,6 +77,7 @@ def makeErrorPre(old, addition):
     else:
         return '<p>STDERR:</p><pre>' + str(addition) + '</pre>'
 
+Template.sipb_xen_database = sipb_xen_database
 Template.helppopup = staticmethod(helppopup)
 Template.err = None
 
@@ -96,8 +103,10 @@ class Defaults:
     memory = 256
     disk = 4.0
     cdrom = ''
+    autoinstall = ''
     name = ''
-    vmtype = 'hvm'
+    type = 'linux-hvm'
+
     def __init__(self, max_memory=None, max_disk=None, **kws):
         if max_memory is not None:
             self.memory = min(self.memory, max_memory)
@@ -136,25 +145,23 @@ def hasVnc(status):
 def parseCreate(user, fields):
     name = fields.getfirst('name')
     if not validation.validMachineName(name):
-        raise InvalidInput('name', name, 'You must provide a machine name.')
+        raise InvalidInput('name', name, 'You must provide a machine name.  Max 22 chars, alnum plus \'-\' and \'_\'.')
     name = name.lower()
 
     if Machine.get_by(name=name):
         raise InvalidInput('name', name,
                            "Name already exists.")
-    
+
     owner = validation.testOwner(user, fields.getfirst('owner'))
 
     memory = fields.getfirst('memory')
     memory = validation.validMemory(owner, memory, on=True)
-    
+
     disk_size = fields.getfirst('disk')
     disk_size = validation.validDisk(owner, disk_size)
 
     vm_type = fields.getfirst('vmtype')
-    if vm_type not in ('hvm', 'paravm'):
-        raise CodeError("Invalid vm type '%s'"  % vm_type)    
-    is_hvm = (vm_type == 'hvm')
+    vm_type = validation.validVmType(vm_type)
 
     cdrom = fields.getfirst('cdrom')
     if cdrom is not None and not CDROM.get(cdrom):
@@ -163,9 +170,9 @@ def parseCreate(user, fields):
     clone_from = fields.getfirst('clone_from')
     if clone_from and clone_from != 'ice3':
         raise CodeError("Invalid clone image '%s'" % clone_from)
-    
+
     return dict(contact=user, name=name, memory=memory, disk_size=disk_size,
-                owner=owner, is_hvm=is_hvm, cdrom=cdrom, clone_from=clone_from)
+                owner=owner, machine_type=vm_type, cdrom=cdrom, clone_from=clone_from)
 
 def create(user, fields):
     """Handler for create requests."""
@@ -188,20 +195,26 @@ def create(user, fields):
 
 
 def getListDict(user):
+    """Gets the list of local variables used by list.tmpl."""
+    checkpoint.checkpoint('Starting')
     machines = g.machines
     checkpoint.checkpoint('Got my machines')
     on = {}
     has_vnc = {}
-    on = g.uptimes
+    xmlist = g.xmlist
     checkpoint.checkpoint('Got uptimes')
     for m in machines:
-        m.uptime = g.uptimes.get(m)
-        if not on[m]:
+        if m not in xmlist:
             has_vnc[m] = 'Off'
-        elif m.type.hvm:
-            has_vnc[m] = True
+            m.uptime = None
         else:
-            has_vnc[m] = "ParaVM"+helppopup("paravm_console")
+            m.uptime = xmlist[m]['uptime']
+            if xmlist[m]['console']:
+                has_vnc[m] = True
+            elif m.type.hvm:
+                has_vnc[m] = "WTF?"
+            else:
+                has_vnc[m] = "ParaVM"+helppopup("paravm_console")
     max_memory = validation.maxMemory(user)
     max_disk = validation.maxDisk(user)
     checkpoint.checkpoint('Got max mem/disk')
@@ -210,15 +223,16 @@ def getListDict(user):
                         owner=user,
                         cdrom='gutsy-i386')
     checkpoint.checkpoint('Got defaults')
+    def sortkey(machine):
+        return (machine.owner != user, machine.owner, machine.name)
+    machines = sorted(machines, key=sortkey)
     d = dict(user=user,
              cant_add_vm=validation.cantAddVm(user),
              max_memory=max_memory,
              max_disk=max_disk,
              defaults=defaults,
              machines=machines,
-             has_vnc=has_vnc,
-             uptimes=g.uptimes,
-             cdroms=CDROM.select())
+             has_vnc=has_vnc)
     return d
 
 def listVms(user, fields):
@@ -227,7 +241,7 @@ def listVms(user, fields):
     d = getListDict(user)
     checkpoint.checkpoint('Got list dict')
     return templates.list(searchList=[d])
-            
+
 def vnc(user, fields):
     """VNC applet page.
 
@@ -239,9 +253,9 @@ def vnc(user, fields):
     You might want iptables like:
 
     -t nat -A PREROUTING -s ! 18.181.0.60 -i eth1 -p tcp -m tcp \
-      --dport 10003 -j DNAT --to-destination 18.181.0.60:10003 
+      --dport 10003 -j DNAT --to-destination 18.181.0.60:10003
     -t nat -A POSTROUTING -d 18.181.0.60 -o eth1 -p tcp -m tcp \
-      --dport 10003 -j SNAT --to-source 18.187.7.142 
+      --dport 10003 -j SNAT --to-source 18.187.7.142
     -A FORWARD -d 18.181.0.60 -i eth1 -o eth1 -p tcp -m tcp \
       --dport 10003 -j ACCEPT
 
@@ -249,7 +263,7 @@ def vnc(user, fields):
     echo 1 > /proc/sys/net/ipv4/ip_forward
     """
     machine = validation.testMachineId(user, fields.getfirst('machine_id'))
-    
+
     TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
 
     data = {}
@@ -262,10 +276,10 @@ def vnc(user, fields):
     token = {'data': pickled_data, 'digest': m.digest()}
     token = cPickle.dumps(token)
     token = base64.urlsafe_b64encode(token)
-    
+
     status = controls.statusInfo(machine)
     has_vnc = hasVnc(status)
-    
+
     d = dict(user=user,
              on=status,
              has_vnc=has_vnc,
@@ -275,10 +289,14 @@ def vnc(user, fields):
     return templates.vnc(searchList=[d])
 
 def getHostname(nic):
+    """Find the hostname associated with a NIC.
+
+    XXX this should be merged with the similar logic in DNS and DHCP.
+    """
     if nic.hostname and '.' in nic.hostname:
         return nic.hostname
     elif nic.machine:
-        return nic.machine.name + '.servers.csail.mit.edu'
+        return nic.machine.name + '.xvm.mit.edu'
     else:
         return None
 
@@ -316,7 +334,7 @@ def getDiskInfo(data_dict, machine):
     disk_fields = []
     for disk in machine.disks:
         name = disk.guest_device_name
-        disk_fields.extend([(x % name, y % name) for x, y in 
+        disk_fields.extend([(x % name, y % name) for x, y in
                             disk_fields_template])
         data_dict['%s_size' % name] = "%0.1f GiB" % (disk.size / 1024.)
     return disk_fields
@@ -344,13 +362,17 @@ def command(user, fields):
         return templates.list(searchList=[d])
     elif back == 'info':
         machine = validation.testMachineId(user, fields.getfirst('machine_id'))
-        d = infoDict(user, machine)
-        d['result'] = result
-        return templates.info(searchList=[d])
+        return ({'Status': '302',
+                 'Location': '/info?machine_id=%d' % machine.machine_id},
+                "You shouldn't see this message.")
     else:
         raise InvalidInput('back', back, 'Not a known back page.')
 
 def modifyDict(user, fields):
+    """Modify a machine as specified by CGI arguments.
+
+    Return a list of local variables for modify.tmpl.
+    """
     olddisk = {}
     transaction = ctx.current.create_transaction()
     try:
@@ -368,7 +390,11 @@ def modifyDict(user, fields):
         if memory is not None:
             memory = validation.validMemory(user, memory, machine, on=False)
             machine.memory = memory
+
+        vm_type = validation.validVmType(fields.getfirst('vmtype'))
+        if vm_type is not None:
+            machine.type = vm_type
+
         disksize = validation.testDisk(user, fields.getfirst('disk'))
         if disksize is not None:
             disksize = validation.validDisk(user, disksize, machine)
@@ -377,17 +403,22 @@ def modifyDict(user, fields):
                 olddisk[disk.guest_device_name] = disksize
                 disk.size = disksize
                 ctx.current.save(disk)
-        
-        if owner is not None:
+
+        update_acl = False
+        if owner is not None and owner != machine.owner:
             machine.owner = owner
+            update_acl = True
         if name is not None:
             machine.name = name
-        if admin is not None:
+        if admin is not None and admin != machine.administrator:
             machine.administrator = admin
+            update_acl = True
         if contact is not None:
             machine.contact = contact
-            
+
         ctx.current.save(machine)
+        if update_acl:
+            cache_acls.refreshMachine(machine)
         transaction.commit()
     except:
         transaction.rollback()
@@ -399,7 +430,7 @@ def modifyDict(user, fields):
     return dict(user=user,
                 command=command,
                 machine=machine)
-    
+
 def modify(user, fields):
     """Handler for modifying attributes of a machine."""
     try:
@@ -418,17 +449,18 @@ def modify(user, fields):
             setattr(info_dict['defaults'], field, fields.getfirst(field))
     info_dict['result'] = result
     return templates.info(searchList=[info_dict])
-    
+
 
 def helpHandler(user, fields):
     """Handler for help messages."""
     simple = fields.getfirst('simple')
     subjects = fields.getlist('subject')
-    
+
     help_mapping = dict(paravm_console="""
-ParaVM machines do not support console access over VNC.  To access
-these machines, you either need to boot with a liveCD and ssh in or
-hope that the sipb-xen maintainers add support for serial consoles.""",
+ParaVM machines do not support local console access over VNC.  To
+access the serial console of these machines, you can SSH with Kerberos
+to sipb-xen-console.mit.edu, using the name of the machine as your
+username.""",
                         hvm_paravm="""
 HVM machines use the virtualization features of the processor, while
 ParaVM machines use Xen's emulation of virtualization features.  You
@@ -440,15 +472,15 @@ The owner field is used to determine <a
 href="help?subject=quotas">quotas</a>.  It must be the name of a
 locker that you are an AFS administrator of.  In particular, you or an
 AFS group you are a member of must have AFS rlidwka bits on the
-locker.  You can check see who administers the LOCKER locker using the
-command 'fs la /mit/LOCKER' on Athena.)  See also <a
+locker.  You can check who administers the LOCKER locker using the
+commands 'attach LOCKER; fs la /mit/LOCKER' on Athena.)  See also <a
 href="help?subject=administrator">administrator</a>.""",
                         administrator="""
 The administrator field determines who can access the console and
 power on and off the machine.  This can be either a user or a moira
 group.""",
                         quotas="""
-Quotas are determined on a per-locker basis.  Each quota may have a
+Quotas are determined on a per-locker basis.  Each locker may have a
 maximum of 512 megabytes of active ram, 50 gigabytes of disk, and 4
 active machines.""",
                         console="""
@@ -458,22 +490,24 @@ your machine will run just fine, but the applet's display of the
 console will suffer artifacts.
 """
                    )
-    
+
     if not subjects:
         subjects = sorted(help_mapping.keys())
-        
+
     d = dict(user=user,
              simple=simple,
              subjects=subjects,
              mapping=help_mapping)
-    
+
     return templates.help(searchList=[d])
-    
+
 
 def badOperation(u, e):
+    """Function called when accessing an unknown URI."""
     raise CodeError("Unknown operation")
 
 def infoDict(user, machine):
+    """Get the variables used by info.tmpl."""
     status = controls.statusInfo(machine)
     checkpoint.checkpoint('Getting status info')
     has_vnc = hasVnc(status)
@@ -520,14 +554,14 @@ def infoDict(user, machine):
 
     nic_fields = getNicInfo(machine_info, machine)
     nic_point = display_fields.index('NIC_INFO')
-    display_fields = (display_fields[:nic_point] + nic_fields + 
+    display_fields = (display_fields[:nic_point] + nic_fields +
                       display_fields[nic_point+1:])
 
     disk_fields = getDiskInfo(machine_info, machine)
     disk_point = display_fields.index('DISK_INFO')
-    display_fields = (display_fields[:disk_point] + disk_fields + 
+    display_fields = (display_fields[:disk_point] + disk_fields +
                       display_fields[disk_point+1:])
-    
+
     main_status['memory'] += ' MiB'
     for field, disp in display_fields:
         if field in ('uptime', 'cputime') and locals()[field] is not None:
@@ -549,10 +583,10 @@ def infoDict(user, machine):
     defaults = Defaults()
     for name in 'machine_id name administrator owner memory contact'.split():
         setattr(defaults, name, getattr(machine, name))
+    defaults.type = machine.type.type_id
     defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
     checkpoint.checkpoint('Got defaults')
     d = dict(user=user,
-             cdroms=CDROM.select(),
              on=status is not None,
              machine=machine,
              defaults=defaults,
@@ -572,15 +606,21 @@ def info(user, fields):
     checkpoint.checkpoint('Got infodict')
     return templates.info(searchList=[d])
 
+def unauthFront(_, fields):
+    """Information for unauth'd users."""
+    return templates.unauth(searchList=[{'simple' : True}])
+
 mapping = dict(list=listVms,
                vnc=vnc,
                command=command,
                modify=modify,
                info=info,
                create=create,
-               help=helpHandler)
+               help=helpHandler,
+               unauth=unauthFront)
 
 def printHeaders(headers):
+    """Print a dictionary as HTTP headers."""
     for key, value in headers.iteritems():
         print '%s: %s' % (key, value)
     print
@@ -588,10 +628,12 @@ def printHeaders(headers):
 
 def getUser():
     """Return the current user based on the SSL environment variables"""
-    username = os.environ['SSL_CLIENT_S_DN_Email'].split("@")[0]
-    return username
+    email = os.environ.get('SSL_CLIENT_S_DN_Email', None)
+    if email is None:
+        return None
+    return email.split("@")[0]
 
-def main(operation, user, fields):    
+def main(operation, user, fields):
     start_time = time.time()
     fun = mapping.get(operation, badOperation)
 
@@ -613,7 +655,8 @@ def main(operation, user, fields):
         output_string =  str(output)
         checkpoint.checkpoint('output as a string')
         print output_string
-        print '<pre>%s</pre>' % checkpoint
+        if fields.has_key('timedebug'):
+            print '<pre>%s</pre>' % checkpoint
     except Exception, err:
         if not fields.has_key('js'):
             if isinstance(err, CodeError):
@@ -637,6 +680,13 @@ def main(operation, user, fields):
 
 if __name__ == '__main__':
     fields = cgi.FieldStorage()
+
+    if fields.has_key('sqldebug'):
+        import logging
+        logging.basicConfig()
+        logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+        logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.INFO)
+
     u = getUser()
     g.user = u
     operation = os.environ.get('PATH_INFO', '')
@@ -645,6 +695,9 @@ if __name__ == '__main__':
         print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
         sys.exit(0)
 
+    if u is None:
+        operation = 'unauth'
+
     if operation.startswith('/'):
         operation = operation[1:]
     if not operation: