diff --git a/include/websock.js b/include/websock.js index 20d51d6a..e3f9d2af 100644 --- a/include/websock.js +++ b/include/websock.js @@ -61,6 +61,8 @@ function Websock() { var api = {}, // Public API websocket = null, // WebSocket object + protocols, // Protocols to request in priority order + mode = 'base64', rQ = [], // Receive queue rQi = 0, // Receive queue index rQmax = 10000, // Max receive queue size before compacting @@ -123,26 +125,46 @@ function rQshift32() { (rQ[rQi++] << 8) + (rQ[rQi++] ); } +function rQslice(start, end) { + if (mode === 'binary') { + if (end) { + return rQ.subarray(rQi + start, rQi + end); + } else { + return rQ.subarray(rQi + start); + } + } else { + if (end) { + return rQ.slice(rQi + start, rQi + end); + } else { + return rQ.slice(rQi + start); + } + } +} + function rQshiftStr(len) { if (typeof(len) === 'undefined') { len = rQlen(); } - var arr = rQ.slice(rQi, rQi + len); + var arr = rQslice(0, len); rQi += len; - return arr.map(function (num) { - return String.fromCharCode(num); } ).join(''); - + return String.fromCharCode.apply(null, arr); } function rQshiftBytes(len) { if (typeof(len) === 'undefined') { len = rQlen(); } - rQi += len; - return rQ.slice(rQi-len, rQi); -} - -function rQslice(start, end) { - if (end) { - return rQ.slice(rQi + start, rQi + end); + var a = rQslice(0, len), b = []; + if (mode === 'binary') { + // Convert to plain array + b.push.apply(b, a); } else { - return rQ.slice(rQi + start); + // Already plain array, just return the original + b = a } + rQi += len; + return b; +} +function rQshiftArray(len) { + if (typeof(len) === 'undefined') { len = rQlen(); } + var a = rQslice(0, len); + rQi += len; + return a; } // Check to see if we must wait for 'num' bytes (default to FBU.bytes) @@ -170,13 +192,26 @@ function rQwait(msg, num, goback) { function encode_message() { /* base64 encode */ - return Base64.encode(sQ); + if (mode === 'binary') { + return (new Uint8Array(sQ)).buffer; + } else { + return Base64.encode(sQ); + } } function decode_message(data) { //Util.Debug(">> decode_message: " + data); - /* base64 decode */ - rQ = rQ.concat(Base64.decode(data, 0)); + if (mode === 'binary') { + // Create new arraybuffer and dump old and new data into it + // TODO: this could be far more efficient and re-use the array + var new_rQ = new Uint8Array(rQ.length + data.byteLength); + new_rQ.set(rQ); + new_rQ.set(new Uint8Array(data), rQ.length); + rQ = new_rQ; + } else { + /* base64 decode and concat to the end */ + rQ = rQ.concat(Base64.decode(data, 0)); + } //Util.Debug(">> decode_message, rQ: " + rQ); } @@ -230,7 +265,7 @@ function recv_message(e) { // Compact the receive queue if (rQ.length > rQmax) { //Util.Debug("Compacting receive queue"); - rQ = rQ.slice(rQi); + rQ = rQslice(rQi); rQi = 0; } } else { @@ -263,7 +298,32 @@ function init() { rQ = []; rQi = 0; sQ = []; - websocket = null; + websocket = null, + protocols = "base64"; + + var bt = false, + wsbt = false; + + if (('Uint8Array' in window) && + ('set' in Uint8Array.prototype)) { + bt = true; + } + // TODO: this sucks, the property should exist on the prototype + // but it does not. + try { + if (bt && ('binaryType' in (new WebSocket("ws://localhost:17523")))) { + wsbt = true; + } + } catch (exc) { + // Just ignore failed test localhost connections + } + if (bt && wsbt) { + Util.Info("Detected binaryType support in WebSockets"); + protocols = ['binary', 'base64']; + } else { + Util.Info("No binaryType support in WebSockets, using base64 encoding"); + protocols = 'base64'; + } } function open(uri) { @@ -272,19 +332,22 @@ function open(uri) { if (test_mode) { websocket = {}; } else { - websocket = new WebSocket(uri, 'base64'); - // TODO: future native binary support - //websocket = new WebSocket(uri, ['binary', 'base64']); + websocket = new WebSocket(uri, protocols); } websocket.onmessage = recv_message; websocket.onopen = function() { Util.Debug(">> WebSock.onopen"); if (websocket.protocol) { + mode = websocket.protocol; Util.Info("Server chose sub-protocol: " + websocket.protocol); } else { + mode = 'base64'; Util.Error("Server select no sub-protocol!: " + websocket.protocol); } + if (mode === 'binary') { + websocket.binaryType = 'arraybuffer'; + } eventHandlers.open(); Util.Debug("<< WebSock.onopen"); }; diff --git a/utils/websocket.py b/utils/websocket.py index d3bb48cb..11f718cd 100644 --- a/utils/websocket.py +++ b/utils/websocket.py @@ -18,7 +18,6 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates import os, sys, time, errno, signal, socket, traceback, select import array, struct -from cgi import parse_qsl from base64 import b64encode, b64decode # Imports that vary by python version @@ -36,8 +35,6 @@ try: from io import StringIO except: from cStringIO import StringIO try: from http.server import SimpleHTTPRequestHandler except: from SimpleHTTPServer import SimpleHTTPRequestHandler -try: from urllib.parse import urlsplit -except: from urlparse import urlsplit # python 2.6 differences try: from hashlib import md5, sha1 @@ -75,6 +72,7 @@ class WebSocketServer(object): buffer_size = 65536 + server_handshake_hixie = """HTTP/1.1 101 Web Socket Protocol Handshake\r Upgrade: WebSocket\r Connection: Upgrade\r @@ -109,11 +107,12 @@ Sec-WebSocket-Accept: %s\r self.verbose = verbose self.listen_host = listen_host self.listen_port = listen_port + self.prefer_ipv6 = source_is_ipv6 self.ssl_only = ssl_only self.daemon = daemon self.run_once = run_once self.timeout = timeout - + self.launch_time = time.time() self.ws_connection = False self.handler_id = 1 @@ -163,7 +162,7 @@ Sec-WebSocket-Accept: %s\r # @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False): + def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False): """ Resolve a host (and optional port) to an IPv4 or IPv6 address. Create a socket. Bind to it if listen is set, otherwise connect to it. Return the socket. @@ -171,24 +170,36 @@ Sec-WebSocket-Accept: %s\r flags = 0 if host == '': host = None - if connect and not port: + if connect and not (port or unix_socket): raise Exception("Connect mode requires a port") + if use_ssl and not ssl: + raise Exception("SSL socket requested but Python SSL module not loaded."); + if not connect and use_ssl: + raise Exception("SSL only supported in connect mode (for now)") if not connect: flags = flags | socket.AI_PASSIVE - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, - socket.IPPROTO_TCP, flags) - if not addrs: - raise Exception("Could resolve host '%s'" % host) - addrs.sort(key=lambda x: x[0]) - if prefer_ipv6: - addrs.reverse() - sock = socket.socket(addrs[0][0], addrs[0][1]) - if connect: - sock.connect(addrs[0][4]) - else: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addrs[0][4]) - sock.listen(100) + + if not unix_socket: + addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, + socket.IPPROTO_TCP, flags) + if not addrs: + raise Exception("Could not resolve host '%s'" % host) + addrs.sort(key=lambda x: x[0]) + if prefer_ipv6: + addrs.reverse() + sock = socket.socket(addrs[0][0], addrs[0][1]) + if connect: + sock.connect(addrs[0][4]) + if use_ssl: + sock = ssl.wrap_socket(sock) + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addrs[0][4]) + sock.listen(100) + else: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(unix_socket) + return sock @staticmethod @@ -552,93 +563,9 @@ Sec-WebSocket-Accept: %s\r # No orderly close for 75 - def do_handshake(self, sock, address): - """ - do_handshake does the following: - - Peek at the first few bytes from the socket. - - If the connection is Flash policy request then answer it, - close the socket and return. - - If the connection is an HTTPS/SSL/TLS connection then SSL - wrap the socket. - - Read from the (possibly wrapped) socket. - - If we have received a HTTP GET request and the webserver - functionality is enabled, answer it, close the socket and - return. - - Assume we have a WebSockets connection, parse the client - handshake data. - - Send a WebSockets handshake server response. - - Return the socket for this WebSocket client. - """ - - stype = "" - - ready = select.select([sock], [], [], 3)[0] - if not ready: - raise self.EClose("ignoring socket not ready") - # Peek, but do not read the data so that we have a opportunity - # to SSL wrap the socket first - handshake = sock.recv(1024, socket.MSG_PEEK) - #self.msg("Handshake [%s]" % handshake) - - if handshake == "": - raise self.EClose("ignoring empty handshake") - - elif handshake.startswith(s2b("")): - # Answer Flash policy request - handshake = sock.recv(1024) - sock.send(s2b(self.policy_response)) - raise self.EClose("Sending flash policy response") - - elif handshake[0] in ("\x16", "\x80", 22, 128): - # SSL wrap the connection - if not ssl: - raise self.EClose("SSL connection but no 'ssl' module") - if not os.path.exists(self.cert): - raise self.EClose("SSL connection but '%s' not found" - % self.cert) - retsock = None - try: - retsock = ssl.wrap_socket( - sock, - server_side=True, - certfile=self.cert, - keyfile=self.key) - except ssl.SSLError: - _, x, _ = sys.exc_info() - if x.args[0] == ssl.SSL_ERROR_EOF: - if len(x.args) > 1: - raise self.EClose(x.args[1]) - else: - raise self.EClose("Got SSL_ERROR_EOF") - else: - raise - - scheme = "wss" - stype = "SSL/TLS (wss://)" - - elif self.ssl_only: - raise self.EClose("non-SSL connection received but disallowed") - - else: - retsock = sock - scheme = "ws" - stype = "Plain non-SSL (ws://)" - - wsh = WSRequestHandler(retsock, address, not self.web) - if wsh.last_code == 101: - # Continue on to handle WebSocket upgrade - pass - elif wsh.last_code == 405: - raise self.EClose("Normal web request received but disallowed") - elif wsh.last_code < 200 or wsh.last_code >= 300: - raise self.EClose(wsh.last_message) - elif self.verbose: - raise self.EClose(wsh.last_message) - else: - raise self.EClose("") - - h = self.headers = wsh.headers - path = self.path = wsh.path + def do_websocket_handshake(self, headers, path): + h = self.headers = headers + self.path = path prot = 'WebSocket-Protocol' protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') @@ -691,7 +618,7 @@ Sec-WebSocket-Accept: %s\r self.base64 = True response = self.server_handshake_hixie % (pre, - h['Origin'], pre, scheme, h['Host'], path) + h['Origin'], pre, self.scheme, h['Host'], path) if 'base64' in protocols: response += "%sWebSocket-Protocol: base64\r\n" % pre @@ -699,6 +626,96 @@ Sec-WebSocket-Accept: %s\r self.msg("Warning: client does not report 'base64' protocol support") response += "\r\n" + trailer + return response + + + def do_handshake(self, sock, address): + """ + do_handshake does the following: + - Peek at the first few bytes from the socket. + - If the connection is Flash policy request then answer it, + close the socket and return. + - If the connection is an HTTPS/SSL/TLS connection then SSL + wrap the socket. + - Read from the (possibly wrapped) socket. + - If we have received a HTTP GET request and the webserver + functionality is enabled, answer it, close the socket and + return. + - Assume we have a WebSockets connection, parse the client + handshake data. + - Send a WebSockets handshake server response. + - Return the socket for this WebSocket client. + """ + stype = "" + ready = select.select([sock], [], [], 3)[0] + + + if not ready: + raise self.EClose("ignoring socket not ready") + # Peek, but do not read the data so that we have a opportunity + # to SSL wrap the socket first + handshake = sock.recv(1024, socket.MSG_PEEK) + #self.msg("Handshake [%s]" % handshake) + + if handshake == "": + raise self.EClose("ignoring empty handshake") + + elif handshake.startswith(s2b("")): + # Answer Flash policy request + handshake = sock.recv(1024) + sock.send(s2b(self.policy_response)) + raise self.EClose("Sending flash policy response") + + elif handshake[0] in ("\x16", "\x80", 22, 128): + # SSL wrap the connection + if not ssl: + raise self.EClose("SSL connection but no 'ssl' module") + if not os.path.exists(self.cert): + raise self.EClose("SSL connection but '%s' not found" + % self.cert) + retsock = None + try: + retsock = ssl.wrap_socket( + sock, + server_side=True, + certfile=self.cert, + keyfile=self.key) + except ssl.SSLError: + _, x, _ = sys.exc_info() + if x.args[0] == ssl.SSL_ERROR_EOF: + if len(x.args) > 1: + raise self.EClose(x.args[1]) + else: + raise self.EClose("Got SSL_ERROR_EOF") + else: + raise + + self.scheme = "wss" + stype = "SSL/TLS (wss://)" + + elif self.ssl_only: + raise self.EClose("non-SSL connection received but disallowed") + + else: + retsock = sock + self.scheme = "ws" + stype = "Plain non-SSL (ws://)" + + wsh = WSRequestHandler(retsock, address, not self.web) + if wsh.last_code == 101: + # Continue on to handle WebSocket upgrade + pass + elif wsh.last_code == 405: + raise self.EClose("Normal web request received but disallowed") + elif wsh.last_code < 200 or wsh.last_code >= 300: + raise self.EClose(wsh.last_message) + elif self.verbose: + raise self.EClose(wsh.last_message) + else: + raise self.EClose("") + + response = self.do_websocket_handshake(wsh.headers, wsh.path) + self.msg("%s: %s WebSocket connection" % (address[0], stype)) self.msg("%s: Version %s, base64: '%s'" % (address[0], self.version, self.base64)) @@ -750,7 +767,7 @@ Sec-WebSocket-Accept: %s\r self.rec = None self.start_time = int(time.time()*1000) - # handler process + # handler process try: try: self.client = self.do_handshake(startsock, address) @@ -801,7 +818,7 @@ Sec-WebSocket-Accept: %s\r is a WebSockets client then call new_client() method (which must be overridden) for each new client connection. """ - lsock = self.socket(self.listen_host, self.listen_port) + lsock = self.socket(self.listen_host, self.listen_port, False, self.prefer_ipv6) if self.daemon: self.daemonize(keepfd=lsock.fileno(), chdir=self.web) @@ -848,7 +865,7 @@ Sec-WebSocket-Accept: %s\r continue else: raise - + if self.run_once: # Run in same process if run_once self.top_new_client(startsock, address) @@ -927,4 +944,3 @@ class WSRequestHandler(SimpleHTTPRequestHandler): def log_message(self, f, *args): # Save instead of printing self.last_message = f % args - diff --git a/utils/websockify b/utils/websockify index 550dff7a..965ce13c 100755 --- a/utils/websockify +++ b/utils/websockify @@ -13,9 +13,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates import socket, optparse, time, os, sys, subprocess from select import select -from websocket import WebSocketServer +import websocket +try: from urllib.parse import parse_qs, urlparse +except: from urlparse import parse_qs, urlparse -class WebSocketProxy(WebSocketServer): +class WebSocketProxy(websocket.WebSocketServer): """ Proxy traffic to and from a WebSockets client to a normal TCP socket server target. All traffic to/from the client is base64 @@ -43,6 +45,9 @@ Traffic Legend: self.target_port = kwargs.pop('target_port') self.wrap_cmd = kwargs.pop('wrap_cmd') self.wrap_mode = kwargs.pop('wrap_mode') + self.unix_target = kwargs.pop('unix_target') + self.ssl_target = kwargs.pop('ssl_target') + self.target_cfg = kwargs.pop('target_cfg') # Last 3 timestamps command was run self.wrap_times = [0, 0, 0] @@ -58,6 +63,7 @@ Traffic Legend: if not self.rebinder: raise Exception("rebind.so not found, perhaps you need to run make") + self.rebinder = os.path.abspath(self.rebinder) self.target_host = "127.0.0.1" # Loopback # Find a free high port @@ -71,7 +77,10 @@ Traffic Legend: "REBIND_OLD_PORT": str(kwargs['listen_port']), "REBIND_NEW_PORT": str(self.target_port)}) - WebSocketServer.__init__(self, *args, **kwargs) + if self.target_cfg: + self.target_cfg = os.path.abspath(self.target_cfg) + + websocket.WebSocketServer.__init__(self, *args, **kwargs) def run_wrap_cmd(self): print("Starting '%s'" % " ".join(self.wrap_cmd)) @@ -88,14 +97,26 @@ Traffic Legend: # Need to call wrapped command after daemonization so we can # know when the wrapped command exits if self.wrap_cmd: - print(" - proxying from %s:%s to '%s' (port %s)\n" % ( - self.listen_host, self.listen_port, - " ".join(self.wrap_cmd), self.target_port)) - self.run_wrap_cmd() + dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) + elif self.unix_target: + dst_string = self.unix_target else: - print(" - proxying from %s:%s to %s:%s\n" % ( - self.listen_host, self.listen_port, - self.target_host, self.target_port)) + dst_string = "%s:%s" % (self.target_host, self.target_port) + + if self.target_cfg: + msg = " - proxying from %s:%s to targets in %s" % ( + self.listen_host, self.listen_port, self.target_cfg) + else: + msg = " - proxying from %s:%s to %s" % ( + self.listen_host, self.listen_port, dst_string) + + if self.ssl_target: + msg += " (using SSL)" + + print(msg + "\n") + + if self.wrap_cmd: + self.run_wrap_cmd() def poll(self): # If we are wrapping a command, check it's status @@ -137,12 +158,26 @@ Traffic Legend: """ Called after a new WebSocket connection has been established. """ + # Checks if we receive a token, and look + # for a valid target for it then + if self.target_cfg: + (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) # Connect to the target - self.msg("connecting to: %s:%s" % ( - self.target_host, self.target_port)) + if self.wrap_cmd: + msg = "connecting to command: %s" % (" ".join(self.wrap_cmd), self.target_port) + elif self.unix_target: + msg = "connecting to unix socket: %s" % self.unix_target + else: + msg = "connecting to: %s:%s" % ( + self.target_host, self.target_port) + + if self.ssl_target: + msg += " (using SSL)" + self.msg(msg) + tsock = self.socket(self.target_host, self.target_port, - connect=True) + connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) if self.verbose and not self.daemon: print(self.traffic_legend) @@ -154,10 +189,49 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() - self.vmsg("%s:%s: Target closed" %( + self.vmsg("%s:%s: Closed target" %( self.target_host, self.target_port)) raise + def get_target(self, target_cfg, path): + """ + Parses the path, extracts a token, and looks for a valid + target for that token in the configuration file(s). Sets + target_host and target_port if successful + """ + # The files in targets contain the lines + # in the form of token: host:port + + # Extract the token parameter from url + args = parse_qs(urlparse(path)[4]) # 4 is the query from url + + if not len(args['token']): + raise self.EClose("Token not present") + + token = args['token'][0].rstrip('\n') + + # target_cfg can be a single config file or directory of + # config files + if os.path.isdir(target_cfg): + cfg_files = [os.path.join(target_cfg, f) + for f in os.listdir(target_cfg)] + else: + cfg_files = [target_cfg] + + targets = {} + for f in cfg_files: + for line in [l.strip() for l in file(f).readlines()]: + if line and not line.startswith('#'): + ttoken, target = line.split(': ') + targets[ttoken] = target.strip() + + self.vmsg("Target config: %s" % repr(targets)) + + if targets.has_key(token): + return targets[token].split(':') + else: + raise self.EClose("Token '%s' not found" % token) + def do_proxy(self, target): """ Proxy client WebSocket to normal target socket. @@ -191,6 +265,8 @@ Traffic Legend: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: + self.vmsg("%s:%s: Target closed connection" %( + self.target_host, self.target_port)) raise self.CClose(1000, "Target closed") cqueue.append(buf) @@ -211,11 +287,13 @@ Traffic Legend: if closed: # TODO: What about blocking on client socket? + self.vmsg("%s:%s: Client closed connection" %( + self.target_host, self.target_port)) raise self.CClose(closed['code'], closed['reason']) def websockify_init(): usage = "\n %prog [options]" - usage += " [source_addr:]source_port target_addr:target_port" + usage += " [source_addr:]source_port [target_addr:target_port]" usage += "\n %prog [options]" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" parser = optparse.OptionParser(usage=usage) @@ -235,17 +313,29 @@ def websockify_init(): parser.add_option("--key", default=None, help="SSL key file (if separate from cert)") parser.add_option("--ssl-only", action="store_true", - help="disallow non-encrypted connections") + help="disallow non-encrypted client connections") + parser.add_option("--ssl-target", action="store_true", + help="connect to SSL target as SSL client") + parser.add_option("--unix-target", + help="connect to unix socket target", metavar="FILE") parser.add_option("--web", default=None, metavar="DIR", help="run webserver on same port. Serve files from DIR.") parser.add_option("--wrap-mode", default="exit", metavar="MODE", choices=["exit", "ignore", "respawn"], help="action to take when the wrapped program exits " "or daemonizes: exit (default), ignore, respawn") + parser.add_option("--prefer-ipv6", "-6", + action="store_true", dest="source_is_ipv6", + help="prefer IPv6 when resolving source_addr") + parser.add_option("--target-config", metavar="FILE", + dest="target_cfg", + help="Configuration file containing valid targets " + "in the form 'token: host:port' or, alternatively, a " + "directory containing configuration files of this form") (opts, args) = parser.parse_args() # Sanity checks - if len(args) < 2: + if len(args) < 2 and not opts.target_cfg: parser.error("Too few arguments") if sys.argv.count('--'): opts.wrap_cmd = args[1:] @@ -254,24 +344,29 @@ def websockify_init(): if len(args) > 2: parser.error("Too many arguments") + if not websocket.ssl and opts.ssl_target: + parser.error("SSL target requested and Python SSL module not loaded."); + if opts.ssl_only and not os.path.exists(opts.cert): parser.error("SSL only and %s not found" % opts.cert) # Parse host:port and convert ports to numbers if args[0].count(':') > 0: opts.listen_host, opts.listen_port = args[0].rsplit(':', 1) + opts.listen_host = opts.listen_host.strip('[]') else: opts.listen_host, opts.listen_port = '', args[0] try: opts.listen_port = int(opts.listen_port) except: parser.error("Error parsing listen port") - if opts.wrap_cmd: + if opts.wrap_cmd or opts.unix_target or opts.target_cfg: opts.target_host = None opts.target_port = None else: if args[1].count(':') > 0: opts.target_host, opts.target_port = args[1].rsplit(':', 1) + opts.target_host = opts.target_host.strip('[]') else: parser.error("Error parsing target") try: opts.target_port = int(opts.target_port)