Correctly verify authentication tokens, and disable backdoor
[invirt/packages/invirt-vnc-server.git] / vncexternalauth.py
index 75a4170..9f0edec 100644 (file)
@@ -7,6 +7,7 @@ from twisted.internet import reactor, protocol, defer
 from twisted.python import log
 
 # python imports
+import sys
 import struct
 import string
 import cPickle
@@ -21,12 +22,15 @@ import get_port
 
 TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
 
-def getPort(name, auth):
-    port = get_port.findPort(name)
-    if port is None:
-        return 0
-    return int(port.split(':')[1])
-
+def getPort(name, auth_data):
+    if (auth_data["machine"] == name):
+        port = get_port.findPort(name)
+        if port is None:
+            return 0
+        return int(port.split(':')[1])
+    else:
+        return None
+    
 class VNCAuthOutgoing(protocol.Protocol):
     
     def __init__(self,socks):
@@ -61,18 +65,25 @@ class VNCAuth(protocol.Protocol):
 
     def validateToken(self, token):
         global TOKEN_KEY
-        if token == "quentin":
-            self.auth = "quentin@ATHENA.MIT.EDU"
-            return #FIXME
-        token = base64.urlsafe_b64decode(token)
-        token = cPickle.load(token)
-        m = hmac.new(TOKEN_KEY, digestmod=sha)
-        m.update(token['data'])
-        if (m.digest() == token['digest']):
-            data = cPickle.load(token['data'])
-            expires = data["expires"]
-            if (time.time() < expires):
-                self.auth = data["user"]
+        try:
+            token = base64.urlsafe_b64decode(token)
+            token = cPickle.loads(token)
+            m = hmac.new(TOKEN_KEY, digestmod=sha)
+            m.update(token['data'])
+            self.auth_error = "Invalid token"
+            if (m.digest() == token['digest']):
+                data = cPickle.loads(token['data'])
+                expires = data["expires"]
+                if (time.time() < expires):
+                    self.auth = data["user"]
+                    self.auth_error = None
+                    self.auth_machine = data["machine"]
+                    self.auth_data = data
+                else:
+                    self.auth_error = "Token has expired; please try logging in again"
+        except:
+            self.auth = None
+            print sys.exc_info()
 
     def dataReceived(self,data):
         if self.otherConn:
@@ -107,14 +118,20 @@ class VNCAuth(protocol.Protocol):
                         self.validateToken(token)
                     finally:
                         if self.auth is not None:
-                            port = getPort(vmname, self.auth)
+                            port = getPort(vmname, self.auth_data)
                             if port is not None: # FIXME
-                                d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
-                                d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
+                                if port is not 0:
+                                    d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
+                                    d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
+                                else:
+                                    self.makeReply(404, "Unable to find VNC for VM "+vmname)
                             else:
                                 self.makeReply(401, "Unauthorized to connect to VM "+vmname)
                         else:
-                            self.makeReply(401, "Invalid token")
+                            if self.auth_error:
+                                self.makeReply(401, self.auth_error)
+                            else:
+                                self.makeReply(401, "Invalid token")
                 else:
                     self.makeReply(401, "Login first")
             else: