diff --git a/websocket.py b/websocket.py
new file mode 100755
index 00000000..6320d1b3
--- /dev/null
+++ b/websocket.py
@@ -0,0 +1,113 @@
+#!/usr/bin/python
+
+'''
+Python WebSocket library with support for "wss://" encryption.
+
+You can make a cert/key with openssl using:
+openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
+as taken from http://docs.python.org/dev/library/ssl.html#certificates
+
+'''
+
+import sys, socket, ssl, traceback
+from base64 import b64encode, b64decode
+
+client_settings = {}
+send_seq = 0
+
+server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
+Upgrade: WebSocket\r
+Connection: Upgrade\r
+WebSocket-Origin: %s\r
+WebSocket-Location: %s://%s%s\r
+WebSocket-Protocol: sample\r
+\r
+"""
+
+policy_response = """\n"""
+
+def traffic(token="."):
+ sys.stdout.write(token)
+ sys.stdout.flush()
+
+def decode(buf):
+ """ Parse out WebSocket packets. """
+ if buf.count('\xff') > 1:
+ return [b64decode(d[1:]) for d in buf.split('\xff')]
+ else:
+ return [b64decode(buf[1:-1])]
+
+def encode(buf):
+ global send_seq
+ if client_settings.get("b64encode"):
+ buf = b64encode(buf)
+
+ if client_settings.get("seq_num"):
+ send_seq += 1
+ return "\x00%d:%s\xff" % (send_seq-1, buf)
+ else:
+ return "\x00%s\xff" % buf
+
+
+def do_handshake(sock):
+ global client_settings, send_seq
+ send_seq = 0
+ # Peek, but don't read the data
+ handshake = sock.recv(1024, socket.MSG_PEEK)
+ #print "Handshake [%s]" % repr(handshake)
+ if handshake.startswith(""):
+ handshake = sock.recv(1024)
+ print "Sending flash policy response"
+ sock.send(policy_response)
+ sock.close()
+ return False
+ elif handshake.startswith("\x16"):
+ retsock = ssl.wrap_socket(
+ sock,
+ server_side=True,
+ certfile='self.pem',
+ ssl_version=ssl.PROTOCOL_TLSv1)
+ scheme = "wss"
+ print "Using SSL/TLS"
+ else:
+ retsock = sock
+ scheme = "ws"
+ print "Using plain (not SSL) socket"
+ handshake = retsock.recv(4096)
+ req_lines = handshake.split("\r\n")
+ _, path, _ = req_lines[0].split(" ")
+ _, origin = req_lines[4].split(" ")
+ _, host = req_lines[3].split(" ")
+
+ # Parse settings from the path
+ cvars = path.partition('?')[2].partition('#')[0].split('&')
+ for cvar in [c for c in cvars if c]:
+ name, _, value = cvar.partition('=')
+ client_settings[name] = value and value or True
+
+ print "client_settings:", client_settings
+
+ retsock.send(server_handshake % (origin, scheme, host, path))
+ return retsock
+
+def start_server(listen_port, handler):
+ lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ lsock.bind(('', listen_port))
+ lsock.listen(100)
+ while True:
+ try:
+ csock = None
+ print 'waiting for connection on port %s' % listen_port
+ startsock, address = lsock.accept()
+ print 'Got client connection from %s' % address[0]
+ csock = do_handshake(startsock)
+ if not csock: continue
+
+ handler(csock)
+
+ except Exception:
+ print "Ignoring exception:"
+ print traceback.format_exc()
+ if csock: csock.close()
+
diff --git a/wsproxy.py b/wsproxy.py
index 6f6296a0..34ea575b 100755
--- a/wsproxy.py
+++ b/wsproxy.py
@@ -9,24 +9,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
'''
-import sys, os, socket, ssl, time, traceback, re
-from base64 import b64encode, b64decode
+import sys, socket, ssl
from select import select
+from websocket import *
buffer_size = 65536
-send_seq = 0
-client_settings = {}
-
-server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
-Upgrade: WebSocket\r
-Connection: Upgrade\r
-WebSocket-Origin: %s\r
-WebSocket-Location: %s://%s%s\r
-WebSocket-Protocol: sample\r
-\r
-"""
-
-policy_response = """\n"""
traffic_legend = """
Traffic Legend:
@@ -40,22 +27,8 @@ Traffic Legend:
<. - Client send partial
"""
-
-def traffic(token="."):
- sys.stdout.write(token)
- sys.stdout.flush()
-
-def decode(buf):
- """ Parse out WebSocket packets. """
- if buf.count('\xff') > 1:
- traffic(str(buf.count('\xff')))
- return [b64decode(d[1:]) for d in buf.split('\xff')]
- else:
- return [b64decode(buf[1:-1])]
-
-def proxy(client, target):
+def do_proxy(client, target):
""" Proxy WebSocket to normal socket. """
- global send_seq
cqueue = []
cpartial = ""
tqueue = []
@@ -66,15 +39,14 @@ def proxy(client, target):
if excepts: raise Exception("Socket exception")
if tqueue and target in outs:
- #print "Target send: %s" % repr(tqueue[0])
- ##log.write("Target send: %s\n" % map(ord, tqueue[0]))
dat = tqueue.pop(0)
sent = target.send(dat)
if sent == len(dat):
traffic(">")
else:
tqueue.insert(0, dat[sent:])
- traffic(">.")
+ traffic(".>")
+ ##log.write("Target send: %s\n" % map(ord, dat))
if cqueue and client in outs:
dat = cqueue.pop(0)
@@ -92,24 +64,17 @@ def proxy(client, target):
buf = target.recv(buffer_size)
if len(buf) == 0: raise Exception("Target closed")
- ##log.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf)))
-
- if client_settings.get("b64encode"):
- buf = b64encode(buf)
-
- if client_settings.get("seq_num"):
- cqueue.append("\x00%d:%s\xff" % (send_seq, buf))
- send_seq += 1
- else:
- cqueue.append("\x00%s\xff" % buf)
-
+ cqueue.append(encode(buf))
traffic("{")
+ ##log.write("Target recv (%d): %s\n" % (len(buf), map(ord, buf)))
if client in ins:
buf = client.recv(buffer_size)
if len(buf) == 0: raise Exception("Client closed")
- if buf[-1] == "\xff":
+ if buf[-1] == '\xff':
+ if buf.count('\xff') > 1:
+ traffic(str(buf.count('\xff')))
traffic("}")
##log.write("Client recv (%d): %s\n" % (len(buf), repr(buf)))
if cpartial:
@@ -118,78 +83,24 @@ def proxy(client, target):
else:
tqueue.extend(decode(buf))
else:
- traffic("}.")
+ traffic(".}")
##log.write("Client recv partial (%d): %s\n" % (len(buf), repr(buf)))
cpartial = cpartial + buf
+def proxy_handler(client):
+ global target_host, target_port
-def do_handshake(sock):
- global client_settings
- # Peek, but don't read the data
- handshake = sock.recv(1024, socket.MSG_PEEK)
- #print "Handshake [%s]" % repr(handshake)
- if handshake.startswith(""):
- handshake = sock.recv(1024)
- print "Sending flash policy response"
- sock.send(policy_response)
- sock.close()
- return False
- elif handshake.startswith("\x16"):
- retsock = ssl.wrap_socket(
- sock,
- server_side=True,
- certfile='self.pem',
- ssl_version=ssl.PROTOCOL_TLSv1)
- scheme = "wss"
- print "Using SSL/TLS"
- else:
- retsock = sock
- scheme = "ws"
- print "Using plain (not SSL) socket"
- handshake = retsock.recv(4096)
- req_lines = handshake.split("\r\n")
- _, path, _ = req_lines[0].split(" ")
- _, origin = req_lines[4].split(" ")
- _, host = req_lines[3].split(" ")
+ print "Connecting to: %s:%s" % (target_host, target_port)
+ tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ tsock.connect((target_host, target_port))
- # Parse settings from the path
- cvars = path.partition('?')[2].partition('#')[0].split('&')
- for cvar in [c for c in cvars if c]:
- name, _, value = cvar.partition('=')
- client_settings[name] = value and value or True
-
- print "client_settings:", client_settings
-
- retsock.send(server_handshake % (origin, scheme, host, path))
- return retsock
-
-def start_server(listen_port, target_host, target_port):
- global send_seq
- lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- lsock.bind(('', listen_port))
- lsock.listen(100)
print traffic_legend
- while True:
- try:
- csock = tsock = None
- print 'waiting for connection on port %s' % listen_port
- startsock, address = lsock.accept()
- print 'Got client connection from %s' % address[0]
- csock = do_handshake(startsock)
- if not csock: continue
- print "Connecting to: %s:%s" % (target_host, target_port)
- tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- tsock.connect((target_host, target_port))
- send_seq = 0
- proxy(csock, tsock)
-
- except Exception:
- print "Ignoring exception:"
- print traceback.format_exc()
- if csock: csock.close()
- if tsock: tsock.close()
+ try:
+ do_proxy(client, tsock)
+ except:
+ if tsock: tsock.close()
+ raise
if __name__ == '__main__':
##log = open("ws.log", 'w')
@@ -201,4 +112,4 @@ if __name__ == '__main__':
except:
print "Usage: "
sys.exit(1)
- start_server(listen_port, target_host, target_port)
+ start_server(listen_port, proxy_handler)
diff --git a/wstest.html b/wstest.html
index bfc6bfc0..10fc8ff5 100644
--- a/wstest.html
+++ b/wstest.html
@@ -224,8 +224,8 @@
}
window.onload = function() {
- WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf";
console.log("onload");
+ WebSocket.__swfLocation = "include/web-socket-js/WebSocketMain.swf";
var url = document.location.href;
$('host').value = (url.match(/host=([^]*)/) || ['',''])[1];
$('port').value = (url.match(/port=([^]*)/) || ['',''])[1];
diff --git a/wstest.py b/wstest.py
index 14ec257e..8bb1541d 100755
--- a/wstest.py
+++ b/wstest.py
@@ -1,50 +1,20 @@
#!/usr/bin/python
-import sys, os, socket, time, traceback, random, time
+'''
+WebSocket server-side load test program. Sends and receives traffic
+that has a random payload (length and content) that is checksummed and
+given a sequence number. Any errors are reported and counted.
+'''
+
+import sys, socket, ssl, time, traceback
+import random, time
from base64 import b64encode, b64decode
from select import select
+from websocket import *
buffer_size = 65536
recv_cnt = send_cnt = 0
-server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
-Upgrade: WebSocket\r
-Connection: Upgrade\r
-WebSocket-Origin: %s\r
-WebSocket-Location: ws://%s%s\r
-WebSocket-Protocol: sample\r
-\r
-"""
-
-policy_response = """"""
-
-def do_handshake(client):
- handshake = client.recv(1024)
- print "Handshake [%s]" % handshake
- if handshake.startswith(""):
- print "Sending flash policy response"
- client.send(policy_response)
- client.close()
- return False
- req_lines = handshake.split("\r\n")
- _, path, _ = req_lines[0].split(" ")
- _, origin = req_lines[4].split(" ")
- _, host = req_lines[3].split(" ")
- client.send(server_handshake % (origin, host, path))
- return True
-
-def traffic(token="."):
- sys.stdout.write(token)
- sys.stdout.flush()
-
-
-def decode(buf):
- """ Parse out WebSocket packets. """
- if buf.count('\xff') > 1:
- traffic(str(buf.count('\xff')))
- return [b64decode(d[1:]) for d in buf.split('\xff')]
- else:
- return [b64decode(buf[1:-1])]
def check(buf):
global recv_cnt
@@ -103,7 +73,7 @@ def check(buf):
def generate():
- global send_cnt
+ global send_cnt, rand_array
length = random.randint(10, 100000)
numlist = rand_array[100000-length:]
# Error in length
@@ -156,31 +126,19 @@ def responder(client, delay=10):
client.send(generate())
traffic("<")
-def start_server(listen_port, delay=10):
- global errors, send_cnt, recv_cnt
- lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- lsock.bind(('', listen_port))
- lsock.listen(100)
- while True:
- try:
- csock = None
- print 'listening on port %s' % listen_port
- csock, address = lsock.accept()
- print 'Got client connection from %s' % address[0]
- if not do_handshake(csock):
- continue
+def test_handler(client):
+ global errors, delay, send_cnt, recv_cnt
- send_cnt = 0
- recv_cnt = 0
- responder(csock, delay=delay)
+ send_cnt = 0
+ recv_cnt = 0
+
+ try:
+ responder(client, delay)
+ except:
+ print "accumulated errors:", errors
+ errors = 0
+ raise
- except Exception:
- print "accumulated errors:", errors
- errors = 0
- print "Ignoring exception:"
- print traceback.format_exc()
- if csock: csock.close()
if __name__ == '__main__':
errors = 0
@@ -200,4 +158,4 @@ if __name__ == '__main__':
for i in range(0, 100000):
rand_array.append(random.randint(0, 9))
- start_server(listen_port, delay=delay)
+ start_server(listen_port, test_handler)