VNC server commit.
[invirt/packages/invirt-vnc-server.git] / vncexternalauth.py
1 """
2 Wrapper for sipb-xen VNC proxying
3 """
4
5 # twisted imports
6 from twisted.internet import reactor, protocol, defer
7 from twisted.python import log
8
9 # python imports
10 import struct
11 import string
12 import cPickle
13 # Python 2.5:
14 #import hashlib
15 import sha
16 import hmac
17 import base64
18 import socket
19 import time
20 import get_port
21
22 TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
23
24 def getPort(name, auth):
25     port = get_port.findPort(name)
26     if port is None:
27         return 0
28     return int(port.split(':')[1])
29
30 class VNCAuthOutgoing(protocol.Protocol):
31     
32     def __init__(self,socks):
33         self.socks=socks
34
35     def connectionMade(self):
36         peer = self.transport.getPeer()
37         self.socks.makeReply(200)
38         self.socks.otherConn=self
39
40     def connectionLost(self, reason):
41         self.socks.transport.loseConnection()
42
43     def dataReceived(self,data):
44         self.socks.write(data)
45
46     def write(self,data):
47         #self.socks.log(self,data)
48         self.transport.write(data)
49
50
51 class VNCAuth(protocol.Protocol):
52     
53     def __init__(self,logging=None,server="localhost"):
54         self.logging=logging
55         self.server=server
56         self.auth=None
57     
58     def connectionMade(self):
59         self.buf=""
60         self.otherConn=None
61
62     def validateToken(self, token):
63         global TOKEN_KEY
64         if token == "quentin":
65             self.auth = "quentin@ATHENA.MIT.EDU"
66             return #FIXME
67         token = base64.urlsafe_b64decode(token)
68         token = cPickle.load(token)
69         m = hmac.new(TOKEN_KEY, digestmod=sha)
70         m.update(token['data'])
71         if (m.digest() == token['digest']):
72             data = cPickle.load(token['data'])
73             expires = data["expires"]
74             if (time.time() < expires):
75                 self.auth = data["user"]
76
77     def dataReceived(self,data):
78         if self.otherConn:
79             self.otherConn.write(data)
80             return
81         self.buf=self.buf+data
82         if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
83             lines = self.buf.splitlines()
84             args = lines.pop(0).split()
85             command = args.pop(0)
86             headers = {}
87             for line in lines:
88                 try:
89                     (header, data) = line.split(": ", 1)
90                     headers[header] = data
91                 except:
92                     pass
93
94             if command == "AUTHTOKEN":
95                 user = args[0]
96                 token = headers["Auth-token"]
97                 if token == "1": #FIXME
98                     self.auth = user
99                     self.makeReply(200, "Authentication successful")
100                 else:
101                     self.makeReply(401)
102             elif command == "CONNECTVNC":
103                 vmname = args[0]
104                 if ("Auth-token" in headers):
105                     token = headers["Auth-token"]
106                     try:
107                         self.validateToken(token)
108                     finally:
109                         if self.auth is not None:
110                             port = getPort(vmname, self.auth)
111                             if port is not None: # FIXME
112                                 d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
113                                 d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
114                             else:
115                                 self.makeReply(401, "Unauthorized to connect to VM "+vmname)
116                         else:
117                             self.makeReply(401, "Invalid token")
118                 else:
119                     self.makeReply(401, "Login first")
120             else:
121                 self.makeReply(501, "unknown method "+command)
122             self.buf=''
123         if False and '\000' in self.buf[8:]:
124             head,self.buf=self.buf[:8],self.buf[8:]
125             try:
126                 version,code,port=struct.unpack("!BBH",head[:4])
127             except struct.error:
128                 raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
129             user,self.buf=string.split(self.buf,"\000",1)
130             if head[4:7]=="\000\000\000": # domain is after
131                 server,self.buf=string.split(self.buf,'\000',1)
132                 #server=gethostbyname(server)
133             else:
134                 server=socket.inet_ntoa(head[4:8])
135             assert version==4, "Bad version code: %s"%version
136             if not self.authorize(code,server,port,user):
137                 self.makeReply(91)
138                 return
139             if code==1: # CONNECT
140                 d = self.connectClass(server, port, SOCKSv4Outgoing, self)
141                 d.addErrback(lambda result, self=self: self.makeReply(91))
142             else:
143                 raise RuntimeError, "Bad Connect Code: %s" % code
144             assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
145
146     def connectionLost(self, reason):
147         if self.otherConn:
148             self.otherConn.transport.loseConnection()
149
150     def authorize(self,code,server,port,user):
151         log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
152         return 1
153
154     def connectClass(self, host, port, klass, *args):
155         return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
156
157     def makeReply(self,reply,message=""):
158         self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
159         if int(reply / 100)!=2: self.transport.loseConnection()
160
161     def write(self,data):
162         #self.log(self,data)
163         self.transport.write(data)
164
165     def log(self,proto,data):
166         if not self.logging: return
167         peer = self.transport.getPeer()
168         their_peer = self.otherConn.transport.getPeer()
169         f=open(self.logging,"a")
170         f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
171                                         peer.host,peer.port,
172                                         ((proto==self and '<') or '>'),
173                                         their_peer.host,their_peer.port))
174         while data:
175             p,data=data[:16],data[16:]
176             f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
177             f.write((16-len(p))*3*' ')
178             for c in p:
179                 if len(repr(c))>3: f.write('.')
180                 else: f.write(c)
181             f.write('\n')
182         f.write('\n')
183         f.close()
184
185
186 class VNCAuthFactory(protocol.Factory):
187     """A factory for a VNC auth proxy.
188     
189     Constructor accepts one argument, a log file name.
190     """
191     
192     def __init__(self, log, server):
193         self.logging = log
194         self.server = server
195     
196     def buildProtocol(self, addr):
197         return VNCAuth(self.logging, self.server)
198