diff --git a/utils/websocket.py b/utils/websocket.py index b931aa8f..3414f528 100755 --- a/utils/websocket.py +++ b/utils/websocket.py @@ -144,16 +144,30 @@ Sec-WebSocket-Accept: %s\r # @staticmethod - def addrinfo(host, port=None): - """ Resolve a host (and optional port) to an IPv4 or IPv6 address. - Returns: family, socktype, proto, canonname, sockaddr + def socket(host, port=None, connect=False, prefer_ipv6=False): + """ Resolve a host (and optional port) to an IPv4 or IPv6 + address. Create a socket. Bind to it if listen is set. Return + a socket that is ready for listen or connect. """ - if not host: - host = 'localhost' - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP) + flags = 0 + if host == '': host = None + if not connect: + flags = flags | socket.AI_PASSIVE + addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, + socket.IPPROTO_TCP, flags) if not addrs: raise Exception("Could resolve host '%s'" % host) - return addrs[0] + addrs.sort(key=lambda x: x[0]) + if prefer_ipv6: + addrs.reverse() + sock = socket.socket(addrs[0][0], addrs[0][1]) + if connect: + sock.connect(addrs[0][4]) + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addrs[0][4]) + sock.listen(100) + return sock @staticmethod def daemonize(keepfd=None, chdir='/'): @@ -738,11 +752,7 @@ Sec-WebSocket-Accept: %s\r is a WebSockets client then call new_client() method (which must be overridden) for each new client connection. """ - addr = self.addrinfo(self.listen_host, self.listen_port) - lsock = socket.socket(addr[0], addr[1]) - lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - lsock.bind((self.listen_host, self.listen_port)) - lsock.listen(100) + lsock = self.socket(self.listen_host, self.listen_port) if self.daemon: self.daemonize(keepfd=lsock.fileno(), chdir=self.web) diff --git a/utils/websockify b/utils/websockify index e05ecabe..e82e364b 100755 --- a/utils/websockify +++ b/utils/websockify @@ -141,9 +141,8 @@ Traffic Legend: # Connect to the target self.msg("connecting to: %s:%s" % ( self.target_host, self.target_port)) - addr = self.addrinfo(self.target_host, self.target_port) - tsock = socket.socket(addr[0], addr[1]) - tsock.connect((self.target_host, self.target_port)) + tsock = self.socket(self.target_host, self.target_port, + connect=True) if self.verbose and not self.daemon: print(self.traffic_legend)