diff --git a/wsproxy.py b/wsproxy.py index b47cb9b1..a1710b19 100755 --- a/wsproxy.py +++ b/wsproxy.py @@ -1,6 +1,6 @@ #!/usr/bin/python -import sys, os, socket, time, traceback, re +import sys, os, socket, ssl, time, traceback, re from base64 import b64encode, b64decode from select import select @@ -12,7 +12,7 @@ server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r Upgrade: WebSocket\r Connection: Upgrade\r WebSocket-Origin: %s\r -WebSocket-Location: ws://%s%s\r +WebSocket-Location: %s://%s%s\r WebSocket-Protocol: sample\r \r """ @@ -32,31 +32,6 @@ Traffic Legend: """ -def do_handshake(client): - global client_settings - handshake = client.recv(1024) - #print "Handshake [%s]" % handshake - if handshake.startswith(""): - print "Sending flash policy response" - client.send(policy_response) - client.close() - return False - req_lines = handshake.split("\r\n") - _, path, _ = req_lines[0].split(" ") - _, origin = req_lines[4].split(" ") - _, host = req_lines[3].split(" ") - - # Parse settings from the path - cvars = path.partition('?')[2].partition('#')[0].split('&') - for cvar in [c for c in cvars if c]: - name, _, value = cvar.partition('=') - client_settings[name] = value and value or True - - print "client_settings:", client_settings - - client.send(server_handshake % (origin, host, path)) - return True - def traffic(token="."): sys.stdout.write(token) sys.stdout.flush() @@ -139,6 +114,46 @@ def proxy(client, target): cpartial = cpartial + buf +def do_handshake(sock): + global client_settings + # Peek, but don't read the data + handshake = sock.recv(1024, socket.MSG_PEEK) + #print "Handshake [%s]" % repr(handshake) + if handshake.startswith(""): + handshake = sock.recv(1024) + print "Sending flash policy response" + sock.send(policy_response) + sock.close() + return False + elif handshake.startswith("\x16"): + retsock = ssl.wrap_socket( + sock, + server_side=True, + certfile='wsproxy.pem', + ssl_version=ssl.PROTOCOL_TLSv1) + scheme = "wss" + print "Using SSL/TLS" + else: + retsock = sock + scheme = "ws" + print "Using plain (not SSL) socket" + handshake = retsock.recv(4096) + req_lines = handshake.split("\r\n") + _, path, _ = req_lines[0].split(" ") + _, origin = req_lines[4].split(" ") + _, host = req_lines[3].split(" ") + + # Parse settings from the path + cvars = path.partition('?')[2].partition('#')[0].split('&') + for cvar in [c for c in cvars if c]: + name, _, value = cvar.partition('=') + client_settings[name] = value and value or True + + print "client_settings:", client_settings + + retsock.send(server_handshake % (origin, scheme, host, path)) + return retsock + def start_server(listen_port, target_host, target_port): global send_seq lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -150,9 +165,10 @@ def start_server(listen_port, target_host, target_port): try: csock = tsock = None print 'waiting for connection on port %s' % listen_port - csock, address = lsock.accept() + startsock, address = lsock.accept() print 'Got client connection from %s' % address[0] - if not do_handshake(csock): continue + csock = do_handshake(startsock) + if not csock: continue print "Connecting to: %s:%s" % (target_host, target_port) tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tsock.connect((target_host, target_port))