Drop privileges in VNC proxy if requested
[invirt/packages/invirt-web.git] / files / usr / bin / invirt-novnc-wsproxy
1 #!/usr/bin/env python
2
3 import os
4 import pwd
5 import grp
6 import contextlib
7 import socket
8 import ssl
9 from select import select
10 import tempfile
11 import urlparse
12 from novnc.websocket import WebSocketServer
13 import invirt.remctl
14 from invirt.config import structs as config
15 from optparse import OptionParser
16
17 def drop_privileges(uid_name='nobody', gid_name='nogroup'):
18     if os.getuid() != 0:
19         return
20
21     # Get the uid/gid from the name
22     uid = pwd.getpwnam(uid_name).pw_uid
23     gid = grp.getgrnam(gid_name).gr_gid
24
25     # Try setting the new uid/gid
26     os.setgid(gid)
27     os.setuid(uid)
28
29     # Ensure a very convervative umask
30     new_umask = 077
31     os.umask(new_umask)
32
33 # From Python >=2.7.9
34 _RESTRICTED_SERVER_CIPHERS = (
35     'ECDH+AESGCM:DH+AESGCM:ECDH+AES256:DH+AES256:ECDH+AES128:DH+AES:ECDH+HIGH:'
36     'DH+HIGH:ECDH+3DES:DH+3DES:RSA+AESGCM:RSA+AES:RSA+HIGH:RSA+3DES:!aNULL:'
37     '!eNULL:!MD5:!DSS:!RC4'
38 )
39
40 @contextlib.contextmanager
41 def noop():
42     yield
43
44 class WebSocketProxy(WebSocketServer):
45     """
46     Proxy traffic from a WebSockets client to an Invirt VNC server,
47     doing the auth handshake for the client.
48     """
49
50     def __init__(self, user=None, group=None, *args, **kwargs):
51         super(WebSocketProxy, self).__init__(*args, **kwargs)
52         self.user = user
53         self.group = group
54         self.server_cas = None
55
56     def started(self):
57         super(WebSocketProxy, self).started()
58         if self.user:
59             cert = open(self.cert).read()
60             key = open(self.key).read()
61             cas = ""
62             for h in config.hosts:
63                 cas += invirt.remctl.remctl(config.remote.hostname, "web", "vnccert", h.hostname)
64             drop_privileges(self.user, self.group)
65             self.cert_tf = tempfile.NamedTemporaryFile()
66             self.cert_tf.write(cert)
67             self.cert_tf.flush()
68             self.cert = self.cert_tf.name
69             self.key_tf = tempfile.NamedTemporaryFile()
70             self.key_tf.write(key)
71             self.key_tf.flush()
72             self.key = self.key_tf.name
73             self.server_cafile = tempfile.NamedTemporaryFile()
74             self.server_cafile.write(cas)
75             self.server_cafile.flush()
76             self.server_cas = self.server_cafile.name
77
78     def new_client(self):
79         url = urlparse.urlparse(self.path)
80         query = urlparse.parse_qs(url.query)
81
82         host = query.get('host', [None])[-1]
83         vmname = query.get('vmname', [None])[-1]
84         token = query.get('token', [None])[-1]
85
86         target_host = None
87         target_port = config.vnc.base_port
88
89         for h in config.hosts:
90             if h.hostname == host:
91                 target_host = h.ip
92
93         if not target_host:
94             raise Exception("host not found")
95         if not vmname:
96             raise Exception("vmname not provided")
97         if not token:
98             raise Exception("token not provided")
99
100         tsock = self.socket(target_host, target_port, connect=True)
101
102         server_cas = self.server_cas
103         ctx = noop()
104         if not server_cas:
105             ctx = tempfile.NamedTemporaryFile()
106
107         with ctx as cafile:
108             if not server_cas:
109                 cadata = invirt.remctl.remctl(config.remote.hostname, "web", "vnccert", host)
110                 cafile.write(cadata)
111                 cafile.flush()
112                 server_cas = cafile.name
113
114             # TODO: Use ssl.create_default_context when we move to Python >=2.7.9
115             tsock = ssl.wrap_socket(tsock, ca_certs=server_cas, cert_reqs=ssl.CERT_REQUIRED, ssl_version=ssl.PROTOCOL_SSLv23, ciphers=_RESTRICTED_SERVER_CIPHERS)
116
117             # Start proxying
118             try:
119                 extra_data = self.do_auth_handshake(tsock, vmname, token)
120                 self.do_proxy(tsock, extra_data)
121             except:
122                 if tsock:
123                     tsock.shutdown(socket.SHUT_RDWR)
124                     tsock.close()
125                     self.vmsg("%s:%s: Target closed" % (
126                         target_host, target_port))
127                 raise
128
129     def do_auth_handshake(self, target, vmname, token):
130         target.send("CONNECTVNC %s VNCProxy/1.0\r\nAuth-token: %s\r\n\r\n" % (vmname, token))
131         data = target.recv(128)
132         if data.startswith("VNCProxy/1.0 200 "):
133             if "\n" in data:
134                 return data[data.find("\n")+3:]
135             return None
136         else:
137             raise Exception(data)
138
139     def do_proxy(self, target, extra_data):
140         """
141         Proxy client WebSocket to normal target socket.
142         """
143         cqueue = []
144         c_pend = 0
145         tqueue = []
146         rlist = [self.client, target]
147
148         if extra_data:
149             tqueue.append(extra_data)
150
151         while True:
152             wlist = []
153
154             if tqueue: wlist.append(target)
155             if cqueue or c_pend: wlist.append(self.client)
156             ins, outs, excepts = select(rlist, wlist, [], 1)
157             if excepts: raise Exception("Socket exception")
158
159             if target in outs:
160                 # Send queued client data to the target
161                 dat = tqueue.pop(0)
162                 sent = target.send(dat)
163                 if sent != len(dat):
164                     # requeue the remaining data
165                     tqueue.insert(0, dat[sent:])
166
167
168             if target in ins:
169                 # Receive target data, encode it and queue for client
170                 buf = target.recv(self.buffer_size)
171                 if len(buf) == 0: raise self.EClose("Target closed")
172
173                 cqueue.append(buf)
174
175
176             if self.client in outs:
177                 # Send queued target data to the client
178                 c_pend = self.send_frames(cqueue)
179
180                 cqueue = []
181
182
183             if self.client in ins:
184                 # Receive client data, decode it, and queue for target
185                 bufs, closed = self.recv_frames()
186                 tqueue.extend(bufs)
187
188                 if closed:
189                     # TODO: What about blocking on client socket?
190                     self.send_close()
191                     raise self.EClose(closed)
192
193 if __name__ == '__main__':
194     parser = OptionParser()
195     parser.add_option("-u", "--user", dest="user", default="nobody",
196                       help="user to drop privileges to", metavar="USER")
197     parser.add_option("-g", "--group", dest="group", default="nogroup",
198                       help="group to drop privileges to", metavar="GROUP")
199
200     (options, args) = parser.parse_args()
201
202     server = WebSocketProxy(cert="/etc/apache2/ssl/server.crt",
203                             key="/etc/apache2/ssl/server.key",
204                             listen_port=config.vnc.novnc_port,
205                             ssl_only=True,
206                             user=options.user,
207                             group=options.group,
208     )
209     server.start_server()