2 Wrapper for Invirt VNC proxying
6 from twisted.internet import reactor, protocol, defer
7 from twisted.python import log
23 return file('/etc/invirt/vnc/token-key').read().strip()
25 def getPort(name, auth_data):
27 if (auth_data["machine"] == name):
28 port = get_port.findPort(name)
31 return int(port.split(':')[1])
35 class VNCAuthOutgoing(protocol.Protocol):
37 def __init__(self,socks):
40 def connectionMade(self):
41 peer = self.transport.getPeer()
42 self.socks.makeReply(200)
43 self.socks.otherConn=self
45 def connectionLost(self, reason):
46 self.socks.transport.loseConnection()
48 def dataReceived(self,data):
49 self.socks.write(data)
52 self.transport.write(data)
55 class VNCAuth(protocol.Protocol):
57 def __init__(self,server="localhost"):
61 def connectionMade(self):
65 def validateToken(self, token):
66 self.auth_error = "Invalid token"
68 (pickled_data, digest) = map(base64.urlsafe_b64decode, token.split("."))
69 m = hmac.new(getTokenKey(), digestmod=sha)
70 m.update(pickled_data)
71 if (m.digest() == digest):
72 data = cPickle.loads(pickled_data)
73 expires = data["expires"]
74 if (time.time() < expires):
75 self.auth = data["user"]
76 self.auth_error = None
77 self.auth_machine = data["machine"]
80 self.auth_error = "Token has expired; please try logging in again"
81 except (TypeError, ValueError, cPickle.UnpicklingError):
85 def dataReceived(self,data):
87 self.otherConn.write(data)
89 self.buf=self.buf+data
90 if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
91 lines = self.buf.splitlines()
92 args = lines.pop(0).split()
97 (header, data) = line.split(": ", 1)
98 headers[header] = data
102 if command == "AUTHTOKEN":
104 token = headers["Auth-token"]
105 if token == "1": #FIXME
107 self.makeReply(200, "Authentication successful")
110 elif command == "CONNECTVNC":
112 if ("Auth-token" in headers):
113 token = headers["Auth-token"]
114 self.validateToken(token)
115 if self.auth is not None:
116 port = getPort(vmname, self.auth_data)
117 if port is not None: # FIXME
119 d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
120 d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
122 self.makeReply(404, "Unable to find VNC for VM "+vmname)
124 self.makeReply(401, "Unauthorized to connect to VM "+vmname)
127 self.makeReply(401, self.auth_error)
129 self.makeReply(401, "Invalid token")
131 self.makeReply(401, "Login first")
133 self.makeReply(501, "unknown method "+command)
135 if False and '\000' in self.buf[8:]:
136 head,self.buf=self.buf[:8],self.buf[8:]
138 version,code,port=struct.unpack("!BBH",head[:4])
140 raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
141 user,self.buf=string.split(self.buf,"\000",1)
142 if head[4:7]=="\000\000\000": # domain is after
143 server,self.buf=string.split(self.buf,'\000',1)
144 #server=gethostbyname(server)
146 server=socket.inet_ntoa(head[4:8])
147 assert version==4, "Bad version code: %s"%version
148 if not self.authorize(code,server,port,user):
151 if code==1: # CONNECT
152 d = self.connectClass(server, port, SOCKSv4Outgoing, self)
153 d.addErrback(lambda result, self=self: self.makeReply(91))
155 raise RuntimeError, "Bad Connect Code: %s" % code
156 assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
158 def connectionLost(self, reason):
160 self.otherConn.transport.loseConnection()
162 def authorize(self,code,server,port,user):
163 log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
166 def connectClass(self, host, port, klass, *args):
167 return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
169 def makeReply(self,reply,message=""):
170 self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
171 if int(reply / 100)!=2: self.transport.loseConnection()
173 def write(self,data):
174 self.transport.write(data)
176 def log(self,proto,data):
177 peer = self.transport.getPeer()
178 their_peer = self.otherConn.transport.getPeer()
179 print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
181 ((proto==self and '<') or '>'),
182 their_peer.host,their_peer.port),
184 p,data=data[:16],data[16:]
185 print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
186 print ((16-len(p))*3*' '),
188 if len(repr(c))>3: print '.',
194 class VNCAuthFactory(protocol.Factory):
195 """A factory for a VNC auth proxy.
197 Constructor accepts one argument, a log file name.
200 def __init__(self, server):
203 def buildProtocol(self, addr):
204 return VNCAuth(self.server)