Automatically detect TLS/SSL during handshake.

Use MSG_PEEK flag on recv to detect whether we are getting a flash
policy request, an SSL/TLS header, or a plain socket connection.
This commit is contained in:
Joel Martin 2010-04-30 15:54:59 -05:00
parent 0e486e1ba0
commit ca5785f570
1 changed files with 45 additions and 29 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/python #!/usr/bin/python
import sys, os, socket, time, traceback, re import sys, os, socket, ssl, time, traceback, re
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from select import select from select import select
@ -12,7 +12,7 @@ server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
Upgrade: WebSocket\r Upgrade: WebSocket\r
Connection: Upgrade\r Connection: Upgrade\r
WebSocket-Origin: %s\r WebSocket-Origin: %s\r
WebSocket-Location: ws://%s%s\r WebSocket-Location: %s://%s%s\r
WebSocket-Protocol: sample\r WebSocket-Protocol: sample\r
\r \r
""" """
@ -32,31 +32,6 @@ Traffic Legend:
""" """
def do_handshake(client):
global client_settings
handshake = client.recv(1024)
#print "Handshake [%s]" % handshake
if handshake.startswith("<policy-file-request/>"):
print "Sending flash policy response"
client.send(policy_response)
client.close()
return False
req_lines = handshake.split("\r\n")
_, path, _ = req_lines[0].split(" ")
_, origin = req_lines[4].split(" ")
_, host = req_lines[3].split(" ")
# Parse settings from the path
cvars = path.partition('?')[2].partition('#')[0].split('&')
for cvar in [c for c in cvars if c]:
name, _, value = cvar.partition('=')
client_settings[name] = value and value or True
print "client_settings:", client_settings
client.send(server_handshake % (origin, host, path))
return True
def traffic(token="."): def traffic(token="."):
sys.stdout.write(token) sys.stdout.write(token)
sys.stdout.flush() sys.stdout.flush()
@ -139,6 +114,46 @@ def proxy(client, target):
cpartial = cpartial + buf cpartial = cpartial + buf
def do_handshake(sock):
global client_settings
# Peek, but don't read the data
handshake = sock.recv(1024, socket.MSG_PEEK)
#print "Handshake [%s]" % repr(handshake)
if handshake.startswith("<policy-file-request/>"):
handshake = sock.recv(1024)
print "Sending flash policy response"
sock.send(policy_response)
sock.close()
return False
elif handshake.startswith("\x16"):
retsock = ssl.wrap_socket(
sock,
server_side=True,
certfile='wsproxy.pem',
ssl_version=ssl.PROTOCOL_TLSv1)
scheme = "wss"
print "Using SSL/TLS"
else:
retsock = sock
scheme = "ws"
print "Using plain (not SSL) socket"
handshake = retsock.recv(4096)
req_lines = handshake.split("\r\n")
_, path, _ = req_lines[0].split(" ")
_, origin = req_lines[4].split(" ")
_, host = req_lines[3].split(" ")
# Parse settings from the path
cvars = path.partition('?')[2].partition('#')[0].split('&')
for cvar in [c for c in cvars if c]:
name, _, value = cvar.partition('=')
client_settings[name] = value and value or True
print "client_settings:", client_settings
retsock.send(server_handshake % (origin, scheme, host, path))
return retsock
def start_server(listen_port, target_host, target_port): def start_server(listen_port, target_host, target_port):
global send_seq global send_seq
lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -150,9 +165,10 @@ def start_server(listen_port, target_host, target_port):
try: try:
csock = tsock = None csock = tsock = None
print 'waiting for connection on port %s' % listen_port print 'waiting for connection on port %s' % listen_port
csock, address = lsock.accept() startsock, address = lsock.accept()
print 'Got client connection from %s' % address[0] print 'Got client connection from %s' % address[0]
if not do_handshake(csock): continue csock = do_handshake(startsock)
if not csock: continue
print "Connecting to: %s:%s" % (target_host, target_port) print "Connecting to: %s:%s" % (target_host, target_port)
tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tsock.connect((target_host, target_port)) tsock.connect((target_host, target_port))