Generate the VNC server certificates at install time
[invirt/packages/invirt-vnc-server.git] / python / vnc / extauth.py
1 """
2 Wrapper for Invirt VNC proxying
3 """
4
5 # twisted imports
6 from twisted.internet import reactor, protocol, defer
7 from twisted.python import log
8
9 # python imports
10 import sys
11 import struct
12 import string
13 import cPickle
14 # Python 2.5:
15 #import hashlib
16 import sha
17 import hmac
18 import base64
19 import socket
20 import time
21
22 def getTokenKey():
23     token_key = file('/etc/invirt/secrets/vnc-key').read().strip()
24     while True:
25         yield token_key
26 getTokenKey = getTokenKey().next
27
28 def getPort(name, auth_data):
29     import get_port
30     if (auth_data["machine"] == name):
31         port = get_port.findPort(name)
32         if port is None:
33             return 0
34         return int(port.split(':')[1])
35     else:
36         return None
37     
38 class VNCAuthOutgoing(protocol.Protocol):
39     
40     def __init__(self,socks):
41         self.socks=socks
42
43     def connectionMade(self):
44         peer = self.transport.getPeer()
45         self.socks.makeReply(200)
46         self.socks.otherConn=self
47
48     def connectionLost(self, reason):
49         self.socks.transport.loseConnection()
50
51     def dataReceived(self,data):
52         self.socks.write(data)
53
54     def write(self,data):
55         self.transport.write(data)
56
57
58 class VNCAuth(protocol.Protocol):
59     
60     def __init__(self,server="localhost"):
61         self.server=server
62         self.auth=None
63     
64     def connectionMade(self):
65         self.buf=""
66         self.otherConn=None
67
68     def validateToken(self, token):
69         self.auth_error = "Invalid token"
70         try:
71             token = base64.urlsafe_b64decode(token)
72             token = cPickle.loads(token)
73             m = hmac.new(getTokenKey(), digestmod=sha)
74             m.update(token['data'])
75             if (m.digest() == token['digest']):
76                 data = cPickle.loads(token['data'])
77                 expires = data["expires"]
78                 if (time.time() < expires):
79                     self.auth = data["user"]
80                     self.auth_error = None
81                     self.auth_machine = data["machine"]
82                     self.auth_data = data
83                 else:
84                     self.auth_error = "Token has expired; please try logging in again"
85         except (TypeError, cPickle.UnpicklingError):
86             self.auth = None            
87             print sys.exc_info()
88
89     def dataReceived(self,data):
90         if self.otherConn:
91             self.otherConn.write(data)
92             return
93         self.buf=self.buf+data
94         if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
95             lines = self.buf.splitlines()
96             args = lines.pop(0).split()
97             command = args.pop(0)
98             headers = {}
99             for line in lines:
100                 try:
101                     (header, data) = line.split(": ", 1)
102                     headers[header] = data
103                 except ValueError:
104                     pass
105
106             if command == "AUTHTOKEN":
107                 user = args[0]
108                 token = headers["Auth-token"]
109                 if token == "1": #FIXME
110                     self.auth = user
111                     self.makeReply(200, "Authentication successful")
112                 else:
113                     self.makeReply(401)
114             elif command == "CONNECTVNC":
115                 vmname = args[0]
116                 if ("Auth-token" in headers):
117                     token = headers["Auth-token"]
118                     self.validateToken(token)
119                     if self.auth is not None:
120                         port = getPort(vmname, self.auth_data)
121                         if port is not None: # FIXME
122                             if port != 0:
123                                 d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
124                                 d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
125                             else:
126                                 self.makeReply(404, "Unable to find VNC for VM "+vmname)
127                         else:
128                             self.makeReply(401, "Unauthorized to connect to VM "+vmname)
129                     else:
130                         if self.auth_error:
131                             self.makeReply(401, self.auth_error)
132                         else:
133                             self.makeReply(401, "Invalid token")
134                 else:
135                     self.makeReply(401, "Login first")
136             else:
137                 self.makeReply(501, "unknown method "+command)
138             self.buf=''
139         if False and '\000' in self.buf[8:]:
140             head,self.buf=self.buf[:8],self.buf[8:]
141             try:
142                 version,code,port=struct.unpack("!BBH",head[:4])
143             except struct.error:
144                 raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
145             user,self.buf=string.split(self.buf,"\000",1)
146             if head[4:7]=="\000\000\000": # domain is after
147                 server,self.buf=string.split(self.buf,'\000',1)
148                 #server=gethostbyname(server)
149             else:
150                 server=socket.inet_ntoa(head[4:8])
151             assert version==4, "Bad version code: %s"%version
152             if not self.authorize(code,server,port,user):
153                 self.makeReply(91)
154                 return
155             if code==1: # CONNECT
156                 d = self.connectClass(server, port, SOCKSv4Outgoing, self)
157                 d.addErrback(lambda result, self=self: self.makeReply(91))
158             else:
159                 raise RuntimeError, "Bad Connect Code: %s" % code
160             assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
161
162     def connectionLost(self, reason):
163         if self.otherConn:
164             self.otherConn.transport.loseConnection()
165
166     def authorize(self,code,server,port,user):
167         log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
168         return 1
169
170     def connectClass(self, host, port, klass, *args):
171         return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
172
173     def makeReply(self,reply,message=""):
174         self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
175         if int(reply / 100)!=2: self.transport.loseConnection()
176
177     def write(self,data):
178         self.transport.write(data)
179
180     def log(self,proto,data):
181         peer = self.transport.getPeer()
182         their_peer = self.otherConn.transport.getPeer()
183         print "%s\t%s:%d %s %s:%d\n"%(time.ctime(),
184                                         peer.host,peer.port,
185                                         ((proto==self and '<') or '>'),
186                                         their_peer.host,their_peer.port),
187         while data:
188             p,data=data[:16],data[16:]
189             print string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ',
190             print ((16-len(p))*3*' '),
191             for c in p:
192                 if len(repr(c))>3: print '.',
193                 else: print c,
194             print ""
195         print ""
196
197
198 class VNCAuthFactory(protocol.Factory):
199     """A factory for a VNC auth proxy.
200     
201     Constructor accepts one argument, a log file name.
202     """
203     
204     def __init__(self, server):
205         self.server = server
206     
207     def buildProtocol(self, addr):
208         return VNCAuth(self.server)
209