Subclass pickle.Unpickler for security
[invirt/scripts/vnc-client.git] / invirt-vnc-client
1 #!/usr/bin/env python
2 from twisted.internet import reactor, ssl, protocol, error
3 from OpenSSL import SSL
4 import base64, pickle
5 import getopt, sys, os, time
6 import io
7
8 verbose = False
9
10 def usage():
11     print("""%s [-v] [-l [HOST:]PORT] {-a AUTHTOKEN|VMNAME}
12  -l, --listen [HOST:]PORT  port (and optionally host) to listen on for
13                            connections (default is 127.0.0.1 and a randomly
14                            chosen port). Use an empty HOST to listen on all
15                            interfaces (INSECURE!)
16  -a, --authtoken AUTHTOKEN Authentication token for connecting to the VNC server
17  VMNAME                    VM name to connect to (automatically fetches an
18                            authentication token using remctl)
19  -v                        verbose status messages""" % (sys.argv[0]))
20
21 class ClientContextFactory(ssl.ClientContextFactory):
22
23     def _verify(self, connection, x509, errnum, errdepth, ok):
24         if verbose:
25             print('_verify (ok=%d):' % ok)
26             print('  subject:', x509.get_subject())
27             print('  issuer:', x509.get_issuer())
28             print('  errnum %s, errdepth %d' % (errnum, errdepth))
29         if errnum == 10:
30             print('The VNC server certificate has expired. Please contact xvm@mit.edu.')
31         return ok
32
33     def getContext(self):
34         ctx = ssl.ClientContextFactory.getContext(self)
35
36         certFile = '/mit/xvm/vnc/servers.cert'
37         if verbose: print("Loading certificates from %s" % certFile)
38         ctx.load_verify_locations(certFile)
39         ctx.set_verify(SSL.VERIFY_PEER|SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
40                        self._verify)
41
42         return ctx
43
44 class Proxy(protocol.Protocol):
45     peer = None
46
47     def setPeer(self, peer):
48         self.peer = peer
49
50     def connectionLost(self, reason):
51         if self.peer is not None:
52             self.peer.transport.loseConnection()
53             self.peer = None
54
55     def dataReceived(self, data):
56         self.peer.transport.write(data)
57
58 class ProxyClient(Proxy):
59     ready = False
60
61     def connectionMade(self):
62         self.peer.setPeer(self)
63         data = b"CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (self.factory.machine.encode(), self.factory.authtoken.encode())
64         self.transport.write(data)
65         if verbose: print("ProxyClient: connection made")
66     def dataReceived(self, data):
67         if not self.ready:
68             if verbose: print('ProxyClient: received data %r' % data)
69             if data.startswith(b"VNCProxy/1.0 200 "):
70                 self.ready = True
71                 if b"\n" in data:
72                     self.peer.transport.write(data[data.find(b"\n")+3:])
73                 self.peer.transport.resumeProducing() # Allow reading
74             else:
75                 print("Failed to connect: %r" % data)
76                 self.transport.loseConnection()
77         else:
78             self.peer.transport.write(data)
79
80 class ProxyClientFactory(protocol.ClientFactory):
81     protocol = ProxyClient
82     
83     def __init__(self, authtoken, machine):
84         self.authtoken = authtoken
85         self.machine = machine
86
87     def setServer(self, server):
88         self.server = server
89
90     def buildProtocol(self, *args, **kw):
91         prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
92         prot.setPeer(self.server)
93         return prot
94
95     def clientConnectionFailed(self, connector, reason):
96         self.server.transport.loseConnection()
97
98
99 class ProxyServer(Proxy):
100     clientProtocolFactory = ProxyClientFactory
101     authtoken = None
102     machine = None
103
104     def connectionMade(self):
105         # Don't read anything from the connecting client until we have
106         # somewhere to send it to.
107         self.transport.pauseProducing()
108         
109         if verbose: print("ProxyServer: connection made")
110
111         client = self.clientProtocolFactory(self.factory.authtoken, self.factory.machine)
112         client.setServer(self)
113
114         reactor.connectSSL(self.factory.host, self.factory.port, client, ClientContextFactory())
115         
116
117 class ProxyFactory(protocol.Factory):
118     protocol = ProxyServer
119
120     def __init__(self, host, port, authtoken, machine):
121         self.host = host
122         self.port = port
123         self.authtoken = authtoken
124         self.machine = machine
125
126 class SafeUnpickler(pickle.Unpickler):
127     def find_class(self, module, name):
128         raise pickle.UnpicklingError("globals are forbidden")
129
130 def main():
131     global verbose
132     try:
133         opts, args = getopt.gnu_getopt(sys.argv[1:], "hl:a:v",
134                                        ["help", "listen=", "authtoken="])
135     except getopt.GetoptError as err:
136         print(str(err)) # will print something like "option -a not recognized"
137         usage()
138         sys.exit(2)
139     listen = ["127.0.0.1", None]
140     authtoken = None
141     for o, a in opts:
142         if o == "-v":
143             verbose = True
144         elif o in ("-h", "--help"):
145             usage()
146             sys.exit()
147         elif o in ("-l", "--listen"):
148             if ":" in a:
149                 listen = a.split(":", 2)
150                 listen[1] = int(listen[1])
151             else:
152                 listen[1] = int(a)
153         elif o in ("-a", "--authtoken"):
154             authtoken = a
155         else:
156             assert False, "unhandled option"
157
158     # Get authentication token
159     if authtoken is None:
160         # User didn't give us an authentication token, so we need to get one
161         if len(args) != 1:
162             print("VMNAME not given or too many arguments")
163             usage()
164             sys.exit(2)
165         from subprocess import PIPE, Popen
166         try:
167             p = Popen(["remctl", "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
168                       stdout=PIPE, universal_newlines=True)
169         except OSError:
170             if verbose: print("remctl not found in path. Trying remctl locker.")
171             p = Popen(["athrun", "remctl", "remctl",
172                        "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
173                       stdout=PIPE, universal_newlines=True)
174         authtoken = p.communicate()[0]
175         if p.returncode != 0:
176             print("Unable to get authentication token")
177             sys.exit(1)
178         if verbose: print('Got authentication token "%s" for VM %s' % \
179                           (authtoken, args[0]))
180
181     # Unpack authentication token
182     try:
183         token_inner = SafeUnpickler(io.BytesIO(base64.urlsafe_b64decode((authtoken.split("."))[0]))).load()
184         machine = token_inner["machine"]
185         connect_host = token_inner["connect_host"]
186         connect_port = token_inner["connect_port"]
187         token_expires = token_inner["expires"]
188         if verbose: print("Unpacked authentication token:\n%s" % \
189                           repr(token_inner))
190     except:
191         print("Invalid authentication token")
192         sys.exit(1)
193     
194     if verbose: print("Will connect to %s:%s" % (connect_host, connect_port))
195     if listen[1] is None:
196         listen[1] = 5900
197         ready = False
198         while not ready and listen[1] < 6000:
199             try:
200                 reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine), interface=listen[0])
201                 ready = True
202             except error.CannotListenError:
203                 listen[1] += 1
204     else:
205         reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine))
206     
207     print("Ready to connect. Connect to %s:%s (display %d) now with your VNC client. The password is 'moocow'." % (listen[0], listen[1], listen[1]-5900))
208     print("You must connect before your authentication token expires at %s." % \
209           (time.ctime(token_expires)))
210     
211     reactor.run()
212
213 if '__main__' == __name__:
214     main()