Another silly error.
[invirt/packages/invirt-vnc-server.git] / python / vnc / extauth.py
1 """
2 Wrapper for Invirt VNC proxying
3 """
4
5 # twisted imports
6 from twisted.internet import reactor, protocol, defer
7 from twisted.python import log
8
9 # python imports
10 import sys
11 import struct
12 import string
13 import cPickle
14 # Python 2.5:
15 #import hashlib
16 import sha
17 import hmac
18 import base64
19 import socket
20 import time
21
22 def getTokenKey():
23     return file('/etc/invirt/vnc/token-key').read().strip()
24
25 def getPort(name, auth_data):
26     import get_port
27     if (auth_data["machine"] == name):
28         port = get_port.findPort(name)
29         if port is None:
30             return 0
31         return int(port.split(':')[1])
32     else:
33         return None
34     
35 class VNCAuthOutgoing(protocol.Protocol):
36     
37     def __init__(self,socks):
38         self.socks=socks
39
40     def connectionMade(self):
41         peer = self.transport.getPeer()
42         self.socks.makeReply(200)
43         self.socks.otherConn=self
44
45     def connectionLost(self, reason):
46         self.socks.transport.loseConnection()
47
48     def dataReceived(self,data):
49         self.socks.write(data)
50
51     def write(self,data):
52         self.transport.write(data)
53
54
55 class VNCAuth(protocol.Protocol):
56     
57     def __init__(self,server="localhost"):
58         self.server=server
59         self.auth=None
60     
61     def connectionMade(self):
62         self.buf=""
63         self.otherConn=None
64
65     def validateToken(self, token):
66         self.auth_error = "Invalid token"
67         try:
68             (pickled_data, digest) = map(base64.urlsafe_b64decode, token.split("."))
69             m = hmac.new(getTokenKey(), digestmod=sha)
70             m.update(pickled_data)
71             if (m.digest() == digest):
72                 data = cPickle.loads(pickled_data)
73                 expires = data["expires"]
74                 if (time.time() < expires):
75                     self.auth = data["user"]
76                     self.auth_error = None
77                     self.auth_machine = data["machine"]
78                     self.auth_data = data
79                 else:
80                     self.auth_error = "Token has expired; please try logging in again"
81         except (TypeError, ValueError, cPickle.UnpicklingError):
82             self.auth = None            
83             print sys.exc_info()
84
85     def dataReceived(self,data):
86         if self.otherConn:
87             self.otherConn.write(data)
88             return
89         self.buf=self.buf+data
90         if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
91             lines = self.buf.splitlines()
92             args = lines.pop(0).split()
93             command = args.pop(0)
94             headers = {}
95             for line in lines:
96                 try:
97                     (header, data) = line.split(": ", 1)
98                     headers[header] = data
99                 except ValueError:
100                     pass
101
102             if command == "AUTHTOKEN":
103                 user = args[0]
104                 token = headers["Auth-token"]
105                 if token == "1": #FIXME
106                     self.auth = user
107                     self.makeReply(200, "Authentication successful")
108                 else:
109                     self.makeReply(401)
110             elif command == "CONNECTVNC":
111                 vmname = args[0]
112                 if ("Auth-token" in headers):
113                     token = headers["Auth-token"]
114                     self.validateToken(token)
115                     if self.auth is not None:
116                         port = getPort(vmname, self.auth_data)
117                         if port is not None: # FIXME
118                             if port != 0:
119                                 d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
120                                 d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
121                             else:
122                                 self.makeReply(404, "Unable to find VNC for VM "+vmname)
123                         else:
124                             self.makeReply(401, "Unauthorized to connect to VM "+vmname)
125                     else:
126                         if self.auth_error:
127                             self.makeReply(401, self.auth_error)
128                         else:
129                             self.makeReply(401, "Invalid token")
130                 else:
131                     self.makeReply(401, "Login first")
132             else:
133                 self.makeReply(501, "unknown method "+command)
134             self.buf=''
135         if False and '\000' in self.buf[8:]:
136             head,self.buf=self.buf[:8],self.buf[8:]
137             try:
138                 version,code,port=struct.unpack("!BBH",head[:4])
139             except struct.error:
140                 raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
141             user,self.buf=string.split(self.buf,"\000",1)
142             if head[4:7]=="\000\000\000": # domain is after
143                 server,self.buf=string.split(self.buf,'\000',1)
144                 #server=gethostbyname(server)
145             else:
146                 server=socket.inet_ntoa(head[4:8])
147             assert version==4, "Bad version code: %s"%version
148             if not self.authorize(code,server,port,user):
149                 self.makeReply(91)
150                 return
151             if code==1: # CONNECT
152                 d = self.connectClass(server, port, SOCKSv4Outgoing, self)
153                 d.addErrback(lambda result, self=self: self.makeReply(91))
154             else:
155                 raise RuntimeError, "Bad Connect Code: %s" % code
156             assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
157
158     def connectionLost(self, reason):
159         if self.otherConn:
160             self.otherConn.transport.loseConnection()
161
162     def authorize(self,code,server,port,user):
163         log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
164         return 1
165
166     def connectClass(self, host, port, klass, *args):
167         return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
168
169     def makeReply(self,reply,message=""):
170         self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
171         if int(reply / 100)!=2: self.transport.loseConnection()
172
173     def write(self,data):
174         self.transport.write(data)
175
176     def log(self,proto,data):
177         peer = self.transport.getPeer()
178         their_peer = self.otherConn.transport.getPeer()
179         print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
180                                         peer.host,peer.port,
181                                         ((proto==self and '<') or '>'),
182                                         their_peer.host,their_peer.port),
183         while data:
184             p,data=data[:16],data[16:]
185             print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
186             print ((16-len(p))*3*' '),
187             for c in p:
188                 if len(repr(c))>3: print '.',
189                 else: print c,
190             print ""
191         print ""
192
193
194 class VNCAuthFactory(protocol.Factory):
195     """A factory for a VNC auth proxy.
196     
197     Constructor accepts one argument, a log file name.
198     """
199     
200     def __init__(self, server):
201         self.server = server
202     
203     def buildProtocol(self, addr):
204         return VNCAuth(self.server)
205