Subclass pickle.Unpickler for security
[invirt/scripts/vnc-client.git] / invirt-vnc-client
index 63d3999..f71366a 100755 (executable)
@@ -1,13 +1,14 @@
-#!/usr/bin/python
+#!/usr/bin/env python
 from twisted.internet import reactor, ssl, protocol, error
 from OpenSSL import SSL
 import base64, pickle
 import getopt, sys, os, time
 from twisted.internet import reactor, ssl, protocol, error
 from OpenSSL import SSL
 import base64, pickle
 import getopt, sys, os, time
+import io
 
 verbose = False
 
 def usage():
 
 verbose = False
 
 def usage():
-    print """%s [-v] [-l [HOST:]PORT] {-a AUTHTOKEN|VMNAME}
+    print("""%s [-v] [-l [HOST:]PORT] {-a AUTHTOKEN|VMNAME}
  -l, --listen [HOST:]PORT  port (and optionally host) to listen on for
                            connections (default is 127.0.0.1 and a randomly
                            chosen port). Use an empty HOST to listen on all
  -l, --listen [HOST:]PORT  port (and optionally host) to listen on for
                            connections (default is 127.0.0.1 and a randomly
                            chosen port). Use an empty HOST to listen on all
@@ -15,25 +16,25 @@ def usage():
  -a, --authtoken AUTHTOKEN Authentication token for connecting to the VNC server
  VMNAME                    VM name to connect to (automatically fetches an
                            authentication token using remctl)
  -a, --authtoken AUTHTOKEN Authentication token for connecting to the VNC server
  VMNAME                    VM name to connect to (automatically fetches an
                            authentication token using remctl)
- -v                        verbose status messages""" % (sys.argv[0])
+ -v                        verbose status messages""" % (sys.argv[0]))
 
 class ClientContextFactory(ssl.ClientContextFactory):
 
     def _verify(self, connection, x509, errnum, errdepth, ok):
         if verbose:
 
 class ClientContextFactory(ssl.ClientContextFactory):
 
     def _verify(self, connection, x509, errnum, errdepth, ok):
         if verbose:
-            print '_verify (ok=%d):' % ok
-            print '  subject:', x509.get_subject()
-            print '  issuer:', x509.get_issuer()
-            print '  errnum %s, errdepth %d' % (errnum, errdepth)
+            print('_verify (ok=%d):' % ok)
+            print('  subject:', x509.get_subject())
+            print('  issuer:', x509.get_issuer())
+            print('  errnum %s, errdepth %d' % (errnum, errdepth))
         if errnum == 10:
         if errnum == 10:
-            print 'The VNC server certificate has expired. Please contact xvm@mit.edu.'
+            print('The VNC server certificate has expired. Please contact xvm@mit.edu.')
         return ok
 
     def getContext(self):
         ctx = ssl.ClientContextFactory.getContext(self)
 
         certFile = '/mit/xvm/vnc/servers.cert'
         return ok
 
     def getContext(self):
         ctx = ssl.ClientContextFactory.getContext(self)
 
         certFile = '/mit/xvm/vnc/servers.cert'
-        if verbose: print "Loading certificates from %s" % certFile
+        if verbose: print("Loading certificates from %s" % certFile)
         ctx.load_verify_locations(certFile)
         ctx.set_verify(SSL.VERIFY_PEER|SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
                        self._verify)
         ctx.load_verify_locations(certFile)
         ctx.set_verify(SSL.VERIFY_PEER|SSL.VERIFY_FAIL_IF_NO_PEER_CERT,
                        self._verify)
@@ -59,19 +60,19 @@ class ProxyClient(Proxy):
 
     def connectionMade(self):
         self.peer.setPeer(self)
 
     def connectionMade(self):
         self.peer.setPeer(self)
-        data = "CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (self.factory.machine, self.factory.authtoken)
+        data = b"CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (self.factory.machine.encode(), self.factory.authtoken.encode())
         self.transport.write(data)
         self.transport.write(data)
-        if verbose: print "ProxyClient: connection made"
+        if verbose: print("ProxyClient: connection made")
     def dataReceived(self, data):
         if not self.ready:
     def dataReceived(self, data):
         if not self.ready:
-            if verbose: print 'ProxyClient: received data "%s"' % data
-            if data.startswith("VNCProxy/1.0 200 "):
+            if verbose: print('ProxyClient: received data %r' % data)
+            if data.startswith(b"VNCProxy/1.0 200 "):
                 self.ready = True
                 self.ready = True
-                if "\n" in data:
-                    self.peer.transport.write(data[data.find("\n")+3:])
+                if b"\n" in data:
+                    self.peer.transport.write(data[data.find(b"\n")+3:])
                 self.peer.transport.resumeProducing() # Allow reading
             else:
                 self.peer.transport.resumeProducing() # Allow reading
             else:
-                print "Failed to connect: %s" % data
+                print("Failed to connect: %r" % data)
                 self.transport.loseConnection()
         else:
             self.peer.transport.write(data)
                 self.transport.loseConnection()
         else:
             self.peer.transport.write(data)
@@ -105,7 +106,7 @@ class ProxyServer(Proxy):
         # somewhere to send it to.
         self.transport.pauseProducing()
         
         # somewhere to send it to.
         self.transport.pauseProducing()
         
-        if verbose: print "ProxyServer: connection made"
+        if verbose: print("ProxyServer: connection made")
 
         client = self.clientProtocolFactory(self.factory.authtoken, self.factory.machine)
         client.setServer(self)
 
         client = self.clientProtocolFactory(self.factory.authtoken, self.factory.machine)
         client.setServer(self)
@@ -122,13 +123,17 @@ class ProxyFactory(protocol.Factory):
         self.authtoken = authtoken
         self.machine = machine
 
         self.authtoken = authtoken
         self.machine = machine
 
+class SafeUnpickler(pickle.Unpickler):
+    def find_class(self, module, name):
+        raise pickle.UnpicklingError("globals are forbidden")
+
 def main():
     global verbose
     try:
         opts, args = getopt.gnu_getopt(sys.argv[1:], "hl:a:v",
                                        ["help", "listen=", "authtoken="])
 def main():
     global verbose
     try:
         opts, args = getopt.gnu_getopt(sys.argv[1:], "hl:a:v",
                                        ["help", "listen=", "authtoken="])
-    except getopt.GetoptError, err:
-        print str(err) # will print something like "option -a not recognized"
+    except getopt.GetoptError as err:
+        print(str(err)) # will print something like "option -a not recognized"
         usage()
         sys.exit(2)
     listen = ["127.0.0.1", None]
         usage()
         sys.exit(2)
     listen = ["127.0.0.1", None]
@@ -154,41 +159,39 @@ def main():
     if authtoken is None:
         # User didn't give us an authentication token, so we need to get one
         if len(args) != 1:
     if authtoken is None:
         # User didn't give us an authentication token, so we need to get one
         if len(args) != 1:
-            print "VMNAME not given or too many arguments"
+            print("VMNAME not given or too many arguments")
             usage()
             sys.exit(2)
         from subprocess import PIPE, Popen
         try:
             p = Popen(["remctl", "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
             usage()
             sys.exit(2)
         from subprocess import PIPE, Popen
         try:
             p = Popen(["remctl", "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
-                      stdout=PIPE)
+                      stdout=PIPE, universal_newlines=True)
         except OSError:
         except OSError:
-            if verbose: print "remctl not found in path. Trying remctl locker."
+            if verbose: print("remctl not found in path. Trying remctl locker.")
             p = Popen(["athrun", "remctl", "remctl",
                        "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
             p = Popen(["athrun", "remctl", "remctl",
                        "xvm-remote.mit.edu", "control", args[0], "vnctoken"],
-                      stdout=PIPE)
+                      stdout=PIPE, universal_newlines=True)
         authtoken = p.communicate()[0]
         if p.returncode != 0:
         authtoken = p.communicate()[0]
         if p.returncode != 0:
-            print "Unable to get authentication token"
+            print("Unable to get authentication token")
             sys.exit(1)
             sys.exit(1)
-        if verbose: print 'Got authentication token "%s" for VM %s' % \
-                          (authtoken, args[0])
+        if verbose: print('Got authentication token "%s" for VM %s' % \
+                          (authtoken, args[0]))
 
     # Unpack authentication token
     try:
 
     # Unpack authentication token
     try:
-        token_outer = base64.urlsafe_b64decode(authtoken)
-        token_outer = pickle.loads(token_outer)
-        token_inner = pickle.loads(token_outer["data"])
+        token_inner = SafeUnpickler(io.BytesIO(base64.urlsafe_b64decode((authtoken.split("."))[0]))).load()
         machine = token_inner["machine"]
         connect_host = token_inner["connect_host"]
         connect_port = token_inner["connect_port"]
         token_expires = token_inner["expires"]
         machine = token_inner["machine"]
         connect_host = token_inner["connect_host"]
         connect_port = token_inner["connect_port"]
         token_expires = token_inner["expires"]
-        if verbose: print "Unpacked authentication token:\n%s" % \
-                          repr(token_inner)
+        if verbose: print("Unpacked authentication token:\n%s" % \
+                          repr(token_inner))
     except:
     except:
-        print "Invalid authentication token"
+        print("Invalid authentication token")
         sys.exit(1)
     
         sys.exit(1)
     
-    if verbose: print "Will connect to %s:%s" % (connect_host, connect_port) 
+    if verbose: print("Will connect to %s:%s" % (connect_host, connect_port))
     if listen[1] is None:
         listen[1] = 5900
         ready = False
     if listen[1] is None:
         listen[1] = 5900
         ready = False
@@ -201,9 +204,9 @@ def main():
     else:
         reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine))
     
     else:
         reactor.listenTCP(listen[1], ProxyFactory(connect_host, connect_port, authtoken, machine))
     
-    print "Ready to connect. Connect to %s:%s (display %d) now with your VNC client. The password is 'moocow'." % (listen[0], listen[1], listen[1]-5900)
-    print "You must connect before your authentication token expires at %s." % \
-          (time.ctime(token_expires))
+    print("Ready to connect. Connect to %s:%s (display %d) now with your VNC client. The password is 'moocow'." % (listen[0], listen[1], listen[1]-5900))
+    print("You must connect before your authentication token expires at %s." % \
+          (time.ctime(token_expires)))
     
     reactor.run()
 
     
     reactor.run()