X-Git-Url: http://xvm.mit.edu/gitweb/invirt/packages/invirt-vnc-server.git/blobdiff_plain/14cf46b2e06a403c4238b415b2d762e747840d35..70b6236aa037ef3fd1c34cfef3454ec092976079:/code/vncexternalauth.py diff --git a/code/vncexternalauth.py b/code/vncexternalauth.py index 9f0edec..30e89e1 100644 --- a/code/vncexternalauth.py +++ b/code/vncexternalauth.py @@ -48,14 +48,12 @@ class VNCAuthOutgoing(protocol.Protocol): self.socks.write(data) def write(self,data): - #self.socks.log(self,data) self.transport.write(data) class VNCAuth(protocol.Protocol): - def __init__(self,logging=None,server="localhost"): - self.logging=logging + def __init__(self,server="localhost"): self.server=server self.auth=None @@ -65,12 +63,12 @@ class VNCAuth(protocol.Protocol): def validateToken(self, token): global TOKEN_KEY + self.auth_error = "Invalid token" try: token = base64.urlsafe_b64decode(token) token = cPickle.loads(token) m = hmac.new(TOKEN_KEY, digestmod=sha) m.update(token['data']) - self.auth_error = "Invalid token" if (m.digest() == token['digest']): data = cPickle.loads(token['data']) expires = data["expires"] @@ -81,8 +79,8 @@ class VNCAuth(protocol.Protocol): self.auth_data = data else: self.auth_error = "Token has expired; please try logging in again" - except: - self.auth = None + except (TypeError, cPickle.UnpicklingError): + self.auth = None print sys.exc_info() def dataReceived(self,data): @@ -99,7 +97,7 @@ class VNCAuth(protocol.Protocol): try: (header, data) = line.split(": ", 1) headers[header] = data - except: + except ValueError: pass if command == "AUTHTOKEN": @@ -114,24 +112,22 @@ class VNCAuth(protocol.Protocol): vmname = args[0] if ("Auth-token" in headers): token = headers["Auth-token"] - try: - self.validateToken(token) - finally: - if self.auth is not None: - port = getPort(vmname, self.auth_data) - if port is not None: # FIXME - if port is not 0: - d = self.connectClass(self.server, port, VNCAuthOutgoing, self) - d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage())) - else: - self.makeReply(404, "Unable to find VNC for VM "+vmname) + self.validateToken(token) + if self.auth is not None: + port = getPort(vmname, self.auth_data) + if port is not None: # FIXME + if port != 0: + d = self.connectClass(self.server, port, VNCAuthOutgoing, self) + d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage())) else: - self.makeReply(401, "Unauthorized to connect to VM "+vmname) + self.makeReply(404, "Unable to find VNC for VM "+vmname) else: - if self.auth_error: - self.makeReply(401, self.auth_error) - else: - self.makeReply(401, "Invalid token") + self.makeReply(401, "Unauthorized to connect to VM "+vmname) + else: + if self.auth_error: + self.makeReply(401, self.auth_error) + else: + self.makeReply(401, "Invalid token") else: self.makeReply(401, "Login first") else: @@ -176,28 +172,24 @@ class VNCAuth(protocol.Protocol): if int(reply / 100)!=2: self.transport.loseConnection() def write(self,data): - #self.log(self,data) self.transport.write(data) def log(self,proto,data): - if not self.logging: return peer = self.transport.getPeer() their_peer = self.otherConn.transport.getPeer() - f=open(self.logging,"a") - f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(), + print "%s\t%s:%d %s %s:%d\n"%(time.ctime(), peer.host,peer.port, ((proto==self and '<') or '>'), - their_peer.host,their_peer.port)) + their_peer.host,their_peer.port), while data: p,data=data[:16],data[16:] - f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ') - f.write((16-len(p))*3*' ') + print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ', + print ((16-len(p))*3*' '), for c in p: - if len(repr(c))>3: f.write('.') - else: f.write(c) - f.write('\n') - f.write('\n') - f.close() + if len(repr(c))>3: print '.', + else: print c, + print "" + print "" class VNCAuthFactory(protocol.Factory): @@ -206,10 +198,9 @@ class VNCAuthFactory(protocol.Factory): Constructor accepts one argument, a log file name. """ - def __init__(self, log, server): - self.logging = log + def __init__(self, server): self.server = server def buildProtocol(self, addr): - return VNCAuth(self.logging, self.server) + return VNCAuth(self.server)