Update the cherrypy branch to use authz.afs.cells instead of just
[invirt/packages/invirt-web.git] / code / main.py
index 473647f..3655352 100755 (executable)
@@ -8,21 +8,13 @@ import datetime
 import hmac
 import random
 import sha
-import simplejson
 import sys
 import time
 import urllib
 import socket
 import cherrypy
+from cherrypy import _cperror
 from StringIO import StringIO
-def revertStandardError():
-    """Move stderr to stdout, and return the contents of the old stderr."""
-    errio = sys.stderr
-    if not isinstance(errio, StringIO):
-        return ''
-    sys.stderr = sys.stdout
-    errio.seek(0)
-    return errio.read()
 
 def printError():
     """Revert stderr to stdout, and print the contents of stderr"""
@@ -33,8 +25,6 @@ if __name__ == '__main__':
     import atexit
     atexit.register(printError)
 
-import templates
-from Cheetah.Template import Template
 import validation
 import cache_acls
 from webcommon import State
@@ -45,16 +35,78 @@ from invirt.database import Machine, CDROM, session, connect, MachineAccess, Typ
 from invirt.config import structs as config
 from invirt.common import InvalidInput, CodeError
 
-from view import View
+from view import View, revertStandardError
+
+class InvirtUnauthWeb(View):
+    @cherrypy.expose
+    @cherrypy.tools.mako(filename="/unauth.mako")
+    def index(self):
+        return {'simple': True}
 
 class InvirtWeb(View):
     def __init__(self):
         super(self.__class__,self).__init__()
         connect()
         self._cp_config['tools.require_login.on'] = True
+        self._cp_config['tools.catch_stderr.on'] = True
         self._cp_config['tools.mako.imports'] = ['from invirt.config import structs as config',
                                                  'from invirt import database']
+        self._cp_config['request.error_response'] = self.handle_error
 
+    @cherrypy.expose
+    @cherrypy.tools.mako(filename="/invalid.mako")
+    def invalidInput(self):
+        """Print an error page when an InvalidInput exception occurs"""
+        err = cherrypy.request.prev.params["err"]
+        emsg = cherrypy.request.prev.params["emsg"]
+        d = dict(err_field=err.err_field,
+                 err_value=str(err.err_value), stderr=emsg,
+                 errorMessage=str(err))
+        return d
+
+    @cherrypy.expose
+    @cherrypy.tools.mako(filename="/error.mako")
+    def error(self):
+        """Print an error page when an exception occurs"""
+        op = cherrypy.request.prev.path_info
+        username = cherrypy.request.login
+        err = cherrypy.request.prev.params["err"]
+        emsg = cherrypy.request.prev.params["emsg"]
+        traceback = cherrypy.request.prev.params["traceback"]
+        d = dict(op=op, user=username, fields=cherrypy.request.prev.params,
+                 errorMessage=str(err), stderr=emsg, traceback=traceback)
+        error_raw = cherrypy.request.lookup.get_template("/error_raw.mako")
+        details = error_raw.render(**d)
+        exclude = config.web.errormail_exclude
+        if username not in exclude and '*' not in exclude:
+            send_error_mail('xvm error on %s for %s: %s' % (op, cherrypy.request.login, err),
+                            details)
+        d['details'] = details
+        return d
+
+    def __getattr__(self, name):
+        if name in ("admin", "overlord"):
+            if not cherrypy.request.login in getAfsGroupMembers(config.adminacl, config.authz.afs.cells[0].cell):
+                raise InvalidInput('username', cherrypy.request.login,
+                                   'Not in admin group %s.' % config.adminacl)
+            cherrypy.request.state = State(cherrypy.request.login, isadmin=True)
+            return self
+        else:
+            return super(InvirtWeb, self).__getattr__(name)
+
+    def handle_error(self):
+        err = sys.exc_info()[1]
+        if isinstance(err, InvalidInput):
+            cherrypy.request.params['err'] = err
+            cherrypy.request.params['emsg'] = revertStandardError()
+            raise cherrypy.InternalRedirect('/invalidInput')
+        if not cherrypy.request.prev or 'err' not in cherrypy.request.prev.params:
+            cherrypy.request.params['err'] = err
+            cherrypy.request.params['emsg'] = revertStandardError()
+            cherrypy.request.params['traceback'] = _cperror.format_exc()
+            raise cherrypy.InternalRedirect('/error')
+        # fall back to cherrypy default error page
+        cherrypy.HTTPError(500).set_response()
 
     @cherrypy.expose
     @cherrypy.tools.mako(filename="/list.mako")
@@ -147,10 +199,16 @@ console will suffer artifacts.
     help._cp_config['tools.require_login.on'] = False
 
     def parseCreate(self, fields):
-        kws = dict([(kw, fields.get(kw)) for kw in 'name description owner memory disksize vmtype cdrom autoinstall'.split() if fields.get(kw)])
-        validate = validation.Validate(cherrypy.request.login, cherrypy.request.state, strict=True, **kws)
-        return dict(contact=cherrypy.request.login, name=validate.name, description=validate.description, memory=validate.memory,
-                    disksize=validate.disksize, owner=validate.owner, machine_type=getattr(validate, 'vmtype', Defaults.type),
+        kws = dict([(kw, fields[kw]) for kw in
+         'name description owner memory disksize vmtype cdrom autoinstall'.split()
+                    if fields[kw]])
+        validate = validation.Validate(cherrypy.request.login,
+                                       cherrypy.request.state,
+                                       strict=True, **kws)
+        return dict(contact=cherrypy.request.login, name=validate.name,
+                    description=validate.description, memory=validate.memory,
+                    disksize=validate.disksize, owner=validate.owner,
+                    machine_type=getattr(validate, 'vmtype', Defaults.type),
                     cdrom=getattr(validate, 'cdrom', None),
                     autoinstall=getattr(validate, 'autoinstall', None))
 
@@ -161,7 +219,8 @@ console will suffer artifacts.
         """Handler for create requests."""
         try:
             parsed_fields = self.parseCreate(fields)
-            machine = controls.createVm(cherrypy.request.login, cherrypy.request.state, **parsed_fields)
+            machine = controls.createVm(cherrypy.request.login,
+                                        cherrypy.request.state, **parsed_fields)
         except InvalidInput, err:
             pass
         else:
@@ -170,8 +229,8 @@ console will suffer artifacts.
         d = getListDict(cherrypy.request.login, cherrypy.request.state)
         d['err'] = err
         if err:
-            for field in fields.keys():
-                setattr(d['defaults'], field, fields.get(field))
+            for field, value in fields.items():
+                setattr(d['defaults'], field, value)
         else:
             d['new_machine'] = parsed_fields['name']
         return d
@@ -185,17 +244,23 @@ console will suffer artifacts.
     @cherrypy.expose
     def errortest(self):
         """Throw an error, to test the error-tracing mechanisms."""
+        print >>sys.stderr, "look ma, it's a stderr"
         raise RuntimeError("test of the emergency broadcast system")
 
     class MachineView(View):
-        # This is hairy. Fix when CherryPy 3.2 is out. (rename to
-        # _cp_dispatch, and parse the argument as a list instead of
-        # string
-
         def __getattr__(self, name):
+            """Synthesize attributes to allow RESTful URLs like
+            /machine/13/info. This is hairy. CherryPy 3.2 adds a
+            method called _cp_dispatch that allows you to explicitly
+            handle URLs that can't be mapped, and it allows you to
+            rewrite the path components and continue processing.
+
+            This function gets the next path component being resolved
+            as a string. _cp_dispatch will get an array of strings
+            representing any subsequent path components as well."""
+
             try:
-                machine_id = int(name)
-                cherrypy.request.params['machine_id'] = machine_id
+                cherrypy.request.params['machine_id'] = int(name)
                 return self
             except ValueError:
                 return None
@@ -204,13 +269,42 @@ console will suffer artifacts.
         @cherrypy.tools.mako(filename="/info.mako")
         def info(self, machine_id):
             """Handler for info on a single VM."""
-            machine = validation.Validate(cherrypy.request.login, cherrypy.request.state, machine_id=machine_id).machine
+            machine = validation.Validate(cherrypy.request.login,
+                                          cherrypy.request.state,
+                                          machine_id=machine_id).machine
             d = infoDict(cherrypy.request.login, cherrypy.request.state, machine)
             checkpoint.checkpoint('Got infodict')
             return d
         index = info
 
         @cherrypy.expose
+        @cherrypy.tools.mako(filename="/info.mako")
+        @cherrypy.tools.require_POST()
+        def modify(self, machine_id, **fields):
+            """Handler for modifying attributes of a machine."""
+            try:
+                modify_dict = modifyDict(cherrypy.request.login,
+                                         cherrypy.request.state,
+                                         machine_id, fields)
+            except InvalidInput, err:
+                result = None
+                machine = validation.Validate(cherrypy.request.login,
+                                              cherrypy.request.state,
+                                              machine_id=machine_id).machine
+            else:
+                machine = modify_dict['machine']
+                result = 'Success!'
+                err = None
+            info_dict = infoDict(cherrypy.request.login,
+                                 cherrypy.request.state, machine)
+            info_dict['err'] = err
+            if err:
+                for field, value in fields.items():
+                    setattr(info_dict['defaults'], field, value)
+            info_dict['result'] = result
+            return info_dict
+
+        @cherrypy.expose
         @cherrypy.tools.mako(filename="/vnc.mako")
         def vnc(self, machine_id):
             """VNC applet page.
@@ -232,8 +326,9 @@ console will suffer artifacts.
             Remember to enable iptables!
             echo 1 > /proc/sys/net/ipv4/ip_forward
             """
-            machine = validation.Validate(cherrypy.request.login, cherrypy.request.state, machine_id=machine_id).machine
-
+            machine = validation.Validate(cherrypy.request.login,
+                                          cherrypy.request.state,
+                                          machine_id=machine_id).machine
             token = controls.vnctoken(machine)
             host = controls.listHost(machine)
             if host:
@@ -251,43 +346,41 @@ console will suffer artifacts.
                      port=port,
                      authtoken=token)
             return d
+
         @cherrypy.expose
         @cherrypy.tools.mako(filename="/command.mako")
         @cherrypy.tools.require_POST()
         def command(self, command_name, machine_id, **kwargs):
             """Handler for running commands like boot and delete on a VM."""
-            back = kwargs.get('back', None)
+            back = kwargs.get('back')
             try:
-                d = controls.commandResult(cherrypy.request.login, cherrypy.request.state, command_name, machine_id, kwargs)
+                d = controls.commandResult(cherrypy.request.login,
+                                           cherrypy.request.state,
+                                           command_name, machine_id, kwargs)
                 if d['command'] == 'Delete VM':
                     back = 'list'
             except InvalidInput, err:
                 if not back:
                     raise
                 print >> sys.stderr, err
-                result = err
+                result = str(err)
             else:
                 result = 'Success!'
                 if not back:
                     return d
             if back == 'list':
                 cherrypy.request.state.clear() #Changed global state
-                raise cherrypy.InternalRedirect('/list?result=%s' % urllib.quote(result))
+                raise cherrypy.InternalRedirect('/list?result=%s'
+                                                % urllib.quote(result))
             elif back == 'info':
-                raise cherrypy.HTTPRedirect(cherrypy.request.base + '/machine/%d/' % machine_id, status=303)
+                raise cherrypy.HTTPRedirect(cherrypy.request.base
+                                            + '/machine/%d/' % machine_id,
+                                            status=303)
             else:
                 raise InvalidInput('back', back, 'Not a known back page.')
 
     machine = MachineView()
 
-def pathSplit(path):
-    if path.startswith('/'):
-        path = path[1:]
-    i = path.find('/')
-    if i == -1:
-        i = len(path)
-    return path[:i], path[i:]
-
 class Checkpoint:
     def __init__(self):
         self.start_time = time.time()
@@ -303,35 +396,6 @@ class Checkpoint:
 
 checkpoint = Checkpoint()
 
-def makeErrorPre(old, addition):
-    if addition is None:
-        return
-    if old:
-        return old[:-6]  + '\n----\n' + str(addition) + '</pre>'
-    else:
-        return '<p>STDERR:</p><pre>' + str(addition) + '</pre>'
-
-Template.database = database
-Template.config = config
-Template.err = None
-
-class JsonDict:
-    """Class to store a dictionary that will be converted to JSON"""
-    def __init__(self, **kws):
-        self.data = kws
-        if 'err' in kws:
-            err = kws['err']
-            del kws['err']
-            self.addError(err)
-
-    def __str__(self):
-        return simplejson.dumps(self.data)
-
-    def addError(self, text):
-        """Add stderr text to be displayed on the website."""
-        self.data['err'] = \
-            makeErrorPre(self.data.get('err'), text)
-
 class Defaults:
     """Class to store default values for fields."""
     memory = 256
@@ -340,6 +404,7 @@ class Defaults:
     autoinstall = ''
     name = ''
     description = ''
+    administrator = ''
     type = 'linux-hvm'
 
     def __init__(self, max_memory=None, max_disk=None, **kws):
@@ -350,17 +415,6 @@ class Defaults:
         for key in kws:
             setattr(self, key, kws[key])
 
-
-
-DEFAULT_HEADERS = {'Content-Type': 'text/html'}
-
-def invalidInput(op, username, fields, err, emsg):
-    """Print an error page when an InvalidInput exception occurs"""
-    d = dict(op=op, user=username, err_field=err.err_field,
-             err_value=str(err.err_value), stderr=emsg,
-             errorMessage=str(err))
-    return templates.invalid(searchList=[d])
-
 def hasVnc(status):
     """Does the machine with a given status list support VNC?"""
     if status is None:
@@ -471,15 +525,18 @@ def getDiskInfo(data_dict, machine):
         data_dict['%s_size' % name] = "%0.1f GiB" % (disk.size / 1024.)
     return disk_fields
 
-def modifyDict(username, state, fields):
+def modifyDict(username, state, machine_id, fields):
     """Modify a machine as specified by CGI arguments.
 
-    Return a list of local variables for modify.tmpl.
+    Return a dict containing the machine that was modified.
     """
     olddisk = {}
     session.begin()
     try:
-        kws = dict([(kw, fields.getfirst(kw)) for kw in 'machine_id owner admin contact name description memory vmtype disksize'.split()])
+        kws = dict([(kw, fields[kw]) for kw in
+         'owner admin contact name description memory vmtype disksize'.split()
+                    if fields[kw]])
+        kws['machine_id'] = machine_id
         validate = validation.Validate(username, state, **kws)
         machine = validate.machine
         oldname = machine.name
@@ -526,32 +583,7 @@ def modifyDict(username, state, fields):
         controls.resizeDisk(oldname, diskname, str(olddisk[diskname]))
     if hasattr(validate, 'name'):
         controls.renameMachine(machine, oldname, validate.name)
-    return dict(user=username,
-                command="modify",
-                machine=machine)
-
-def modify(username, state, path, fields):
-    """Handler for modifying attributes of a machine."""
-    try:
-        modify_dict = modifyDict(username, state, fields)
-    except InvalidInput, err:
-        result = None
-        machine = validation.Validate(username, state, machine_id=fields.getfirst('machine_id')).machine
-    else:
-        machine = modify_dict['machine']
-        result = 'Success!'
-        err = None
-    info_dict = infoDict(username, state, machine)
-    info_dict['err'] = err
-    if err:
-        for field in fields.keys():
-            setattr(info_dict['defaults'], field, fields.getfirst(field))
-    info_dict['result'] = result
-    return templates.info(searchList=[info_dict])
-
-def badOperation(u, s, p, e):
-    """Function called when accessing an unknown URI."""
-    return ({'Status': '404 Not Found'}, 'Invalid operation.')
+    return dict(machine=machine)
 
 def infoDict(username, state, machine):
     """Get the variables used by info.tmpl."""
@@ -624,7 +656,8 @@ def infoDict(username, state, machine):
     max_disk = validation.maxDisk(machine.owner, machine)
     defaults = Defaults()
     for name in 'machine_id name description administrator owner memory contact'.split():
-        setattr(defaults, name, getattr(machine, name))
+        if getattr(machine, name):
+            setattr(defaults, name, getattr(machine, name))
     defaults.type = machine.type.type_id
     defaults.disk = "%0.2f" % (machine.disks[0].size/1024.)
     checkpoint.checkpoint('Got defaults')
@@ -640,35 +673,6 @@ def infoDict(username, state, machine):
              fields = fields)
     return d
 
-def unauthFront(_, _2, _3, fields):
-    """Information for unauth'd users."""
-    return templates.unauth(searchList=[{'simple' : True, 
-            'hostname' : socket.getfqdn()}])
-
-def admin(username, state, path, fields):
-    if path == '':
-        return ({'Status': '303 See Other',
-                 'Location': 'admin/'},
-                "You shouldn't see this message.")
-    if not username in getAfsGroupMembers(config.adminacl, 'athena.mit.edu'):
-        raise InvalidInput('username', username,
-                           'Not in admin group %s.' % config.adminacl)
-    newstate = State(username, isadmin=True)
-    newstate.environ = state.environ
-    return handler(username, newstate, path, fields)
-
-mapping = dict(
-               modify=modify,
-               unauth=unauthFront,
-               admin=admin,
-               overlord=admin)
-
-def printHeaders(headers):
-    """Print a dictionary as HTTP headers."""
-    for key, value in headers.iteritems():
-        print '%s: %s' % (key, value)
-    print
-
 def send_error_mail(subject, body):
     import subprocess
 
@@ -685,98 +689,4 @@ Subject: %s
     p.stdin.close()
     p.wait()
 
-def show_error(op, username, fields, err, emsg, traceback):
-    """Print an error page when an exception occurs"""
-    d = dict(op=op, user=username, fields=fields,
-             errorMessage=str(err), stderr=emsg, traceback=traceback)
-    details = templates.error_raw(searchList=[d])
-    exclude = config.web.errormail_exclude
-    if username not in exclude and '*' not in exclude:
-        send_error_mail('xvm error on %s for %s: %s' % (op, username, err),
-                        details)
-    d['details'] = details
-    return templates.error(searchList=[d])
-
-def handler(username, state, path, fields):
-    operation, path = pathSplit(path)
-    if not operation:
-        operation = 'list'
-    print 'Starting', operation
-    fun = mapping.get(operation, badOperation)
-    return fun(username, state, path, fields)
-
-class App:
-    def __init__(self, environ, start_response):
-        self.environ = environ
-        self.start = start_response
-
-        self.username = getUser(environ)
-        self.state = State(self.username)
-        self.state.environ = environ
-
-        random.seed() #sigh
-
-    def __iter__(self):
-        start_time = time.time()
-        database.clear_cache()
-        sys.stderr = StringIO()
-        fields = cgi.FieldStorage(fp=self.environ['wsgi.input'], environ=self.environ)
-        operation = self.environ.get('PATH_INFO', '')
-        if not operation:
-            self.start("301 Moved Permanently", [('Location', './')])
-            return
-        if self.username is None:
-            operation = 'unauth'
-
-        try:
-            checkpoint.checkpoint('Before')
-            output = handler(self.username, self.state, operation, fields)
-            checkpoint.checkpoint('After')
-
-            headers = dict(DEFAULT_HEADERS)
-            if isinstance(output, tuple):
-                new_headers, output = output
-                headers.update(new_headers)
-            e = revertStandardError()
-            if e:
-                if hasattr(output, 'addError'):
-                    output.addError(e)
-                else:
-                    # This only happens on redirects, so it'd be a pain to get
-                    # the message to the user.  Maybe in the response is useful.
-                    output = output + '\n\nstderr:\n' + e
-            output_string =  str(output)
-            checkpoint.checkpoint('output as a string')
-        except Exception, err:
-            if not fields.has_key('js'):
-                if isinstance(err, InvalidInput):
-                    self.start('200 OK', [('Content-Type', 'text/html')])
-                    e = revertStandardError()
-                    yield str(invalidInput(operation, self.username, fields,
-                                           err, e))
-                    return
-            import traceback
-            self.start('500 Internal Server Error',
-                       [('Content-Type', 'text/html')])
-            e = revertStandardError()
-            s = show_error(operation, self.username, fields,
-                           err, e, traceback.format_exc())
-            yield str(s)
-            return
-        status = headers.setdefault('Status', '200 OK')
-        del headers['Status']
-        self.start(status, headers.items())
-        yield output_string
-        if fields.has_key('timedebug'):
-            yield '<pre>%s</pre>' % cgi.escape(str(checkpoint))
-
-def constructor():
-    connect()
-    return App
-
-def main():
-    from flup.server.fcgi_fork import WSGIServer
-    WSGIServer(constructor()).run()
-
-if __name__ == '__main__':
-    main()
+random.seed() #sigh