diff --git a/utils/wswrapper.c b/utils/wswrapper.c index ef9fa586..6b5e60f2 100644 --- a/utils/wswrapper.c +++ b/utils/wswrapper.c @@ -48,7 +48,6 @@ return -1; - const char _WS_response[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\ Upgrade: WebSocket\r\n\ Connection: Upgrade\r\n\ @@ -57,6 +56,17 @@ Connection: Upgrade\r\n\ %sWebSocket-Protocol: sample\r\n\ \r\n%s"; +#define WS_BUFSIZE 65536 + +typedef struct { + char rbuf[WS_BUFSIZE]; + char sbuf[WS_BUFSIZE]; + int rcarry_cnt; + char rcarry[3]; + int newframe; +} _WS_connection; + + /* * If WSWRAP_PORT environment variable is set then listen to the bind fd that * matches WSWRAP_PORT, otherwise listen to the first socket fd that bind is @@ -65,26 +75,12 @@ Connection: Upgrade\r\n\ int _WS_listen_fd = 0; int _WS_sockfd = 0; -typedef struct { - char _WS_rbuf[65536]; - char _WS_sbuf[65536]; -} _WS_connection; +_WS_connection * _WS_connections[65546]; -int _WS_bufsize = 65536; -char *_WS_rbuf = NULL; -char *_WS_sbuf = NULL; -int _WS_rcarry_cnt = 0; -char _WS_rcarry[3] = ""; -int _WS_newframe = 1; -int _WS_init() { - if (! (_WS_rbuf = malloc(_WS_bufsize)) ) { - return 0; - } - if (! (_WS_sbuf = malloc(_WS_bufsize)) ) { - return 0; - } -} +/* + * WebSocket handshake routines + */ int _WS_gen_md5(char *key1, char *key2, char *key3, char *target) { unsigned int i, spaces1 = 0, spaces2 = 0; @@ -246,13 +242,17 @@ int _WS_handshake(int sockfd) return ret; } +/* + * WebSockets recv and read interposer routine + */ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, size_t len, int flags) { + _WS_connection *ws = _WS_connections[sockfd]; int rawcount, deccount, left, rawlen, retlen, decodelen; int sockflags; int i; - char * fstart, * fend, * cstart; + char *fstart, *fend, *cstart; static void * (*rfunc)(), * (*rfunc2)(); if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv"); @@ -262,7 +262,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, return 0; } - if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) { + if (! ws) { // Not our file descriptor, just pass through if (recvf) { return (ssize_t) rfunc(sockfd, buf, len, flags); @@ -277,26 +277,26 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, retlen = 0; // first copy in any carry-over bytes - if (_WS_rcarry_cnt) { - if (_WS_rcarry_cnt == 1) { - DEBUG("Using carry byte: %u (", _WS_rcarry[0]); - } else if (_WS_rcarry_cnt == 2) { - DEBUG("Using carry bytes: %u,%u (", _WS_rcarry[0], - _WS_rcarry[1]); + if (ws->rcarry_cnt) { + if (ws->rcarry_cnt == 1) { + DEBUG("Using carry byte: %u (", ws->rcarry[0]); + } else if (ws->rcarry_cnt == 2) { + DEBUG("Using carry bytes: %u,%u (", ws->rcarry[0], + ws->rcarry[1]); } else { RET_ERROR(EIO, "Too many carry-over bytes\n"); } - if (len <= _WS_rcarry_cnt) { + if (len <= ws->rcarry_cnt) { DEBUG("final)\n"); - memcpy((char *) buf, _WS_rcarry, len); - _WS_rcarry_cnt -= len; + memcpy((char *) buf, ws->rcarry, len); + ws->rcarry_cnt -= len; return len; } else { DEBUG("prepending)\n"); - memcpy((char *) buf, _WS_rcarry, _WS_rcarry_cnt); - retlen += _WS_rcarry_cnt; - left -= _WS_rcarry_cnt; - _WS_rcarry_cnt = 0; + memcpy((char *) buf, ws->rcarry, ws->rcarry_cnt); + retlen += ws->rcarry_cnt; + left -= ws->rcarry_cnt; + ws->rcarry_cnt = 0; } } @@ -304,20 +304,20 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, rawcount = (left * 4) / 3 + 3; rawcount -= rawcount%4; - if (rawcount > _WS_bufsize - 1) { + if (rawcount > WS_BUFSIZE - 1) { RET_ERROR(ENOMEM, "recv of %d bytes is larger than buffer\n", rawcount); } i = 0; while (1) { // Peek at everything available - rawlen = (int) rfunc(sockfd, _WS_rbuf, _WS_bufsize-1, + rawlen = (int) rfunc(sockfd, ws->rbuf, WS_BUFSIZE-1, flags | MSG_PEEK); if (rawlen <= 0) { DEBUG("_WS_recv: returning because rawlen %d\n", rawlen); return (ssize_t) rawlen; } - fstart = _WS_rbuf; + fstart = ws->rbuf; /* while (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') { @@ -326,7 +326,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, } */ if (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') { - rawlen = (int) rfunc(sockfd, _WS_rbuf, 2, flags); + rawlen = (int) rfunc(sockfd, ws->rbuf, 2, flags); if (rawlen != 2) { RET_ERROR(EIO, "Could not strip empty frame headers\n"); } @@ -335,7 +335,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, fstart[rawlen] = '\x00'; - if (rawlen - _WS_newframe >= 4) { + if (rawlen - ws->newframe >= 4) { // We have enough to base64 decode at least 1 byte break; } @@ -362,19 +362,19 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, DEBUG("\n"); */ - if (_WS_newframe) { + if (ws->newframe) { if (fstart[0] != '\x00') { RET_ERROR(EPROTO, "Missing frame start\n"); } fstart++; rawlen--; - _WS_newframe = 0; + ws->newframe = 0; } fend = memchr(fstart, '\xff', rawlen); if (fend) { - _WS_newframe = 1; + ws->newframe = 1; if ((fend - fstart) % 4) { RET_ERROR(EPROTO, "Frame length is not multiple of 4\n"); } @@ -387,7 +387,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, // How much should we consume if (rawcount < fend - fstart) { - _WS_newframe = 0; + ws->newframe = 0; deccount = rawcount; } else { deccount = fend - fstart; @@ -397,7 +397,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, if (flags & MSG_PEEK) { MSG("*** Got MSG_PEEK ***\n"); } else { - rfunc(sockfd, _WS_rbuf, fstart - _WS_rbuf + deccount + _WS_newframe, flags); + rfunc(sockfd, ws->rbuf, fstart - ws->rbuf + deccount + ws->newframe, flags); } fstart[deccount] = '\x00'; // base64 terminator @@ -415,16 +415,16 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, if (! (flags & MSG_PEEK)) { // Add anything left over to the carry-over - _WS_rcarry_cnt = decodelen - left; - if (_WS_rcarry_cnt > 2) { + ws->rcarry_cnt = decodelen - left; + if (ws->rcarry_cnt > 2) { RET_ERROR(EPROTO, "Got too much base64 data\n"); } - memcpy(_WS_rcarry, buf + retlen, _WS_rcarry_cnt); - if (_WS_rcarry_cnt == 1) { - DEBUG("Saving carry byte: %u\n", _WS_rcarry[0]); - } else if (_WS_rcarry_cnt == 2) { - DEBUG("Saving carry bytes: %u,%u\n", _WS_rcarry[0], - _WS_rcarry[1]); + memcpy(ws->rcarry, buf + retlen, ws->rcarry_cnt); + if (ws->rcarry_cnt == 1) { + DEBUG("Saving carry byte: %u\n", ws->rcarry[0]); + } else if (ws->rcarry_cnt == 2) { + DEBUG("Saving carry bytes: %u,%u\n", ws->rcarry[0], + ws->rcarry[1]); } else { MSG("Waah2!\n"); } @@ -442,9 +442,13 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf, return retlen; } +/* + * WebSockets send and write interposer routine + */ ssize_t _WS_send(int sendf, int sockfd, const void *buf, size_t len, int flags) { + _WS_connection *ws = _WS_connections[sockfd]; int rawlen, enclen, rlen, over, left, clen, retlen, dbufsize; int sockflags; char * target; @@ -453,7 +457,7 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf, if (!sfunc) sfunc = (void *(*)()) dlsym(RTLD_NEXT, "send"); if (!sfunc2) sfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "write"); - if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) { + if (! ws) { // Not our file descriptor, just pass through if (sendf) { return (ssize_t) sfunc(sockfd, buf, len, flags); @@ -465,22 +469,22 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf, sockflags = fcntl(sockfd, F_GETFL, 0); - dbufsize = (_WS_bufsize * 3)/4 - 2; + dbufsize = (WS_BUFSIZE * 3)/4 - 2; if (len > dbufsize) { RET_ERROR(ENOMEM, "send of %d bytes is larger than send buffer\n", len); } // base64 encode and add frame markers rawlen = 0; - _WS_sbuf[rawlen++] = '\x00'; - enclen = b64_ntop(buf, len, _WS_sbuf+rawlen, _WS_bufsize-rawlen); + ws->sbuf[rawlen++] = '\x00'; + enclen = b64_ntop(buf, len, ws->sbuf+rawlen, WS_BUFSIZE-rawlen); if (enclen < 0) { RET_ERROR(EPROTO, "Base64 encoding error\n"); } rawlen += enclen; - _WS_sbuf[rawlen++] = '\xff'; + ws->sbuf[rawlen++] = '\xff'; - rlen = (int) sfunc(sockfd, _WS_sbuf, rawlen, flags); + rlen = (int) sfunc(sockfd, ws->sbuf, rawlen, flags); if (rlen <= 0) { return rlen; @@ -490,11 +494,11 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf, left = (4 - over) % 4 + 1; // left to send DEBUG("_WS_send: rlen: %d (over: %d, left: %d), rawlen: %d\n", rlen, over, left, rawlen); rlen += left; - _WS_sbuf[rlen-1] = '\xff'; + ws->sbuf[rlen-1] = '\xff'; i = 0; do { i++; - clen = (int) sfunc(sockfd, _WS_sbuf + rlen - left, left, flags); + clen = (int) sfunc(sockfd, ws->sbuf + rlen - left, left, flags); if (clen > 0) { left -= clen; } else if (clen == 0) { @@ -518,8 +522,8 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf, // Adjust for framing retlen = rlen - 2; // Adjust for base64 padding - if (_WS_sbuf[rlen-1] == '=') { retlen --; } - if (_WS_sbuf[rlen-2] == '=') { retlen --; } + if (ws->sbuf[rlen-1] == '=') { retlen --; } + if (ws->sbuf[rlen-2] == '=') { retlen --; } // Adjust for base64 encoding retlen = (retlen*3)/4; @@ -529,13 +533,15 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf, for (i = 0; i < retlen; i++) { DEBUG("%u,", (unsigned char) ((char *)buf)[i]); } - DEBUG(" as '%s' (%d)\n", _WS_sbuf+1, rlen); + DEBUG(" as '%s' (%d)\n", ws->sbuf+1, rlen); */ return (ssize_t) retlen; } -/* Override network routines */ +/* + * Overload (LD_PRELOAD) standard library network routines + */ /* int socket(int domain, int type, int protocol) @@ -603,24 +609,24 @@ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) return fd; } - if (_WS_sockfd == 0) { - // TODO: not just first connection - _WS_sockfd = fd; - - if (!_WS_rbuf) { - if (! _WS_init()) { - RET_ERROR(ENOMEM, "Could not allocate interposer buffer\n"); - } + if (_WS_connections[fd]) { + MSG("error, already interposing on fd %d\n", fd); + } else { + if (! (_WS_connections[fd] = malloc(sizeof(_WS_connection)))) { + RET_ERROR(ENOMEM, "Could not allocate interposer memory\n"); } + _WS_connections[fd]->rcarry_cnt = 0; + _WS_connections[fd]->rcarry[0] = '\0'; + _WS_connections[fd]->newframe = 1; - ret = _WS_handshake(_WS_sockfd); + ret = _WS_handshake(fd); if (ret < 0) { + free(_WS_connections[fd]); + _WS_connections[fd] = NULL; errno = EPROTO; return ret; } - MSG("interposing on fd %d\n", _WS_sockfd); - } else { - DEBUG("already interposing on fd %d\n", _WS_sockfd); + MSG("interposing on fd %d (allocated memory)\n", fd); } return fd; @@ -631,9 +637,10 @@ int close(int fd) static void * (*func)(); if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "close"); - if ((_WS_sockfd != 0) && (_WS_sockfd == fd)) { - MSG("finished interposing on fd %d\n", _WS_sockfd); - _WS_sockfd = 0; + if (_WS_connections[fd]) { + free(_WS_connections[fd]); + _WS_connections[fd] = NULL; + MSG("finished interposing on fd %d (freed memory)\n", fd); } return (int) func(fd); }