diff --git a/websocket.py b/websocket.py new file mode 100755 index 00000000..6320d1b3 --- /dev/null +++ b/websocket.py @@ -0,0 +1,113 @@ +#!/usr/bin/python + +''' +Python WebSocket library with support for "wss://" encryption. + +You can make a cert/key with openssl using: +openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem +as taken from http://docs.python.org/dev/library/ssl.html#certificates + +''' + +import sys, socket, ssl, traceback +from base64 import b64encode, b64decode + +client_settings = {} +send_seq = 0 + +server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r +Upgrade: WebSocket\r +Connection: Upgrade\r +WebSocket-Origin: %s\r +WebSocket-Location: %s://%s%s\r +WebSocket-Protocol: sample\r +\r +""" + +policy_response = """\n""" + +def traffic(token="."): + sys.stdout.write(token) + sys.stdout.flush() + +def decode(buf): + """ Parse out WebSocket packets. """ + if buf.count('\xff') > 1: + return [b64decode(d[1:]) for d in buf.split('\xff')] + else: + return [b64decode(buf[1:-1])] + +def encode(buf): + global send_seq + if client_settings.get("b64encode"): + buf = b64encode(buf) + + if client_settings.get("seq_num"): + send_seq += 1 + return "\x00%d:%s\xff" % (send_seq-1, buf) + else: + return "\x00%s\xff" % buf + + +def do_handshake(sock): + global client_settings, send_seq + send_seq = 0 + # Peek, but don't read the data + handshake = sock.recv(1024, socket.MSG_PEEK) + #print "Handshake [%s]" % repr(handshake) + if handshake.startswith(""): + handshake = sock.recv(1024) + print "Sending flash policy response" + sock.send(policy_response) + sock.close() + return False + elif handshake.startswith("\x16"): + retsock = ssl.wrap_socket( + sock, + server_side=True, + certfile='self.pem', + ssl_version=ssl.PROTOCOL_TLSv1) + scheme = "wss" + print "Using SSL/TLS" + else: + retsock = sock + scheme = "ws" + print "Using plain (not SSL) socket" + handshake = retsock.recv(4096) + req_lines = handshake.split("\r\n") + _, path, _ = req_lines[0].split(" ") + _, origin = req_lines[4].split(" ") + _, host = req_lines[3].split(" ") + + # Parse settings from the path + cvars = path.partition('?')[2].partition('#')[0].split('&') + for cvar in [c for c in cvars if c]: + name, _, value = cvar.partition('=') + client_settings[name] = value and value or True + + print "client_settings:", client_settings + + retsock.send(server_handshake % (origin, scheme, host, path)) + return retsock + +def start_server(listen_port, handler): + lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + lsock.bind(('', listen_port)) + lsock.listen(100) + while True: + try: + csock = None + print 'waiting for connection on port %s' % listen_port + startsock, address = lsock.accept() + print 'Got client connection from %s' % address[0] + csock = do_handshake(startsock) + if not csock: continue + + handler(csock) + + except Exception: + print "Ignoring exception:" + print traceback.format_exc() + if csock: csock.close() + diff --git a/wsproxy.py b/wsproxy.py index 6f6296a0..34ea575b 100755 --- a/wsproxy.py +++ b/wsproxy.py @@ -9,24 +9,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' -import sys, os, socket, ssl, time, traceback, re -from base64 import b64encode, b64decode +import sys, socket, ssl from select import select +from websocket import * buffer_size = 65536 -send_seq = 0 -client_settings = {} - -server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r -Upgrade: WebSocket\r -Connection: Upgrade\r -WebSocket-Origin: %s\r -WebSocket-Location: %s://%s%s\r -WebSocket-Protocol: sample\r -\r -""" - -policy_response = """\n""" traffic_legend = """ Traffic Legend: @@ -40,22 +27,8 @@ Traffic Legend: <. - Client send partial """ - -def traffic(token="."): - sys.stdout.write(token) - sys.stdout.flush() - -def decode(buf): - """ Parse out WebSocket packets. """ - if buf.count('\xff') > 1: - traffic(str(buf.count('\xff'))) - return [b64decode(d[1:]) for d in buf.split('\xff')] - else: - return [b64decode(buf[1:-1])] - -def proxy(client, target): +def do_proxy(client, target): """ Proxy WebSocket to normal socket. """ - global send_seq cqueue = [] cpartial = "" tqueue = [] @@ -66,15 +39,14 @@ def proxy(client, target): if excepts: raise Exception("Socket exception") if tqueue and target in outs: - #print "Target send: %s" % repr(tqueue[0]) - ##log.write("Target send: %s\n" % map(ord, tqueue[0])) dat = tqueue.pop(0) sent = target.send(dat) if sent == len(dat): traffic(">") else: tqueue.insert(0, dat[sent:]) - traffic(">.") + traffic(".>") + ##log.write("Target send: %s\n" % map(ord, dat)) if cqueue and client in outs: dat = cqueue.pop(0) @@ -92,24 +64,17 @@ def proxy(client, target): buf = target.recv(buffer_size) if len(buf) == 0: raise Exception("Target closed") - ##log.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf))) - - if client_settings.get("b64encode"): - buf = b64encode(buf) - - if client_settings.get("seq_num"): - cqueue.append("\x00%d:%s\xff" % (send_seq, buf)) - send_seq += 1 - else: - cqueue.append("\x00%s\xff" % buf) - + cqueue.append(encode(buf)) traffic("{") + ##log.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf))) if client in ins: buf = client.recv(buffer_size) if len(buf) == 0: raise Exception("Client closed") - if buf[-1] == "\xff": + if buf[-1] == '\xff': + if buf.count('\xff') > 1: + traffic(str(buf.count('\xff'))) traffic("}") ##log.write("Client recv (%d): %s\n" % (len(buf), repr(buf))) if cpartial: @@ -118,78 +83,24 @@ def proxy(client, target): else: tqueue.extend(decode(buf)) else: - traffic("}.") + traffic(".}") ##log.write("Client recv partial (%d): %s\n" % (len(buf), repr(buf))) cpartial = cpartial + buf +def proxy_handler(client): + global target_host, target_port -def do_handshake(sock): - global client_settings - # Peek, but don't read the data - handshake = sock.recv(1024, socket.MSG_PEEK) - #print "Handshake [%s]" % repr(handshake) - if handshake.startswith(""): - handshake = sock.recv(1024) - print "Sending flash policy response" - sock.send(policy_response) - sock.close() - return False - elif handshake.startswith("\x16"): - retsock = ssl.wrap_socket( - sock, - server_side=True, - certfile='self.pem', - ssl_version=ssl.PROTOCOL_TLSv1) - scheme = "wss" - print "Using SSL/TLS" - else: - retsock = sock - scheme = "ws" - print "Using plain (not SSL) socket" - handshake = retsock.recv(4096) - req_lines = handshake.split("\r\n") - _, path, _ = req_lines[0].split(" ") - _, origin = req_lines[4].split(" ") - _, host = req_lines[3].split(" ") + print "Connecting to: %s:%s" % (target_host, target_port) + tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + tsock.connect((target_host, target_port)) - # Parse settings from the path - cvars = path.partition('?')[2].partition('#')[0].split('&') - for cvar in [c for c in cvars if c]: - name, _, value = cvar.partition('=') - client_settings[name] = value and value or True - - print "client_settings:", client_settings - - retsock.send(server_handshake % (origin, scheme, host, path)) - return retsock - -def start_server(listen_port, target_host, target_port): - global send_seq - lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - lsock.bind(('', listen_port)) - lsock.listen(100) print traffic_legend - while True: - try: - csock = tsock = None - print 'waiting for connection on port %s' % listen_port - startsock, address = lsock.accept() - print 'Got client connection from %s' % address[0] - csock = do_handshake(startsock) - if not csock: continue - print "Connecting to: %s:%s" % (target_host, target_port) - tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - tsock.connect((target_host, target_port)) - send_seq = 0 - proxy(csock, tsock) - - except Exception: - print "Ignoring exception:" - print traceback.format_exc() - if csock: csock.close() - if tsock: tsock.close() + try: + do_proxy(client, tsock) + except: + if tsock: tsock.close() + raise if __name__ == '__main__': ##log = open("ws.log", 'w') @@ -201,4 +112,4 @@ if __name__ == '__main__': except: print "Usage: " sys.exit(1) - start_server(listen_port, target_host, target_port) + start_server(listen_port, proxy_handler) diff --git a/wstest.html b/wstest.html index bfc6bfc0..10fc8ff5 100644 --- a/wstest.html +++ b/wstest.html @@ -224,8 +224,8 @@ } window.onload = function() { - WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf"; console.log("onload"); + WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf"; var url = document.location.href; $('host').value = (url.match(/host=([^&#]*)/) || ['',''])[1]; $('port').value = (url.match(/port=([^&#]*)/) || ['',''])[1]; diff --git a/wstest.py b/wstest.py index 14ec257e..8bb1541d 100755 --- a/wstest.py +++ b/wstest.py @@ -1,50 +1,20 @@ #!/usr/bin/python -import sys, os, socket, time, traceback, random, time +''' +WebSocket server-side load test program. Sends and receives traffic +that has a random payload (length and content) that is checksummed and +given a sequence number. Any errors are reported and counted. +''' + +import sys, socket, ssl, time, traceback +import random, time from base64 import b64encode, b64decode from select import select +from websocket import * buffer_size = 65536 recv_cnt = send_cnt = 0 -server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r -Upgrade: WebSocket\r -Connection: Upgrade\r -WebSocket-Origin: %s\r -WebSocket-Location: ws://%s%s\r -WebSocket-Protocol: sample\r -\r -""" - -policy_response = """""" - -def do_handshake(client): - handshake = client.recv(1024) - print "Handshake [%s]" % handshake - if handshake.startswith(""): - print "Sending flash policy response" - client.send(policy_response) - client.close() - return False - req_lines = handshake.split("\r\n") - _, path, _ = req_lines[0].split(" ") - _, origin = req_lines[4].split(" ") - _, host = req_lines[3].split(" ") - client.send(server_handshake % (origin, host, path)) - return True - -def traffic(token="."): - sys.stdout.write(token) - sys.stdout.flush() - - -def decode(buf): - """ Parse out WebSocket packets. """ - if buf.count('\xff') > 1: - traffic(str(buf.count('\xff'))) - return [b64decode(d[1:]) for d in buf.split('\xff')] - else: - return [b64decode(buf[1:-1])] def check(buf): global recv_cnt @@ -103,7 +73,7 @@ def check(buf): def generate(): - global send_cnt + global send_cnt, rand_array length = random.randint(10, 100000) numlist = rand_array[100000-length:] # Error in length @@ -156,31 +126,19 @@ def responder(client, delay=10): client.send(generate()) traffic("<") -def start_server(listen_port, delay=10): - global errors, send_cnt, recv_cnt - lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - lsock.bind(('', listen_port)) - lsock.listen(100) - while True: - try: - csock = None - print 'listening on port %s' % listen_port - csock, address = lsock.accept() - print 'Got client connection from %s' % address[0] - if not do_handshake(csock): - continue +def test_handler(client): + global errors, delay, send_cnt, recv_cnt - send_cnt = 0 - recv_cnt = 0 - responder(csock, delay=delay) + send_cnt = 0 + recv_cnt = 0 + + try: + responder(client, delay) + except: + print "accumulated errors:", errors + errors = 0 + raise - except Exception: - print "accumulated errors:", errors - errors = 0 - print "Ignoring exception:" - print traceback.format_exc() - if csock: csock.close() if __name__ == '__main__': errors = 0 @@ -200,4 +158,4 @@ if __name__ == '__main__': for i in range(0, 100000): rand_array.append(random.randint(0, 9)) - start_server(listen_port, delay=delay) + start_server(listen_port, test_handler)