#!/usr/bin/env python import os import pwd import grp import contextlib import socket import ssl from select import select import tempfile import urlparse from novnc.websocket import WebSocketServer import invirt.remctl from invirt.config import structs as config from optparse import OptionParser def drop_privileges(uid_name='nobody', gid_name='nogroup'): if os.getuid() != 0: return # Get the uid/gid from the name uid = pwd.getpwnam(uid_name).pw_uid gid = grp.getgrnam(gid_name).gr_gid # Try setting the new uid/gid os.setgid(gid) os.setuid(uid) # Ensure a very convervative umask new_umask = 077 os.umask(new_umask) # From Python >=2.7.9 _RESTRICTED_SERVER_CIPHERS = ( 'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:' 'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:' '!eNULL:!MD5:!DSS:!RC4' ) @contextlib.contextmanager def noop(): yield class WebSocketProxy(WebSocketServer): """ Proxy traffic from a WebSockets client to an Invirt VNC server, doing the auth handshake for the client. """ def __init__(self, user=None, group=None, *args, **kwargs): super(WebSocketProxy, self).__init__(*args, **kwargs) self.user = user self.group = group self.server_cas = None def started(self): super(WebSocketProxy, self).started() if self.user: cert = open(self.cert).read() key = open(self.key).read() cas = "" for h in config.hosts: cas += invirt.remctl.remctl(config.remote.hostname, "web", "vnccert", h.hostname) drop_privileges(self.user, self.group) self.cert_tf = tempfile.NamedTemporaryFile() self.cert_tf.write(cert) self.cert_tf.flush() self.cert = self.cert_tf.name self.key_tf = tempfile.NamedTemporaryFile() self.key_tf.write(key) self.key_tf.flush() self.key = self.key_tf.name self.server_cafile = tempfile.NamedTemporaryFile() self.server_cafile.write(cas) self.server_cafile.flush() self.server_cas = self.server_cafile.name def new_client(self): url = urlparse.urlparse(self.path) query = urlparse.parse_qs(url.query) host = query.get('host', [None])[-1] vmname = query.get('vmname', [None])[-1] token = query.get('token', [None])[-1] target_host = None target_port = config.vnc.base_port for h in config.hosts: if h.hostname == host: target_host = h.ip if not target_host: raise Exception("host not found") if not vmname: raise Exception("vmname not provided") if not token: raise Exception("token not provided") tsock = self.socket(target_host, target_port, connect=True) server_cas = self.server_cas ctx = noop() if not server_cas: ctx = tempfile.NamedTemporaryFile() with ctx as cafile: if not server_cas: cadata = invirt.remctl.remctl(config.remote.hostname, "web", "vnccert", host) cafile.write(cadata) cafile.flush() server_cas = cafile.name # TODO: Use ssl.create_default_context when we move to Python >=2.7.9 tsock = ssl.wrap_socket(tsock, ca_certs=server_cas, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23, ciphers=_RESTRICTED_SERVER_CIPHERS) # Start proxying try: extra_data = self.do_auth_handshake(tsock, vmname, token) self.do_proxy(tsock, extra_data) except: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() self.vmsg("%s:%s: Target closed" % ( target_host, target_port)) raise def do_auth_handshake(self, target, vmname, token): target.send("CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (vmname, token)) data = target.recv(128) if data.startswith("VNCProxy/1.0 200 "): if "\n" in data: return data[data.find("\n")+3:] return None else: raise Exception(data) def do_proxy(self, target, extra_data): """ Proxy client WebSocket to normal target socket. """ cqueue = [] c_pend = 0 tqueue = [] rlist = [self.client, target] if extra_data: tqueue.append(extra_data) while True: wlist = [] if tqueue: wlist.append(target) if cqueue or c_pend: wlist.append(self.client) ins, outs, excepts = select(rlist, wlist, [], 1) if excepts: raise Exception("Socket exception") if target in outs: # Send queued client data to the target dat = tqueue.pop(0) sent = target.send(dat) if sent != len(dat): # requeue the remaining data tqueue.insert(0, dat[sent:]) if target in ins: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: raise self.EClose("Target closed") cqueue.append(buf) if self.client in outs: # Send queued target data to the client c_pend = self.send_frames(cqueue) cqueue = [] if self.client in ins: # Receive client data, decode it, and queue for target bufs, closed = self.recv_frames() tqueue.extend(bufs) if closed: # TODO: What about blocking on client socket? self.send_close() raise self.EClose(closed) if __name__ == '__main__': parser = OptionParser() parser.add_option("-u", "--user", dest="user", default="nobody", help="user to drop privileges to", metavar="USER") parser.add_option("-g", "--group", dest="group", default="nogroup", help="group to drop privileges to", metavar="GROUP") (options, args) = parser.parse_args() server = WebSocketProxy(cert="/etc/apache2/ssl/server.crt", key="/etc/apache2/ssl/server.key", listen_port=config.vnc.novnc_port, ssl_only=True, user=options.user, group=options.group, ) server.start_server()