diff --git a/utils/Makefile b/utils/Makefile index bd824f84..099d4790 100644 --- a/utils/Makefile +++ b/utils/Makefile @@ -1,11 +1,7 @@ wsproxy: wsproxy.o websocket.o $(CC) $^ -l ssl -l resolv -o $@ -#websocket.o: websocket.c -# $(CC) -c $^ -o $@ -# -#wsproxy.o: wsproxy.c -# $(CC) -c $^ -o $@ +websocket.o wsproxy.o: websocket.h clean: rm -f wsproxy wsproxy.o websocket.o diff --git a/utils/websocket.c b/utils/websocket.c index eb4e1a6a..3b555e43 100644 --- a/utils/websocket.c +++ b/utils/websocket.c @@ -296,12 +296,13 @@ ws_ctx_t *do_handshake(int sock) { return NULL; } else if (bcmp(handshake, "\x16", 1) == 0) { // SSL + if (! settings.cert) { return NULL; } ws_ctx = ws_socket_ssl(sock, settings.cert); if (! ws_ctx) { return NULL; } scheme = "wss"; printf(" using SSL socket\n"); } else if (settings.ssl_only) { - printf("Non-SSL connection disallowed"); + printf("Non-SSL connection disallowed\n"); close(sock); return NULL; } else { @@ -401,10 +402,6 @@ void start_server() { struct sockaddr_in serv_addr, cli_addr; ws_ctx_t *ws_ctx; - if (settings.daemon) { - daemonize(); - } - /* Initialize buffers */ bufsize = 65536; if (! (tbuf = malloc(bufsize)) ) @@ -416,6 +413,10 @@ void start_server() { if (! (cbuf_tmp = malloc(bufsize)) ) { fatal("malloc()"); } + if (settings.daemon) { + daemonize(); + } + lsock = socket(AF_INET, SOCK_STREAM, 0); if (lsock < 0) { error("ERROR creating listener socket"); } bzero((char *) &serv_addr, sizeof(serv_addr)); diff --git a/utils/websocket.h b/utils/websocket.h index 9520018f..42d55b82 100644 --- a/utils/websocket.h +++ b/utils/websocket.h @@ -1,5 +1,4 @@ #include -#include typedef struct { int sockfd; @@ -13,8 +12,8 @@ typedef struct { void (*handler)(ws_ctx_t*); int ssl_only; int daemon; - char record[1024]; - char cert[1024]; + char *record; + char *cert; } settings_t; typedef struct { diff --git a/utils/wsproxy.c b/utils/wsproxy.c index a03c3a74..b522a117 100644 --- a/utils/wsproxy.c +++ b/utils/wsproxy.c @@ -7,6 +7,7 @@ */ #include #include +#include #include #include #include @@ -198,8 +199,8 @@ void proxy_handler(ws_ctx_t *ws_ctx) { int tsock = 0; struct sockaddr_in taddr; - if (settings.record) { - recordfd = open(settings.record, O_WRONLY | O_CREAT | O_TRUNC, + if (settings.record && settings.record[0] != '\0') { + recordfd = open(settings.record, O_WRONLY | O_CREAT | O_APPEND, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); } @@ -238,7 +239,7 @@ void proxy_handler(ws_ctx_t *ws_ctx) { int main(int argc, char *argv[]) { - int listen_port, c, option_index = 0; + int listen_port, fd, c, option_index = 0; static int ssl_only = 0, foreground = 0; char *found; static struct option long_options[] = { @@ -250,8 +251,8 @@ int main(int argc, char *argv[]) {0, 0, 0, 0} }; - settings.record[0] = '\0'; - strcpy(settings.cert, "self.pem"); + settings.record = NULL; + settings.cert = realpath("self.pem", NULL); while (1) { c = getopt_long (argc, argv, "fr:c:", @@ -269,10 +270,18 @@ int main(int argc, char *argv[]) foreground = 1; break; case 'r': - memcpy(settings.record, optarg, sizeof(settings.record)); + if ((fd = open(optarg, O_CREAT, + S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)) < -1) { + fatal("Could not access %s\n", optarg); + } + close(fd); + settings.record = realpath(optarg, NULL); break; case 'c': - memcpy(settings.cert, optarg, sizeof(settings.cert)); + settings.cert = realpath(optarg, NULL); + if (! settings.cert) { + fatal("No cert file at %s\n", optarg); + } break; default: usage(); @@ -314,17 +323,6 @@ int main(int argc, char *argv[]) usage(); } - /* Initialize buffers */ - bufsize = 65536; - if (! (tbuf = malloc(bufsize)) ) - { fatal("malloc()"); } - if (! (cbuf = malloc(bufsize)) ) - { fatal("malloc()"); } - if (! (tbuf_tmp = malloc(bufsize)) ) - { fatal("malloc()"); } - if (! (cbuf_tmp = malloc(bufsize)) ) - { fatal("malloc()"); } - settings.handler = proxy_handler; start_server(); diff --git a/utils/wsproxy.py b/utils/wsproxy.py index 64aeaa80..177df16c 100755 --- a/utils/wsproxy.py +++ b/utils/wsproxy.py @@ -101,7 +101,7 @@ def proxy_handler(client): if settings['record']: print "Opening record file: %s" % settings['record'] - rec = open(settings['record'], 'w') + rec = open(settings['record'], 'a') print "Connecting to: %s:%s" % (target_host, target_port) tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -152,5 +152,6 @@ if __name__ == '__main__': settings['cert'] = os.path.abspath(options.cert) settings['ssl_only'] = options.ssl_only settings['daemon'] = options.daemon - settings['record'] = os.path.abspath(options.record) + if options.record: + settings['record'] = os.path.abspath(options.record) start_server()