Use joins, new xmlist.py
[invirt/packages/invirt-web.git] / code / main.py
index 3fe04d5..e208a13 100755 (executable)
@@ -11,6 +11,7 @@ import sha
 import simplejson
 import sys
 import time
+import urllib
 from StringIO import StringIO
 
 def revertStandardError():
@@ -39,6 +40,7 @@ from Cheetah.Template import Template
 import sipb_xen_database
 from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess, Type, Autoinstall
 import validation
+import cache_acls
 from webcommon import InvalidInput, CodeError, g
 import controls
 
@@ -57,12 +59,15 @@ class Checkpoint:
 
 checkpoint = Checkpoint()
 
+def jquote(string):
+    return "'" + string.replace('\\', '\\\\').replace("'", "\\'").replace('\n', '\\n') + "'"
 
 def helppopup(subj):
     """Return HTML code for a (?) link to a specified help topic"""
-    return ('<span class="helplink"><a href="help?subject=' + subj +
-            '&amp;simple=true" target="_blank" ' +
-            'onclick="return helppopup(\'' + subj + '\')">(?)</a></span>')
+    return ('<span class="helplink"><a href="help?' +
+            cgi.escape(urllib.urlencode(dict(subject=subj, simple='true')))
+            +'" target="_blank" ' +
+            'onclick="return helppopup(' + cgi.escape(jquote(subj)) + ')">(?)</a></span>')
 
 def makeErrorPre(old, addition):
     if addition is None:
@@ -100,8 +105,9 @@ class Defaults:
     cdrom = ''
     autoinstall = ''
     name = ''
+    type = 'linux-hvm'
+
     def __init__(self, max_memory=None, max_disk=None, **kws):
-        self.type = Type.get('linux-hvm')
         if max_memory is not None:
             self.memory = min(self.memory, max_memory)
         if max_disk is not None:
@@ -190,20 +196,25 @@ def create(user, fields):
 
 def getListDict(user):
     """Gets the list of local variables used by list.tmpl."""
+    checkpoint.checkpoint('Starting')
     machines = g.machines
     checkpoint.checkpoint('Got my machines')
     on = {}
     has_vnc = {}
-    on = g.uptimes
+    xmlist = g.xmlist
     checkpoint.checkpoint('Got uptimes')
     for m in machines:
-        m.uptime = g.uptimes.get(m)
-        if not on[m]:
+        if m not in xmlist:
             has_vnc[m] = 'Off'
-        elif m.type.hvm:
-            has_vnc[m] = True
+            m.uptime = None
         else:
-            has_vnc[m] = "ParaVM"+helppopup("paravm_console")
+            m.uptime = xmlist[m]['uptime']
+            if xmlist[m]['console']:
+                has_vnc[m] = True
+            elif m.type.hvm:
+                has_vnc[m] = "WTF?"
+            else:
+                has_vnc[m] = "ParaVM"+helppopup("paravm_console")
     max_memory = validation.maxMemory(user)
     max_disk = validation.maxDisk(user)
     checkpoint.checkpoint('Got max mem/disk')
@@ -221,8 +232,7 @@ def getListDict(user):
              max_disk=max_disk,
              defaults=defaults,
              machines=machines,
-             has_vnc=has_vnc,
-             uptimes=g.uptimes)
+             has_vnc=has_vnc)
     return d
 
 def listVms(user, fields):
@@ -286,7 +296,7 @@ def getHostname(nic):
     if nic.hostname and '.' in nic.hostname:
         return nic.hostname
     elif nic.machine:
-        return nic.machine.name + '.servers.csail.mit.edu'
+        return nic.machine.name + '.xvm.mit.edu'
     else:
         return None
 
@@ -394,16 +404,21 @@ def modifyDict(user, fields):
                 disk.size = disksize
                 ctx.current.save(disk)
 
-        if owner is not None:
+        update_acl = False
+        if owner is not None and owner != machine.owner:
             machine.owner = owner
+            update_acl = True
         if name is not None:
             machine.name = name
-        if admin is not None:
+        if admin is not None and admin != machine.administrator:
             machine.administrator = admin
+            update_acl = True
         if contact is not None:
             machine.contact = contact
 
         ctx.current.save(machine)
+        if update_acl:
+            cache_acls.refreshMachine(machine)
         transaction.commit()
     except:
         transaction.rollback()
@@ -566,8 +581,9 @@ def infoDict(user, machine):
     checkpoint.checkpoint('Got mem')
     max_disk = validation.maxDisk(user, machine)
     defaults = Defaults()
-    for name in 'machine_id name administrator owner memory contact type'.split():
+    for name in 'machine_id name administrator owner memory contact'.split():
         setattr(defaults, name, getattr(machine, name))
+    defaults.type = machine.type.type_id
     defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
     checkpoint.checkpoint('Got defaults')
     d = dict(user=user,
@@ -590,13 +606,18 @@ def info(user, fields):
     checkpoint.checkpoint('Got infodict')
     return templates.info(searchList=[d])
 
+def unauthFront(_, fields):
+    """Information for unauth'd users."""
+    return templates.unauth(searchList=[{'simple' : True}])
+
 mapping = dict(list=listVms,
                vnc=vnc,
                command=command,
                modify=modify,
                info=info,
                create=create,
-               help=helpHandler)
+               help=helpHandler,
+               unauth=unauthFront)
 
 def printHeaders(headers):
     """Print a dictionary as HTTP headers."""
@@ -607,8 +628,10 @@ def printHeaders(headers):
 
 def getUser():
     """Return the current user based on the SSL environment variables"""
-    username = os.environ['SSL_CLIENT_S_DN_Email'].split("@")[0]
-    return username
+    email = os.environ.get('SSL_CLIENT_S_DN_Email', None)
+    if email is None:
+        return None
+    return email.split("@")[0]
 
 def main(operation, user, fields):
     start_time = time.time()
@@ -632,7 +655,8 @@ def main(operation, user, fields):
         output_string =  str(output)
         checkpoint.checkpoint('output as a string')
         print output_string
-        print '<!-- <pre>%s</pre> -->' % checkpoint
+        if fields.has_key('timedebug'):
+            print '<pre>%s</pre>' % checkpoint
     except Exception, err:
         if not fields.has_key('js'):
             if isinstance(err, CodeError):
@@ -656,6 +680,13 @@ def main(operation, user, fields):
 
 if __name__ == '__main__':
     fields = cgi.FieldStorage()
+
+    if fields.has_key('sqldebug'):
+        import logging
+        logging.basicConfig()
+        logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
+        logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.INFO)
+
     u = getUser()
     g.user = u
     operation = os.environ.get('PATH_INFO', '')
@@ -664,6 +695,9 @@ if __name__ == '__main__':
         print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
         sys.exit(0)
 
+    if u is None:
+        operation = 'unauth'
+
     if operation.startswith('/'):
         operation = operation[1:]
     if not operation: