Correctly verify authentication tokens, and disable backdoor
[invirt/packages/invirt-vnc-server.git] / vncexternalauth.py
1 """
2 Wrapper for sipb-xen VNC proxying
3 """
4
5 # twisted imports
6 from twisted.internet import reactor, protocol, defer
7 from twisted.python import log
8
9 # python imports
10 import sys
11 import struct
12 import string
13 import cPickle
14 # Python 2.5:
15 #import hashlib
16 import sha
17 import hmac
18 import base64
19 import socket
20 import time
21 import get_port
22
23 TOKEN_KEY = "0M6W0U1IXexThi5idy8mnkqPKEq1LtEnlK/pZSn0cDrN"
24
25 def getPort(name, auth_data):
26     if (auth_data["machine"] == name):
27         port = get_port.findPort(name)
28         if port is None:
29             return 0
30         return int(port.split(':')[1])
31     else:
32         return None
33     
34 class VNCAuthOutgoing(protocol.Protocol):
35     
36     def __init__(self,socks):
37         self.socks=socks
38
39     def connectionMade(self):
40         peer = self.transport.getPeer()
41         self.socks.makeReply(200)
42         self.socks.otherConn=self
43
44     def connectionLost(self, reason):
45         self.socks.transport.loseConnection()
46
47     def dataReceived(self,data):
48         self.socks.write(data)
49
50     def write(self,data):
51         #self.socks.log(self,data)
52         self.transport.write(data)
53
54
55 class VNCAuth(protocol.Protocol):
56     
57     def __init__(self,logging=None,server="localhost"):
58         self.logging=logging
59         self.server=server
60         self.auth=None
61     
62     def connectionMade(self):
63         self.buf=""
64         self.otherConn=None
65
66     def validateToken(self, token):
67         global TOKEN_KEY
68         try:
69             token = base64.urlsafe_b64decode(token)
70             token = cPickle.loads(token)
71             m = hmac.new(TOKEN_KEY, digestmod=sha)
72             m.update(token['data'])
73             self.auth_error = "Invalid token"
74             if (m.digest() == token['digest']):
75                 data = cPickle.loads(token['data'])
76                 expires = data["expires"]
77                 if (time.time() < expires):
78                     self.auth = data["user"]
79                     self.auth_error = None
80                     self.auth_machine = data["machine"]
81                     self.auth_data = data
82                 else:
83                     self.auth_error = "Token has expired; please try logging in again"
84         except:
85             self.auth = None
86             print sys.exc_info()
87
88     def dataReceived(self,data):
89         if self.otherConn:
90             self.otherConn.write(data)
91             return
92         self.buf=self.buf+data
93         if ('\r\n\r\n' in self.buf) or ('\n\n' in self.buf) or ('\r\r' in self.buf):
94             lines = self.buf.splitlines()
95             args = lines.pop(0).split()
96             command = args.pop(0)
97             headers = {}
98             for line in lines:
99                 try:
100                     (header, data) = line.split(": ", 1)
101                     headers[header] = data
102                 except:
103                     pass
104
105             if command == "AUTHTOKEN":
106                 user = args[0]
107                 token = headers["Auth-token"]
108                 if token == "1": #FIXME
109                     self.auth = user
110                     self.makeReply(200, "Authentication successful")
111                 else:
112                     self.makeReply(401)
113             elif command == "CONNECTVNC":
114                 vmname = args[0]
115                 if ("Auth-token" in headers):
116                     token = headers["Auth-token"]
117                     try:
118                         self.validateToken(token)
119                     finally:
120                         if self.auth is not None:
121                             port = getPort(vmname, self.auth_data)
122                             if port is not None: # FIXME
123                                 if port is not 0:
124                                     d = self.connectClass(self.server, port, VNCAuthOutgoing, self)
125                                     d.addErrback(lambda result, self=self: self.makeReply(404, result.getErrorMessage()))
126                                 else:
127                                     self.makeReply(404, "Unable to find VNC for VM "+vmname)
128                             else:
129                                 self.makeReply(401, "Unauthorized to connect to VM "+vmname)
130                         else:
131                             if self.auth_error:
132                                 self.makeReply(401, self.auth_error)
133                             else:
134                                 self.makeReply(401, "Invalid token")
135                 else:
136                     self.makeReply(401, "Login first")
137             else:
138                 self.makeReply(501, "unknown method "+command)
139             self.buf=''
140         if False and '\000' in self.buf[8:]:
141             head,self.buf=self.buf[:8],self.buf[8:]
142             try:
143                 version,code,port=struct.unpack("!BBH",head[:4])
144             except struct.error:
145                 raise RuntimeError, "struct error with head='%s' and buf='%s'"%(repr(head),repr(self.buf))
146             user,self.buf=string.split(self.buf,"\000",1)
147             if head[4:7]=="\000\000\000": # domain is after
148                 server,self.buf=string.split(self.buf,'\000',1)
149                 #server=gethostbyname(server)
150             else:
151                 server=socket.inet_ntoa(head[4:8])
152             assert version==4, "Bad version code: %s"%version
153             if not self.authorize(code,server,port,user):
154                 self.makeReply(91)
155                 return
156             if code==1: # CONNECT
157                 d = self.connectClass(server, port, SOCKSv4Outgoing, self)
158                 d.addErrback(lambda result, self=self: self.makeReply(91))
159             else:
160                 raise RuntimeError, "Bad Connect Code: %s" % code
161             assert self.buf=="","hmm, still stuff in buffer... %s" % repr(self.buf)
162
163     def connectionLost(self, reason):
164         if self.otherConn:
165             self.otherConn.transport.loseConnection()
166
167     def authorize(self,code,server,port,user):
168         log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
169         return 1
170
171     def connectClass(self, host, port, klass, *args):
172         return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
173
174     def makeReply(self,reply,message=""):
175         self.transport.write("VNCProxy/1.0 %d %s\r\n\r\n" % (reply, message))
176         if int(reply / 100)!=2: self.transport.loseConnection()
177
178     def write(self,data):
179         #self.log(self,data)
180         self.transport.write(data)
181
182     def log(self,proto,data):
183         if not self.logging: return
184         peer = self.transport.getPeer()
185         their_peer = self.otherConn.transport.getPeer()
186         f=open(self.logging,"a")
187         f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
188                                         peer.host,peer.port,
189                                         ((proto==self and '<') or '>'),
190                                         their_peer.host,their_peer.port))
191         while data:
192             p,data=data[:16],data[16:]
193             f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
194             f.write((16-len(p))*3*' ')
195             for c in p:
196                 if len(repr(c))>3: f.write('.')
197                 else: f.write(c)
198             f.write('\n')
199         f.write('\n')
200         f.close()
201
202
203 class VNCAuthFactory(protocol.Factory):
204     """A factory for a VNC auth proxy.
205     
206     Constructor accepts one argument, a log file name.
207     """
208     
209     def __init__(self, log, server):
210         self.logging = log
211         self.server = server
212     
213     def buildProtocol(self, addr):
214         return VNCAuth(self.logging, self.server)
215