Fix another stupid bug.
[invirt/packages/invirt-web.git] / code / main.py
index 8a5c178..6dcb70d 100755 (executable)
@@ -6,7 +6,6 @@ import cPickle
 import cgi
 import datetime
 import hmac
-import os
 import sha
 import simplejson
 import sys
@@ -41,7 +40,7 @@ import sipb_xen_database
 from sipb_xen_database import Machine, CDROM, ctx, connect, MachineAccess, Type, Autoinstall
 import validation
 import cache_acls
-from webcommon import InvalidInput, CodeError, state
+from webcommon import InvalidInput, CodeError, State
 import controls
 
 class Checkpoint:
@@ -144,7 +143,7 @@ def hasVnc(status):
 
 def parseCreate(username, state, fields):
     kws = dict([(kw, fields.getfirst(kw)) for kw in 'name owner memory disksize vmtype cdrom clone_from'.split()])
-    validate = validation.Validate(username, state, **kws)
+    validate = validation.Validate(username, state, strict=True, **kws)
     return dict(contact=username, name=validate.name, memory=validate.memory,
                 disksize=validate.disksize, owner=validate.owner, machine_type=validate.vmtype,
                 cdrom=getattr(validate, 'cdrom', None),
@@ -154,13 +153,13 @@ def create(username, state, fields):
     """Handler for create requests."""
     try:
         parsed_fields = parseCreate(username, state, fields)
-        machine = controls.createVm(username, **parsed_fields)
+        machine = controls.createVm(username, state, **parsed_fields)
     except InvalidInput, err:
         pass
     else:
         err = None
     state.clear() #Changed global state
-    d = getListDict(username)
+    d = getListDict(username, state)
     d['err'] = err
     if err:
         for field in fields.keys():
@@ -262,7 +261,7 @@ def vnc(username, state, fields):
              on=status,
              has_vnc=has_vnc,
              machine=machine,
-             hostname=os.environ.get('SERVER_NAME', 'localhost'),
+             hostname=state.environ.get('SERVER_NAME', 'localhost'),
              authtoken=token)
     return templates.vnc(searchList=[d])
 
@@ -335,7 +334,7 @@ def command(username, state, fields):
             return templates.command(searchList=[d])
     if back == 'list':
         state.clear() #Changed global state
-        d = getListDict(username)
+        d = getListDict(username, state)
         d['result'] = result
         return templates.list(searchList=[d])
     elif back == 'info':
@@ -412,7 +411,7 @@ def modify(username, state, fields):
         machine = modify_dict['machine']
         result = 'Success!'
         err = None
-    info_dict = infoDict(username, machine)
+    info_dict = infoDict(username, state, machine)
     info_dict['err'] = err
     if err:
         for field in fields.keys():
@@ -472,11 +471,11 @@ console will suffer artifacts.
     return templates.help(searchList=[d])
 
 
-def badOperation(u, e):
+def badOperation(u, s, e):
     """Function called when accessing an unknown URI."""
     raise CodeError("Unknown operation")
 
-def infoDict(username, machine):
+def infoDict(username, state, machine):
     """Get the variables used by info.tmpl."""
     status = controls.statusInfo(machine)
     checkpoint.checkpoint('Getting status info')
@@ -572,7 +571,7 @@ def infoDict(username, machine):
 def info(username, state, fields):
     """Handler for info on a single VM."""
     machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
-    d = infoDict(username, machine)
+    d = infoDict(username, state, machine)
     checkpoint.checkpoint('Got infodict')
     return templates.info(searchList=[d])
 
@@ -605,83 +604,89 @@ def getUser(environ):
         return None
     return email[:-8]
 
-def main(operation, username, state, fields):
-    start_time = time.time()
-    fun = mapping.get(operation, badOperation)
-
-    if fun not in (helpHandler, ):
-        connect('postgres://sipb-xen@sipb-xen-dev.mit.edu/sipb_xen')
-    try:
-        checkpoint.checkpoint('Before')
-        output = fun(username, state, fields)
-        checkpoint.checkpoint('After')
-
-        headers = dict(DEFAULT_HEADERS)
-        if isinstance(output, tuple):
-            new_headers, output = output
-            headers.update(new_headers)
-        e = revertStandardError()
-        if e:
-            if isinstance(output, basestring):
-                sys.stderr = StringIO()
-                x = str(output)
-                print >> sys.stderr, x
-                print >> sys.stderr, 'XXX'
-                print >> sys.stderr, e
-                raise Exception()
-            output.addError(e)
-        printHeaders(headers)
-        output_string =  str(output)
-        checkpoint.checkpoint('output as a string')
-        print output_string
+class App:
+    def __init__(self, environ, start_response):
+        self.environ = environ
+        self.start = start_response
+
+        self.username = getUser(environ)
+        self.state = State(self.username)
+        self.state.environ = environ
+
+    def __iter__(self):
+        fields = cgi.FieldStorage(fp=self.environ['wsgi.input'], environ=self.environ)
+        print >> sys.stderr, fields
+        operation = self.environ.get('PATH_INFO', '')
+        if not operation:
+            self.start("301 Moved Permanently", [('Location',
+                                                  self.environ['SCRIPT_NAME']+'/')])
+            return
+        if self.username is None:
+            operation = 'unauth'
+        if operation.startswith('/'):
+            operation = operation[1:]
+        if not operation:
+            operation = 'list'
+        print 'Starting', operation
+
+        start_time = time.time()
+        fun = mapping.get(operation, badOperation)
+        try:
+            checkpoint.checkpoint('Before')
+            output = fun(self.username, self.state, fields)
+            checkpoint.checkpoint('After')
+
+            headers = dict(DEFAULT_HEADERS)
+            if isinstance(output, tuple):
+                new_headers, output = output
+                headers.update(new_headers)
+            print 'MOO2'
+            e = revertStandardError()
+            if e:
+                if isinstance(output, basestring):
+                    sys.stderr = StringIO()
+                    x = str(output)
+                    print >> sys.stderr, x
+                    print >> sys.stderr, 'XXX'
+                    print >> sys.stderr, e
+                    raise Exception()
+                output.addError(e)
+            output_string =  str(output)
+            checkpoint.checkpoint('output as a string')
+        except Exception, err:
+            if not fields.has_key('js'):
+                if isinstance(err, CodeError):
+                    self.start('500 Internal Server Error', [('Content-Type', 'text/html')])
+                    e = revertStandardError()
+                    s = error(operation, self.username, fields, err, e)
+                    yield str(s)
+                    return
+                if isinstance(err, InvalidInput):
+                    self.start('200 OK', [('Content-Type', 'text/html')])
+                    e = revertStandardError()
+                    yield str(invalidInput(operation, self.username, fields, err, e))
+                    return
+            self.start('500 Internal Server Error', [('Content-Type', 'text/plain')])
+            import traceback
+            yield '''Uh-oh!  We experienced an error.'
+Please email xvm-dev@mit.edu with the contents of this page.'
+----
+%s
+----
+%s
+----''' % (str(err), traceback.format_exc())
+        self.start('200 OK', headers.items())
+        yield output_string
         if fields.has_key('timedebug'):
-            print '<pre>%s</pre>' % cgi.escape(checkpoint)
-    except Exception, err:
-        if not fields.has_key('js'):
-            if isinstance(err, CodeError):
-                print 'Content-Type: text/html\n'
-                e = revertStandardError()
-                print error(operation, state.username, fields, err, e)
-                sys.exit(1)
-            if isinstance(err, InvalidInput):
-                print 'Content-Type: text/html\n'
-                e = revertStandardError()
-                print invalidInput(operation, state.username, fields, err, e)
-                sys.exit(1)
-        print 'Content-Type: text/plain\n'
-        print 'Uh-oh!  We experienced an error.'
-        print 'Please email xvm-dev@mit.edu with the contents of this page.'
-        print '----'
-        e = revertStandardError()
-        print e
-        print '----'
-        raise
+            yield '<pre>%s</pre>' % cgi.escape(str(checkpoint))
+
+def constructor():
+    connect('postgres://sipb-xen@sipb-xen-dev.mit.edu/sipb_xen')
+    return App
+
+def main():
+    from flup.server.fcgi_fork import WSGIServer
+    WSGIServer(constructor()).run()
 
 if __name__ == '__main__':
-    fields = cgi.FieldStorage()
-
-    if fields.has_key('sqldebug'):
-        import logging
-        logging.basicConfig()
-        logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
-        logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.INFO)
-
-    username = getUser(os.environ)
-    state.username = username
-    operation = os.environ.get('PATH_INFO', '')
-    if not operation:
-        print "Status: 301 Moved Permanently"
-        print 'Location: ' + os.environ['SCRIPT_NAME']+'/\n'
-        sys.exit(0)
-    if username is None:
-        operation = 'unauth'
-    if operation.startswith('/'):
-        operation = operation[1:]
-    if not operation:
-        operation = 'list'
-
-    if os.getenv("SIPB_XEN_PROFILE"):
-        import profile
-        profile.run('main(operation, username, state, fields)', 'log-'+operation)
-    else:
-        main(operation, username, state, fields)
+    main()