--- /dev/null
+"""
+Wrapper for sipb-xen VNC proxying
+"""
+
+# twisted imports
+from twisted.internet import reactor, protocol, defer
+from twisted.python import log
+
+# python imports
+import struct
+import string
+import cPickle
+# Python 2.5:
+#import hashlib
+import sha
+import hmac
+import base64
+import socket
+import time
+import get_port
+
+TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
+
+def getPort(name, auth):
+ port = get_port.findPort(name)
+ if port is None:
+ return 0
+ return int(port.split(':')[1])
+
+class VNCAuthOutgoing(protocol.Protocol):
+
+ def __init__(self,socks):
+ self.socks=socks
+
+ def connectionMade(self):
+ peer = self.transport.getPeer()
+ self.socks.makeReply(200)
+ self.socks.otherConn=self
+
+ def connectionLost(self, reason):
+ self.socks.transport.loseConnection()
+
+ def dataReceived(self,data):
+ self.socks.write(data)
+
+ def write(self,data):
+ #self.socks.log(self,data)
+ self.transport.write(data)
+
+
+class VNCAuth(protocol.Protocol):
+
+ def __init__(self,logging=None,server="localhost"):
+ self.logging=logging
+ self.server=server
+ self.auth=None
+
+ def connectionMade(self):
+ self.buf=""
+ self.otherConn=None
+
+ def validateToken(self, token):
+ global TOKEN_KEY
+ if token == "quentin":
+ self.auth = "quentin@ATHENA.MIT.EDU"
+ return #FIXME
+ token = base64.urlsafe_b64decode(token)
+ token = cPickle.load(token)
+ m = hmac.new(TOKEN_KEY, digestmod=sha)
+ m.update(token['data'])
+ if (m.digest() == token['digest']):
+ data = cPickle.load(token['data'])
+ expires = data["expires"]
+ if (time.time() < expires):
+ self.auth = data["user"]
+
+ def dataReceived(self,data):
+ if self.otherConn:
+ self.otherConn.write(data)
+ return
+ self.buf=self.buf+data
+ if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
+ lines = self.buf.splitlines()
+ args = lines.pop(0).split()
+ command = args.pop(0)
+ headers = {}
+ for line in lines:
+ try:
+ (header, data) = line.split(": ", 1)
+ headers[header] = data
+ except:
+ pass
+
+ if command == "AUTHTOKEN":
+ user = args[0]
+ token = headers["Auth-token"]
+ if token == "1": #FIXME
+ self.auth = user
+ self.makeReply(200, "Authentication successful")
+ else:
+ self.makeReply(401)
+ elif command == "CONNECTVNC":
+ vmname = args[0]
+ if ("Auth-token" in headers):
+ token = headers["Auth-token"]
+ try:
+ self.validateToken(token)
+ finally:
+ if self.auth is not None:
+ port = getPort(vmname, self.auth)
+ if port is not None: # FIXME
+ d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
+ d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
+ else:
+ self.makeReply(401, "Unauthorized to connect to VM "+vmname)
+ else:
+ self.makeReply(401, "Invalid token")
+ else:
+ self.makeReply(401, "Login first")
+ else:
+ self.makeReply(501, "unknown method "+command)
+ self.buf=''
+ if False and '\000' in self.buf[8:]:
+ head,self.buf=self.buf[:8],self.buf[8:]
+ try:
+ version,code,port=struct.unpack("!BBH",head[:4])
+ except struct.error:
+ raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
+ user,self.buf=string.split(self.buf,"\000",1)
+ if head[4:7]=="\000\000\000": # domain is after
+ server,self.buf=string.split(self.buf,'\000',1)
+ #server=gethostbyname(server)
+ else:
+ server=socket.inet_ntoa(head[4:8])
+ assert version==4, "Bad version code: %s"%version
+ if not self.authorize(code,server,port,user):
+ self.makeReply(91)
+ return
+ if code==1: # CONNECT
+ d = self.connectClass(server, port, SOCKSv4Outgoing, self)
+ d.addErrback(lambda result, self=self: self.makeReply(91))
+ else:
+ raise RuntimeError, "Bad Connect Code: %s" % code
+ assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
+
+ def connectionLost(self, reason):
+ if self.otherConn:
+ self.otherConn.transport.loseConnection()
+
+ def authorize(self,code,server,port,user):
+ log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
+ return 1
+
+ def connectClass(self, host, port, klass, *args):
+ return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
+
+ def makeReply(self,reply,message=""):
+ self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
+ if int(reply / 100)!=2: self.transport.loseConnection()
+
+ def write(self,data):
+ #self.log(self,data)
+ self.transport.write(data)
+
+ def log(self,proto,data):
+ if not self.logging: return
+ peer = self.transport.getPeer()
+ their_peer = self.otherConn.transport.getPeer()
+ f=open(self.logging,"a")
+ f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
+ peer.host,peer.port,
+ ((proto==self and '<') or '>'),
+ their_peer.host,their_peer.port))
+ while data:
+ p,data=data[:16],data[16:]
+ f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
+ f.write((16-len(p))*3*' ')
+ for c in p:
+ if len(repr(c))>3: f.write('.')
+ else: f.write(c)
+ f.write('\n')
+ f.write('\n')
+ f.close()
+
+
+class VNCAuthFactory(protocol.Factory):
+ """A factory for a VNC auth proxy.
+
+ Constructor accepts one argument, a log file name.
+ """
+
+ def __init__(self, log, server):
+ self.logging = log
+ self.server = server
+
+ def buildProtocol(self, addr):
+ return VNCAuth(self.logging, self.server)
+