diff --git a/tests/ws.py b/tests/ws.py deleted file mode 100755 index 7d9b0bf3..00000000 --- a/tests/ws.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/python - -''' -WebSocket server-side load test program. Sends and receives traffic -that has a random payload (length and content) that is checksummed and -given a sequence number. Any errors are reported and counted. -''' - -import sys, os, socket, ssl, time, traceback -import random, time -from base64 import b64encode, b64decode -from select import select - -sys.path.insert(0,os.path.dirname(__file__) + "/../utils/") -from websocket import * - -buffer_size = 65536 -max_packet_size = 10000 -recv_cnt = send_cnt = 0 - - -def check(buf): - global recv_cnt - - try: - data_list = decode(buf) - except: - print "\n" + repr(buf) + "" - return "Failed to decode" - - err = "" - for data in data_list: - if data.count('$') > 1: - raise Exception("Multiple parts within single packet") - if len(data) == 0: - traffic("_") - continue - - if data[0] != "^": - err += "buf did not start with '^'\n" - continue - - try: - cnt, length, chksum, nums = data[1:-1].split(':') - cnt = int(cnt) - length = int(length) - chksum = int(chksum) - except: - print "\n" + repr(data) + "" - err += "Invalid data format\n" - continue - - if recv_cnt != cnt: - err += "Expected count %d but got %d\n" % (recv_cnt, cnt) - recv_cnt = cnt + 1 - continue - - recv_cnt += 1 - - if len(nums) != length: - err += "Expected length %d but got %d\n" % (length, len(nums)) - continue - - inv = nums.translate(None, "0123456789") - if inv: - err += "Invalid characters found: %s\n" % inv - continue - - real_chksum = 0 - for num in nums: - real_chksum += int(num) - - if real_chksum != chksum: - err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum) - return err - - -def generate(): - global send_cnt, rand_array - length = random.randint(10, max_packet_size) - numlist = rand_array[max_packet_size-length:] - # Error in length - #numlist.append(5) - chksum = sum(numlist) - # Error in checksum - #numlist[0] = 5 - nums = "".join( [str(n) for n in numlist] ) - data = "^%d:%d:%d:%s$" % (send_cnt, length, chksum, nums) - send_cnt += 1 - - return encode(data) - -def responder(client, delay=10): - global errors - cqueue = [] - cpartial = "" - socks = [client] - last_send = time.time() * 1000 - - while True: - ins, outs, excepts = select(socks, socks, socks, 1) - if excepts: raise Exception("Socket exception") - - if client in ins: - buf = client.recv(buffer_size) - if len(buf) == 0: raise Exception("Client closed") - #print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf)) - if buf[-1] == '\xff': - if cpartial: - err = check(cpartial + buf) - cpartial = "" - else: - err = check(buf) - if err: - traffic("}") - errors = errors + 1 - print err - else: - traffic(">") - else: - traffic(".>") - cpartial = cpartial + buf - - now = time.time() * 1000 - if client in outs and now > (last_send + delay): - last_send = now - #print "Client send: %s" % repr(cqueue[0]) - client.send(generate()) - traffic("<") - -def test_handler(client): - global errors, delay, send_cnt, recv_cnt - - send_cnt = 0 - recv_cnt = 0 - - try: - responder(client, delay) - except: - print "accumulated errors:", errors - errors = 0 - raise - - -if __name__ == '__main__': - errors = 0 - try: - if len(sys.argv) < 2: raise - listen_port = int(sys.argv[1]) - if len(sys.argv) == 3: - delay = int(sys.argv[2]) - else: - delay = 10 - except: - print "Usage: [delay_ms]" - sys.exit(1) - - print "Prepopulating random array" - rand_array = [] - for i in range(0, max_packet_size): - rand_array.append(random.randint(0, 9)) - - settings['listen_port'] = listen_port - settings['daemon'] = False - settings['handler'] = test_handler - start_server() diff --git a/tests/wsecho.html b/tests/wsecho.html new file mode 100644 index 00000000..9e3c6d6c --- /dev/null +++ b/tests/wsecho.html @@ -0,0 +1,176 @@ + + + + WebSockets Echo Test + + + + + + + + + + + + Host:   + Port:   + Encrypt:   +   + + +
+ Log:
+ + + + + + + diff --git a/tests/ws.html b/tests/wstest.html similarity index 100% rename from tests/ws.html rename to tests/wstest.html diff --git a/tests/wstest.py b/tests/wstest.py new file mode 120000 index 00000000..2c5b2b30 --- /dev/null +++ b/tests/wstest.py @@ -0,0 +1 @@ +../utils/wstest.py \ No newline at end of file diff --git a/utils/websocket.py b/utils/websocket.py index 7efd01a3..48eb15ac 100755 --- a/utils/websocket.py +++ b/utils/websocket.py @@ -23,20 +23,13 @@ except: from urlparse import urlsplit from cgi import parse_qsl -settings = { - 'verbose' : False, - 'listen_host' : '', - 'listen_port' : None, - 'handler' : None, - 'handler_id' : 1, - 'cert' : None, - 'key' : None, - 'ssl_only' : False, - 'daemon' : True, - 'record' : None, - 'web' : False, } +class WebSocketServer(): + """ + WebSockets server class. + Must be sub-classed with handler method definition. + """ -server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r + server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r Upgrade: WebSocket\r Connection: Upgrade\r %sWebSocket-Origin: %s\r @@ -45,10 +38,323 @@ Connection: Upgrade\r \r %s""" -policy_response = """\n""" + policy_response = """\n""" + + class EClose(Exception): + pass + + def __init__(self, listen_host='', listen_port=None, + verbose=False, cert='', key='', ssl_only=None, + daemon=False, record='', web=''): + + # settings + self.verbose = verbose + self.listen_host = listen_host + self.listen_port = listen_port + self.ssl_only = ssl_only + self.daemon = daemon + + + # Make paths settings absolute + self.cert = os.path.abspath(cert) + self.key = self.web = self.record = '' + if key: + self.key = os.path.abspath(key) + if web: + self.web = os.path.abspath(web) + if record: + self.record = os.path.abspath(record) + + if self.web: + os.chdir(self.web) + + self.handler_id = 1 + + # + # WebSocketServer static methods + # + @staticmethod + def daemonize(self, keepfd=None): + os.umask(0) + if self.web: + os.chdir(self.web) + else: + os.chdir('/') + os.setgid(os.getgid()) # relinquish elevations + os.setuid(os.getuid()) # relinquish elevations + + # Double fork to daemonize + if os.fork() > 0: os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: os._exit(0) # Parent exits + + # Signal handling + def terminate(a,b): os._exit(0) + signal.signal(signal.SIGTERM, terminate) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Close open files + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if maxfd == resource.RLIM_INFINITY: maxfd = 256 + for fd in reversed(range(maxfd)): + try: + if fd != keepfd: + os.close(fd) + except OSError, exc: + if exc.errno != errno.EBADF: raise + + # Redirect I/O to /dev/null + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) + + @staticmethod + def encode(buf): + """ Encode a WebSocket packet. """ + buf = b64encode(buf) + return "\x00%s\xff" % buf + + @staticmethod + def decode(buf): + """ Decode WebSocket packets. """ + if buf.count('\xff') > 1: + return [b64decode(d[1:]) for d in buf.split('\xff')] + else: + return [b64decode(buf[1:-1])] + + @staticmethod + def parse_handshake(handshake): + """ Parse fields from client WebSockets handshake. """ + ret = {} + req_lines = handshake.split("\r\n") + if not req_lines[0].startswith("GET "): + raise Exception("Invalid handshake: no GET request line") + ret['path'] = req_lines[0].split(" ")[1] + for line in req_lines[1:]: + if line == "": break + var, val = line.split(": ") + ret[var] = val + + if req_lines[-2] == "": + ret['key3'] = req_lines[-1] + + return ret + + @staticmethod + def gen_md5(keys): + """ Generate hash value for WebSockets handshake v76. """ + key1 = keys['Sec-WebSocket-Key1'] + key2 = keys['Sec-WebSocket-Key2'] + key3 = keys['key3'] + spaces1 = key1.count(" ") + spaces2 = key2.count(" ") + num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1 + num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2 + + return md5(struct.pack('>II8s', num1, num2, key3)).digest() + + + # + # WebSocketServer logging/output functions + # + + def traffic(self, token="."): + """ Show traffic flow in verbose mode. """ + if self.verbose and not self.daemon: + sys.stdout.write(token) + sys.stdout.flush() + + def msg(self, msg): + """ Output message with handler_id prefix. """ + if not self.daemon: + print "% 3d: %s" % (self.handler_id, msg) + + def vmsg(self, msg): + """ Same as msg() but only if verbose. """ + if self.verbose: + self.msg(msg) + + # + # Main WebSocketServer methods + # + + def do_handshake(self, sock, address): + """ + do_handshake does the following: + - Peek at the first few bytes from the socket. + - If the connection is Flash policy request then answer it, + close the socket and return. + - If the connection is an HTTPS/SSL/TLS connection then SSL + wrap the socket. + - Read from the (possibly wrapped) socket. + - If we have received a HTTP GET request and the webserver + functionality is enabled, answer it, close the socket and + return. + - Assume we have a WebSockets connection, parse the client + handshake data. + - Send a WebSockets handshake server response. + - Return the socket for this WebSocket client. + """ + + stype = "" + + # Peek, but don't read the data + handshake = sock.recv(1024, socket.MSG_PEEK) + #self.msg("Handshake [%s]" % repr(handshake)) + + if handshake == "": + raise self.EClose("ignoring empty handshake") + + elif handshake.startswith(""): + # Answer Flash policy request + handshake = sock.recv(1024) + sock.send(self.policy_response) + raise self.EClose("Sending flash policy response") + + elif handshake[0] in ("\x16", "\x80"): + # SSL wrap the connection + if not os.path.exists(self.cert): + raise self.EClose("SSL connection but '%s' not found" + % self.cert) + try: + retsock = ssl.wrap_socket( + sock, + server_side=True, + certfile=self.cert, + keyfile=self.key) + except ssl.SSLError, x: + if x.args[0] == ssl.SSL_ERROR_EOF: + raise self.EClose("") + else: + raise + + scheme = "wss" + stype = "SSL/TLS (wss://)" + + elif self.ssl_only: + raise self.EClose("non-SSL connection received but disallowed") + + else: + retsock = sock + scheme = "ws" + stype = "Plain non-SSL (ws://)" + + # Now get the data from the socket + handshake = retsock.recv(4096) + #self.msg("handshake: " + repr(handshake)) + + if len(handshake) == 0: + raise self.EClose("Client closed during handshake") + + # Check for and handle normal web requests + if handshake.startswith('GET ') and \ + handshake.find('Upgrade: WebSocket\r\n') == -1: + if not self.web: + raise self.EClose("Normal web request received but disallowed") + sh = SplitHTTPHandler(handshake, retsock, address) + if sh.last_code < 200 or sh.last_code >= 300: + raise self.EClose(sh.last_message) + elif self.verbose: + raise self.EClose(sh.last_message) + else: + raise self.EClose("") + + # Parse client WebSockets handshake + h = self.parse_handshake(handshake) + + if h.get('key3'): + trailer = self.gen_md5(h) + pre = "Sec-" + ver = 76 + else: + trailer = "" + pre = "" + ver = 75 + + self.msg("%s: %s WebSocket connection (version %s)" + % (address[0], stype, ver)) + + # Send server WebSockets handshake response + response = self.server_handshake % (pre, h['Origin'], pre, + scheme, h['Host'], h['path'], pre, trailer) + #self.msg("sending response:", repr(response)) + retsock.send(response) + + # Return the WebSockets socket which may be SSL wrapped + return retsock + + + def handler(self, client): + """ Do something with a WebSockets client connection. """ + raise("WebSocketServer.handler() must be overloaded") + + def start_server(self): + """ + Daemonize if requested. Listen for for connections. Run + do_handshake() method for each connection. If the connection + is a WebSockets client then call handler() method (which must + be overridden) for each connection. + """ + + lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + lsock.bind((self.listen_host, self.listen_port)) + lsock.listen(100) + + print "WebSocket server settings:" + print " - Listening on %s:%s" % ( + self.listen_host, self.listen_port) + if self.daemon: + print " - Backgrounding (daemon)" + print " - Flash security policy server" + if self.web: + print " - Web server" + if os.path.exists(self.cert): + print " - SSL/TLS support" + if self.ssl_only: + print " - Deny non-SSL/TLS connections" + + if self.daemon: + self.daemonize(self, keepfd=lsock.fileno()) + + # Reep zombies + signal.signal(signal.SIGCHLD, signal.SIG_IGN) + + while True: + try: + csock = startsock = None + pid = 0 + startsock, address = lsock.accept() + self.vmsg('%s: forking handler' % address[0]) + pid = os.fork() + + if pid == 0: + # handler process + csock = self.do_handshake(startsock, address) + self.handler(csock) + else: + # parent process + self.handler_id += 1 + + except self.EClose, exc: + # Connection was not a WebSockets connection + if exc.args[0]: + self.msg("%s: %s" % (address[0], exc.args[0])) + except KeyboardInterrupt, exc: + pass + except Exception, exc: + self.msg("handler exception: %s" % str(exc)) + if self.verbose: + self.msg(traceback.format_exc()) + finally: + if csock and csock != startsock: + csock.close() + if startsock: + startsock.close() + + if pid == 0: + break # Child process exits -class EClose(Exception): - pass # HTTP handler with request from a string and response to a socket class SplitHTTPHandler(SimpleHTTPRequestHandler): @@ -73,213 +379,3 @@ class SplitHTTPHandler(SimpleHTTPRequestHandler): self.last_message = f % args -def traffic(token="."): - if settings['verbose'] and not settings['daemon']: - sys.stdout.write(token) - sys.stdout.flush() - -def handler_msg(msg): - if not settings['daemon']: - print "% 3d: %s" % (settings['handler_id'], msg) - -def handler_vmsg(msg): - if settings['verbose']: handler_msg(msg) - -def encode(buf): - buf = b64encode(buf) - - return "\x00%s\xff" % buf - -def decode(buf): - """ Parse out WebSocket packets. """ - if buf.count('\xff') > 1: - return [b64decode(d[1:]) for d in buf.split('\xff')] - else: - return [b64decode(buf[1:-1])] - -def parse_handshake(handshake): - ret = {} - req_lines = handshake.split("\r\n") - if not req_lines[0].startswith("GET "): - raise Exception("Invalid handshake: no GET request line") - ret['path'] = req_lines[0].split(" ")[1] - for line in req_lines[1:]: - if line == "": break - var, val = line.split(": ") - ret[var] = val - - if req_lines[-2] == "": - ret['key3'] = req_lines[-1] - - return ret - -def gen_md5(keys): - key1 = keys['Sec-WebSocket-Key1'] - key2 = keys['Sec-WebSocket-Key2'] - key3 = keys['key3'] - spaces1 = key1.count(" ") - spaces2 = key2.count(" ") - num1 = int("".join([c for c in key1 if c.isdigit()])) / spaces1 - num2 = int("".join([c for c in key2 if c.isdigit()])) / spaces2 - - return md5(struct.pack('>II8s', num1, num2, key3)).digest() - - -def do_handshake(sock, address): - stype = "" - - # Peek, but don't read the data - handshake = sock.recv(1024, socket.MSG_PEEK) - #handler_msg("Handshake [%s]" % repr(handshake)) - if handshake == "": - raise EClose("ignoring empty handshake") - elif handshake.startswith(""): - handshake = sock.recv(1024) - sock.send(policy_response) - raise EClose("Sending flash policy response") - elif handshake[0] in ("\x16", "\x80"): - if not os.path.exists(settings['cert']): - raise EClose("SSL connection but '%s' not found" - % settings['cert']) - try: - retsock = ssl.wrap_socket( - sock, - server_side=True, - certfile=settings['cert'], - keyfile=settings['key']) - except ssl.SSLError, x: - if x.args[0] == ssl.SSL_ERROR_EOF: - raise EClose("") - else: - raise - - scheme = "wss" - stype = "SSL/TLS (wss://)" - elif settings['ssl_only']: - raise EClose("non-SSL connection received but disallowed") - else: - retsock = sock - scheme = "ws" - stype = "Plain non-SSL (ws://)" - - # Now get the data from the socket - handshake = retsock.recv(4096) - #handler_msg("handshake: " + repr(handshake)) - - if len(handshake) == 0: - raise EClose("Client closed during handshake") - - # Handle normal web requests - if handshake.startswith('GET ') and \ - handshake.find('Upgrade: WebSocket\r\n') == -1: - if not settings['web']: - raise EClose("Normal web request received but disallowed") - sh = SplitHTTPHandler(handshake, retsock, address) - if sh.last_code < 200 or sh.last_code >= 300: - raise EClose(sh.last_message) - elif settings['verbose']: - raise EClose(sh.last_message) - else: - raise EClose("") - - # Do WebSockets handshake and return the socket - h = parse_handshake(handshake) - - if h.get('key3'): - trailer = gen_md5(h) - pre = "Sec-" - ver = 76 - else: - trailer = "" - pre = "" - ver = 75 - - handler_msg("%s WebSocket connection (version %s) from %s" - % (stype, ver, address[0])) - - response = server_handshake % (pre, h['Origin'], pre, scheme, - h['Host'], h['path'], pre, trailer) - - #handler_msg("sending response:", repr(response)) - retsock.send(response) - return retsock - -def daemonize(keepfd=None): - os.umask(0) - os.chdir('/') - os.setgid(os.getgid()) # relinquish elevations - os.setuid(os.getuid()) # relinquish elevations - - # Double fork to daemonize - if os.fork() > 0: os._exit(0) # Parent exits - os.setsid() # Obtain new process group - if os.fork() > 0: os._exit(0) # Parent exits - - # Signal handling - def terminate(a,b): os._exit(0) - signal.signal(signal.SIGTERM, terminate) - signal.signal(signal.SIGINT, signal.SIG_IGN) - - # Close open files - maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if maxfd == resource.RLIM_INFINITY: maxfd = 256 - for fd in reversed(range(maxfd)): - try: - if fd != keepfd: - os.close(fd) - else: - handler_vmsg("Keeping fd: %d" % fd) - except OSError, exc: - if exc.errno != errno.EBADF: raise - - # Redirect I/O to /dev/null - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) - - -def start_server(): - - lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - lsock.bind((settings['listen_host'], settings['listen_port'])) - lsock.listen(100) - - if settings['daemon']: - daemonize(keepfd=lsock.fileno()) - - # Reep zombies - signal.signal(signal.SIGCHLD, signal.SIG_IGN) - - print 'Waiting for connections on %s:%s' % ( - settings['listen_host'], settings['listen_port']) - - while True: - try: - csock = startsock = None - pid = 0 - startsock, address = lsock.accept() - handler_vmsg('%s: forking handler' % address[0]) - pid = os.fork() - - if pid == 0: # handler process - csock = do_handshake(startsock, address) - settings['handler'](csock) - else: # parent process - settings['handler_id'] += 1 - - except EClose, exc: - if csock and csock != startsock: - csock.close() - startsock.close() - if exc.args[0]: - handler_msg("%s: %s" % (address[0], exc.args[0])) - except Exception, exc: - handler_msg("handler exception: %s" % str(exc)) - if settings['verbose']: - handler_msg(traceback.format_exc()) - - if pid == 0: - if csock: csock.close() - if startsock and startsock != csock: startsock.close() - break # Child process exits diff --git a/utils/wsproxy.py b/utils/wsproxy.py index 32a4021e..660a2a6b 100755 --- a/utils/wsproxy.py +++ b/utils/wsproxy.py @@ -11,14 +11,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' -import socket, optparse, time +import socket, optparse, time, os from select import select -from websocket import * +from websocket import WebSocketServer -buffer_size = 65536 -rec = None +class WebSocketProxy(WebSocketServer): + """ + Proxy traffic to and from a WebSockets client to a normal TCP + socket server target. All traffic to/from the client is base64 + encoded/decoded to allow binary data to be sent/received to/from + the target. + """ -traffic_legend = """ + buffer_size = 65536 + + traffic_legend = """ Traffic Legend: } - Client receive }. - Client receive partial @@ -30,101 +37,122 @@ Traffic Legend: <. - Client send partial """ -def do_proxy(client, target): - """ Proxy WebSocket to normal socket. """ - global rec - cqueue = [] - cpartial = "" - tqueue = [] - rlist = [client, target] - tstart = int(time.time()*1000) + def __init__(self, *args, **kwargs): + # Save off the target host:port + self.target_host = kwargs.pop('target_host') + self.target_port = kwargs.pop('target_port') + WebSocketServer.__init__(self, *args, **kwargs) - while True: - wlist = [] - tdelta = int(time.time()*1000) - tstart - if tqueue: wlist.append(target) - if cqueue: wlist.append(client) - ins, outs, excepts = select(rlist, wlist, [], 1) - if excepts: raise Exception("Socket exception") + def handler(self, client): + """ + Called after a new WebSocket connection has been established. + """ - if target in outs: - dat = tqueue.pop(0) - sent = target.send(dat) - if sent == len(dat): - traffic(">") - else: - tqueue.insert(0, dat[sent:]) - traffic(".>") - ##if rec: rec.write("Target send: %s\n" % map(ord, dat)) + self.rec = None + if self.record: + # Record raw frame data as a JavaScript compatible file + fname = "%s.%s" % (self.record, + self.handler_id) + self.msg("opening record file: %s" % fname) + self.rec = open(fname, 'w+') + self.rec.write("var VNC_frame_data = [\n") - if client in outs: - dat = cqueue.pop(0) - sent = client.send(dat) - if sent == len(dat): - traffic("<") - ##if rec: rec.write("Client send: %s ...\n" % repr(dat[0:80])) - if rec: rec.write("%s,\n" % repr("{%s{" % tdelta + dat[1:-1])) - else: - cqueue.insert(0, dat[sent:]) - traffic("<.") - ##if rec: rec.write("Client send partial: %s\n" % repr(dat[0:send])) + # Connect to the target + self.msg("connecting to: %s:%s" % ( + self.target_host, self.target_port)) + tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tsock.connect((self.target_host, self.target_port)) + if self.verbose and not self.daemon: + print self.traffic_legend - if target in ins: - buf = target.recv(buffer_size) - if len(buf) == 0: raise EClose("Target closed") + # Stat proxying + try: + self.do_proxy(client, tsock) + except: + if tsock: tsock.close() + if self.rec: + self.rec.write("'EOF']\n") + self.rec.close() + raise - cqueue.append(encode(buf)) - traffic("{") - ##if rec: rec.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf))) + def do_proxy(self, client, target): + """ + Proxy client WebSocket to normal target socket. + """ + cqueue = [] + cpartial = "" + tqueue = [] + rlist = [client, target] + tstart = int(time.time()*1000) - if client in ins: - buf = client.recv(buffer_size) - if len(buf) == 0: raise EClose("Client closed") + while True: + wlist = [] + tdelta = int(time.time()*1000) - tstart - if buf == '\xff\x00': - raise EClose("Client sent orderly close frame") - elif buf[-1] == '\xff': - if buf.count('\xff') > 1: - traffic(str(buf.count('\xff'))) - traffic("}") - ##if rec: rec.write("Client recv (%d): %s\n" % (len(buf), repr(buf))) - if rec: rec.write("%s,\n" % (repr("}%s}" % tdelta + buf[1:-1]))) - if cpartial: - tqueue.extend(decode(cpartial + buf)) - cpartial = "" + if tqueue: wlist.append(target) + if cqueue: wlist.append(client) + ins, outs, excepts = select(rlist, wlist, [], 1) + if excepts: raise Exception("Socket exception") + + if target in outs: + # Send queued client data to the target + dat = tqueue.pop(0) + sent = target.send(dat) + if sent == len(dat): + self.traffic(">") else: - tqueue.extend(decode(buf)) - else: - traffic(".}") - ##if rec: rec.write("Client recv partial (%d): %s\n" % (len(buf), repr(buf))) - cpartial = cpartial + buf + # requeue the remaining data + tqueue.insert(0, dat[sent:]) + self.traffic(".>") -def proxy_handler(client): - global target_host, target_port, options, rec, fname + if client in outs: + # Send queued target data to the client + dat = cqueue.pop(0) + sent = client.send(dat) + if sent == len(dat): + self.traffic("<") + if self.rec: + self.rec.write("%s,\n" % + repr("{%s{" % tdelta + dat[1:-1])) + else: + cqueue.insert(0, dat[sent:]) + self.traffic("<.") - if settings['record']: - fname = "%s.%s" % (settings['record'], - settings['handler_id']) - handler_msg("opening record file: %s" % fname) - rec = open(fname, 'w+') - rec.write("var VNC_frame_data = [\n") - handler_msg("connecting to: %s:%s" % (target_host, target_port)) - tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - tsock.connect((target_host, target_port)) + if target in ins: + # Receive target data, encode it and queue for client + buf = target.recv(self.buffer_size) + if len(buf) == 0: raise self.EClose("Target closed") - if settings['verbose'] and not settings['daemon']: - print traffic_legend + cqueue.append(self.encode(buf)) + self.traffic("{") - try: - do_proxy(client, tsock) - except: - if tsock: tsock.close() - if rec: - rec.write("'EOF']\n") - rec.close() - raise + if client in ins: + # Receive client data, decode it, and queue for target + buf = client.recv(self.buffer_size) + if len(buf) == 0: raise self.EClose("Client closed") + + if buf == '\xff\x00': + raise self.EClose("Client sent orderly close frame") + elif buf[-1] == '\xff': + if buf.count('\xff') > 1: + self.traffic(str(buf.count('\xff'))) + self.traffic("}") + if self.rec: + self.rec.write("%s,\n" % + (repr("}%s}" % tdelta + buf[1:-1]))) + if cpartial: + # Prepend saved partial and decode frame(s) + tqueue.extend(self.decode(cpartial + buf)) + cpartial = "" + else: + # decode frame(s) + tqueue.extend(self.decode(buf)) + else: + # Save off partial WebSockets frame + self.traffic(".}") + cpartial = cpartial + buf if __name__ == '__main__': usage = "%prog [--record FILE]" @@ -145,40 +173,31 @@ if __name__ == '__main__': help="disallow non-encrypted connections") parser.add_option("--web", default=None, metavar="DIR", help="run webserver on same port. Serve files from DIR.") - (options, args) = parser.parse_args() + (opts, args) = parser.parse_args() + # Sanity checks if len(args) > 2: parser.error("Too many arguments") if len(args) < 2: parser.error("Too few arguments") + + if opts.ssl_only and not os.path.exists(opts.cert): + parser.error("SSL only and %s not found" % opts.cert) + elif not os.path.exists(opts.cert): + print "Warning: %s not found" % opts.cert + + # Parse host:port and convert ports to numbers if args[0].count(':') > 0: - host,port = args[0].split(':') + opts.listen_host, opts.listen_port = args[0].split(':') else: - host,port = '',args[0] + opts.listen_host, opts.listen_port = '', args[0] if args[1].count(':') > 0: - target_host,target_port = args[1].split(':') + opts.target_host, opts.target_port = args[1].split(':') else: parser.error("Error parsing target") - try: port = int(port) + try: opts.listen_port = int(opts.listen_port) except: parser.error("Error parsing listen port") - try: target_port = int(target_port) + try: opts.target_port = int(opts.target_port) except: parser.error("Error parsing target port") - if options.ssl_only and not os.path.exists(options.cert): - parser.error("SSL only and %s not found" % options.cert) - elif not os.path.exists(options.cert): - print "Warning: %s not found" % options.cert - - settings['verbose'] = options.verbose - settings['listen_host'] = host - settings['listen_port'] = port - settings['handler'] = proxy_handler - settings['cert'] = os.path.abspath(options.cert) - if options.key: - settings['key'] = os.path.abspath(options.key) - settings['ssl_only'] = options.ssl_only - settings['daemon'] = options.daemon - if options.record: - settings['record'] = os.path.abspath(options.record) - if options.web: - os.chdir = options.web - settings['web'] = options.web - start_server() + # Create and start the WebSockets proxy + server = WebSocketProxy(**opts.__dict__) + server.start_server() diff --git a/utils/wstest.py b/utils/wstest.py new file mode 100755 index 00000000..0d005fa2 --- /dev/null +++ b/utils/wstest.py @@ -0,0 +1,171 @@ +#!/usr/bin/python + +''' +WebSocket server-side load test program. Sends and receives traffic +that has a random payload (length and content) that is checksummed and +given a sequence number. Any errors are reported and counted. +''' + +import sys, os, socket, ssl, time, traceback +import random, time +from select import select + +sys.path.insert(0,os.path.dirname(__file__) + "/../utils/") +from websocket import WebSocketServer + + +class WebSocketTest(WebSocketServer): + + buffer_size = 65536 + max_packet_size = 10000 + recv_cnt = 0 + send_cnt = 0 + + def __init__(self, *args, **kwargs): + self.errors = 0 + self.delay = kwargs.pop('delay') + + print "Prepopulating random array" + self.rand_array = [] + for i in range(0, self.max_packet_size): + self.rand_array.append(random.randint(0, 9)) + + WebSocketServer.__init__(self, *args, **kwargs) + + def handler(self, client): + self.send_cnt = 0 + self.recv_cnt = 0 + + try: + self.responder(client) + except: + print "accumulated errors:", self.errors + self.errors = 0 + raise + + def responder(self, client): + cqueue = [] + cpartial = "" + socks = [client] + last_send = time.time() * 1000 + + while True: + ins, outs, excepts = select(socks, socks, socks, 1) + if excepts: raise Exception("Socket exception") + + if client in ins: + buf = client.recv(self.buffer_size) + if len(buf) == 0: + raise self.EClose("Client closed") + #print "Client recv: %s (%d)" % (repr(buf[1:-1]), len(buf)) + if buf[-1] == '\xff': + if cpartial: + err = self.check(cpartial + buf) + cpartial = "" + else: + err = self.check(buf) + if err: + self.traffic("}") + self.errors = self.errors + 1 + print err + else: + self.traffic(">") + else: + self.traffic(".>") + cpartial = cpartial + buf + + now = time.time() * 1000 + if client in outs and now > (last_send + self.delay): + last_send = now + #print "Client send: %s" % repr(cqueue[0]) + client.send(self.generate()) + self.traffic("<") + + def generate(self): + length = random.randint(10, self.max_packet_size) + numlist = self.rand_array[self.max_packet_size-length:] + # Error in length + #numlist.append(5) + chksum = sum(numlist) + # Error in checksum + #numlist[0] = 5 + nums = "".join( [str(n) for n in numlist] ) + data = "^%d:%d:%d:%s$" % (self.send_cnt, length, chksum, nums) + self.send_cnt += 1 + + return WebSocketServer.encode(data) + + + def check(self, buf): + try: + data_list = WebSocketServer.decode(buf) + except: + print "\n" + repr(buf) + "" + return "Failed to decode" + + err = "" + for data in data_list: + if data.count('$') > 1: + raise Exception("Multiple parts within single packet") + if len(data) == 0: + self.traffic("_") + continue + + if data[0] != "^": + err += "buf did not start with '^'\n" + continue + + try: + cnt, length, chksum, nums = data[1:-1].split(':') + cnt = int(cnt) + length = int(length) + chksum = int(chksum) + except: + print "\n" + repr(data) + "" + err += "Invalid data format\n" + continue + + if self.recv_cnt != cnt: + err += "Expected count %d but got %d\n" % (self.recv_cnt, cnt) + self.recv_cnt = cnt + 1 + continue + + self.recv_cnt += 1 + + if len(nums) != length: + err += "Expected length %d but got %d\n" % (length, len(nums)) + continue + + inv = nums.translate(None, "0123456789") + if inv: + err += "Invalid characters found: %s\n" % inv + continue + + real_chksum = 0 + for num in nums: + real_chksum += int(num) + + if real_chksum != chksum: + err += "Expected checksum %d but real chksum is %d\n" % (chksum, real_chksum) + return err + + +if __name__ == '__main__': + try: + if len(sys.argv) < 2: raise + listen_port = int(sys.argv[1]) + if len(sys.argv) == 3: + delay = int(sys.argv[2]) + else: + delay = 10 + except: + print "Usage: %s [delay_ms]" % sys.argv[0] + sys.exit(1) + + server = WebSocketTest( + listen_port=listen_port, + verbose=True, + cert='self.pem', + web='.', + delay=delay) + server.start_server()