Changed VNC proxy to spew to stdout instead of an arbitrary logfile so that it will...
[invirt/packages/invirt-vnc-server.git] / code / vncexternalauth.py
index 9f0edec..30e89e1 100644 (file)
@@ -48,14 +48,12 @@ class VNCAuthOutgoing(protocol.Protocol):
         self.socks.write(data)
 
     def write(self,data):
-        #self.socks.log(self,data)
         self.transport.write(data)
 
 
 class VNCAuth(protocol.Protocol):
     
-    def __init__(self,logging=None,server="localhost"):
-        self.logging=logging
+    def __init__(self,server="localhost"):
         self.server=server
         self.auth=None
     
@@ -65,12 +63,12 @@ class VNCAuth(protocol.Protocol):
 
     def validateToken(self, token):
         global TOKEN_KEY
+        self.auth_error = "Invalid token"
         try:
             token = base64.urlsafe_b64decode(token)
             token = cPickle.loads(token)
             m = hmac.new(TOKEN_KEY, digestmod=sha)
             m.update(token['data'])
-            self.auth_error = "Invalid token"
             if (m.digest() == token['digest']):
                 data = cPickle.loads(token['data'])
                 expires = data["expires"]
@@ -81,8 +79,8 @@ class VNCAuth(protocol.Protocol):
                     self.auth_data = data
                 else:
                     self.auth_error = "Token has expired; please try logging in again"
-        except:
-            self.auth = None
+        except (TypeError, cPickle.UnpicklingError):
+            self.auth = None            
             print sys.exc_info()
 
     def dataReceived(self,data):
@@ -99,7 +97,7 @@ class VNCAuth(protocol.Protocol):
                 try:
                     (header, data) = line.split(": ", 1)
                     headers[header] = data
-                except:
+                except ValueError:
                     pass
 
             if command == "AUTHTOKEN":
@@ -114,24 +112,22 @@ class VNCAuth(protocol.Protocol):
                 vmname = args[0]
                 if ("Auth-token" in headers):
                     token = headers["Auth-token"]
-                    try:
-                        self.validateToken(token)
-                    finally:
-                        if self.auth is not None:
-                            port = getPort(vmname, self.auth_data)
-                            if port is not None: # FIXME
-                                if port is not 0:
-                                    d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
-                                    d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
-                                else:
-                                    self.makeReply(404, "Unable to find VNC for VM "+vmname)
+                    self.validateToken(token)
+                    if self.auth is not None:
+                        port = getPort(vmname, self.auth_data)
+                        if port is not None: # FIXME
+                            if port != 0:
+                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
+                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
                             else:
-                                self.makeReply(401, "Unauthorized to connect to VM "+vmname)
+                                self.makeReply(404, "Unable to find VNC for VM "+vmname)
                         else:
-                            if self.auth_error:
-                                self.makeReply(401, self.auth_error)
-                            else:
-                                self.makeReply(401, "Invalid token")
+                            self.makeReply(401, "Unauthorized to connect to VM "+vmname)
+                    else:
+                        if self.auth_error:
+                            self.makeReply(401, self.auth_error)
+                        else:
+                            self.makeReply(401, "Invalid token")
                 else:
                     self.makeReply(401, "Login first")
             else:
@@ -176,28 +172,24 @@ class VNCAuth(protocol.Protocol):
         if int(reply / 100)!=2: self.transport.loseConnection()
 
     def write(self,data):
-        #self.log(self,data)
         self.transport.write(data)
 
     def log(self,proto,data):
-        if not self.logging: return
         peer = self.transport.getPeer()
         their_peer = self.otherConn.transport.getPeer()
-        f=open(self.logging,"a")
-        f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
+        print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
                                         peer.host,peer.port,
                                         ((proto==self and '<') or '>'),
-                                        their_peer.host,their_peer.port))
+                                        their_peer.host,their_peer.port),
         while data:
             p,data=data[:16],data[16:]
-            f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
-            f.write((16-len(p))*3*' ')
+            print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
+            print ((16-len(p))*3*' '),
             for c in p:
-                if len(repr(c))>3: f.write('.')
-                else: f.write(c)
-            f.write('\n')
-        f.write('\n')
-        f.close()
+                if len(repr(c))>3: print '.',
+                else: print c,
+            print ""
+        print ""
 
 
 class VNCAuthFactory(protocol.Factory):
@@ -206,10 +198,9 @@ class VNCAuthFactory(protocol.Factory):
     Constructor accepts one argument, a log file name.
     """
     
-    def __init__(self, log, server):
-        self.logging = log
+    def __init__(self, server):
         self.server = server
     
     def buildProtocol(self, addr):
-        return VNCAuth(self.logging, self.server)
+        return VNCAuth(self.server)