Cleanup of TLS code

This commit is contained in:
Sara Dickinson 2015-04-16 18:01:17 +01:00
parent 99aa79b48f
commit 99c1973fae
4 changed files with 20 additions and 63 deletions

View File

@ -1145,15 +1145,15 @@ set_ub_dns_transport(struct getdns_context* context,
case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN:
case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN:
/* TODO: Investigate why ssl-upstream in Unbound isn't working (error /* TODO: Investigate why ssl-upstream in Unbound isn't working (error
that the SSL lib isn't init'ed but that is done in prep_for_res.*/ * that the SSL lib isn't init'ed but that is done in prep_for_res.
/* Note: no fallback or pipelining available directly in unbound.*/ * Note: no fallback or pipelining available directly in unbound.*/
set_ub_string_opt(context, "do-udp:", "no"); set_ub_string_opt(context, "do-udp:", "no");
set_ub_string_opt(context, "do-tcp:", "yes"); set_ub_string_opt(context, "do-tcp:", "yes");
//set_ub_string_opt(context, "ssl-upstream:", "yes"); /* set_ub_string_opt(context, "ssl-upstream:", "yes");*/
/* TODO: Specifying a different port to do TLS on in unbound is a bit /* TODO: Specifying a different port to do TLS on in unbound is a bit
tricky as it involves modifying each fwd upstream defined on the * tricky as it involves modifying each fwd upstream defined on the
unbound ctx... And to support fallback this would have to be reset * unbound ctx... And to support fallback this would have to be reset
from the stub code while trying to connect...*/ * from the stub code while trying to connect...*/
break; break;
default: default:
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
@ -1171,10 +1171,10 @@ getdns_context_set_dns_transport(struct getdns_context *context,
{ {
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
/* Note that the call below does not have any effect in unbound after the /* Note that the call below does not have any effect in unbound after the
ctx is finalised. So will not apply for recursive mode or stub + dnssec. * ctx is finalised. So will not apply for recursive mode or stub + dnssec.
However the method returns success as otherwise the transport could not * However the method returns success as otherwise the transport could not
be reset for stub mode..... */ * be reset for stub mode.....
/* Also, not all transport options supported in libunbound yet*/ * Also, not all transport options supported in libunbound yet */
if (set_ub_dns_transport(context, value) != GETDNS_RETURN_GOOD) { if (set_ub_dns_transport(context, value) != GETDNS_RETURN_GOOD) {
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
} }
@ -1799,9 +1799,7 @@ getdns_context_prepare_for_resolution(struct getdns_context *context,
return GETDNS_RETURN_BAD_CONTEXT; return GETDNS_RETURN_BAD_CONTEXT;
} }
/* Transport can in theory be set per query in stub mode so deal with it /* Transport can in theory be set per query in stub mode */
here */
printf("[TLS] preparing for resolution, checking transport type\n");
if (context->resolution_type == GETDNS_RESOLUTION_STUB) { if (context->resolution_type == GETDNS_RESOLUTION_STUB) {
switch (context->dns_transport) { switch (context->dns_transport) {
case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN:
@ -1809,13 +1807,10 @@ getdns_context_prepare_for_resolution(struct getdns_context *context,
if (context->tls_ctx == NULL) { if (context->tls_ctx == NULL) {
/* Init the SSL library */ /* Init the SSL library */
SSL_library_init(); SSL_library_init();
/* Load error messages */
SSL_load_error_strings();
/* Create client context, use TLS v1.2 only for now */ /* Create client context, use TLS v1.2 only for now */
SSL_CTX* tls_ctx = SSL_CTX_new(TLSv1_2_client_method()); SSL_CTX* tls_ctx = SSL_CTX_new(TLSv1_2_client_method());
if(!tls_ctx) { if(!tls_ctx) {
ERR_print_errors_fp(stderr);
return GETDNS_RETURN_BAD_CONTEXT; return GETDNS_RETURN_BAD_CONTEXT;
} }
context->tls_ctx = tls_ctx; context->tls_ctx = tls_ctx;

View File

@ -98,7 +98,6 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner,
net_req->write_queue_tail = NULL; net_req->write_queue_tail = NULL;
net_req->query_len = 0; net_req->query_len = 0;
net_req->response_len = 0; net_req->response_len = 0;
net_req->tls_obj = NULL;
net_req->wire_data_sz = wire_data_sz; net_req->wire_data_sz = wire_data_sz;
if (max_query_sz) { if (max_query_sz) {

View File

@ -320,8 +320,8 @@ upstream_erred(getdns_upstream *upstream)
netreq->state = NET_REQ_FINISHED; netreq->state = NET_REQ_FINISHED;
priv_getdns_check_dns_req_complete(netreq->owner); priv_getdns_check_dns_req_complete(netreq->owner);
} }
// TODO[TLS]: When we get an error (which is probably a timeout) and are /* TODO[TLS]: When we get an error (which is probably a timeout) and are
// using to keep connections open should we leave the connection up here? * using to keep connections open should we leave the connection up here? */
if (upstream->tls_obj) { if (upstream->tls_obj) {
SSL_shutdown(upstream->tls_obj); SSL_shutdown(upstream->tls_obj);
SSL_free(upstream->tls_obj); SSL_free(upstream->tls_obj);
@ -507,8 +507,6 @@ stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf)
uint8_t *buf; uint8_t *buf;
size_t buf_size; size_t buf_size;
fprintf(stderr, "[TLS] method: stub_tcp_read\n");
if (!tcp->read_buf) { if (!tcp->read_buf) {
/* First time tcp read, create a buffer for reading */ /* First time tcp read, create a buffer for reading */
if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096)))
@ -529,7 +527,6 @@ stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf)
/* TODO: Try to reconnect */ /* TODO: Try to reconnect */
return STUB_TCP_ERROR; return STUB_TCP_ERROR;
} }
fprintf(stderr, "[TLS] method: read %d TCP bytes \n", (int)read);
tcp->to_read -= read; tcp->to_read -= read;
tcp->read_pos += read; tcp->read_pos += read;
@ -575,16 +572,13 @@ stub_tcp_read_cb(void *userarg)
&dnsreq->context->mf))) { &dnsreq->context->mf))) {
case STUB_TCP_AGAIN: case STUB_TCP_AGAIN:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> tcp again\n");
return; return;
case STUB_TCP_ERROR: case STUB_TCP_ERROR:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> tcp error\n");
stub_erred(netreq); stub_erred(netreq);
return; return;
default: default:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> All done. close fd %d\n", netreq->fd);
GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event);
if (q != netreq->query_id) if (q != netreq->query_id)
return; return;
@ -618,7 +612,8 @@ sock_wait(int sockfd)
fd_set fds; fd_set fds;
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(FD_SET_T sockfd, &fds); FD_SET(FD_SET_T sockfd, &fds);
struct timeval timeout = {2, 0 }; /*TODO[TLS]: Pick up this timeout from the context*/
struct timeval timeout = {5, 0 };
ret = select(sockfd+1, NULL, &fds, NULL, &timeout); ret = select(sockfd+1, NULL, &fds, NULL, &timeout);
if(ret == 0) if(ret == 0)
/* timeout expired */ /* timeout expired */
@ -632,7 +627,6 @@ sock_wait(int sockfd)
static int static int
sock_connected(int sockfd) sock_connected(int sockfd)
{ {
fprintf(stderr, "[TLS] connect in progress \n");
/* wait(write) until connected or error */ /* wait(write) until connected or error */
while(1) { while(1) {
int error = 0; int error = 0;
@ -667,10 +661,8 @@ do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream)
{ {
/*Lets make sure the connection is up before we try a handshake*/ /*Lets make sure the connection is up before we try a handshake*/
if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) { if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) {
fprintf(stderr, "[TLS] connect failed \n");
return NULL; return NULL;
} }
fprintf(stderr, "[TLS] connect done \n");
/* Create SSL instance */ /* Create SSL instance */
SSL* ssl = SSL_new(dnsreq->context->tls_ctx); SSL* ssl = SSL_new(dnsreq->context->tls_ctx);
@ -694,29 +686,24 @@ do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream)
while ((r = SSL_do_handshake(ssl)) != 1) while ((r = SSL_do_handshake(ssl)) != 1)
{ {
want = SSL_get_error(ssl, r); want = SSL_get_error(ssl, r);
fprintf(stderr, "[TLS] in handshake loop %d, want is %d \n", r, want);
switch (want) { switch (want) {
case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_READ:
if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) { if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) {
fprintf(stderr, "[TLS] ssl handshake timeout %d\n", want);
SSL_free(ssl); SSL_free(ssl);
return NULL; return NULL;
} }
break; break;
case SSL_ERROR_WANT_WRITE: case SSL_ERROR_WANT_WRITE:
if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) { if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) {
fprintf(stderr, "[TLS] ssl handshake timeout %d\n", want);
SSL_free(ssl); SSL_free(ssl);
return NULL; return NULL;
} }
break; break;
default: default:
fprintf(stderr, "[TLS] got ssl error code %d\n", want);
SSL_free(ssl); SSL_free(ssl);
return NULL; return NULL;
} }
} }
fprintf(stderr, "[TLS] got TLS connection\n");
return ssl; return ssl;
} }
@ -727,8 +714,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
uint8_t *buf; uint8_t *buf;
size_t buf_size; size_t buf_size;
fprintf(stderr, "[TLS] method: stub_tls_read\n");
if (!tcp->read_buf) { if (!tcp->read_buf) {
/* First time tls read, create a buffer for reading */ /* First time tls read, create a buffer for reading */
if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096)))
@ -750,7 +735,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
} else } else
return STUB_TCP_ERROR; return STUB_TCP_ERROR;
} }
fprintf(stderr, "[TLS] method: read %d TLS bytes \n", (int)read);
tcp->to_read -= read; tcp->to_read -= read;
tcp->read_pos += read; tcp->read_pos += read;
@ -761,7 +745,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
if (read == 2) { if (read == 2) {
/* Read the packet size short */ /* Read the packet size short */
tcp->to_read = gldns_read_uint16(tcp->read_buf); tcp->to_read = gldns_read_uint16(tcp->read_buf);
fprintf(stderr, "[TLS] method: %d TLS bytes to read \n", (int)tcp->to_read);
if (tcp->to_read < GLDNS_HEADER_SIZE) if (tcp->to_read < GLDNS_HEADER_SIZE)
return STUB_TCP_ERROR; return STUB_TCP_ERROR;
@ -781,7 +764,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
} }
/* Ready to start reading the packet */ /* Ready to start reading the packet */
fprintf(stderr, "[TLS] method: resetting read_pos \n");
tcp->read_pos = tcp->read_buf; tcp->read_pos = tcp->read_buf;
read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read);
if (read <= 0) { if (read <= 0) {
@ -803,7 +785,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
static void netreq_upstream_read_cb(void *userarg); static void netreq_upstream_read_cb(void *userarg);
static void netreq_upstream_write_cb(void *userarg); static void netreq_upstream_write_cb(void *userarg);
static void upstream_write_cb(void *userarg);
static void static void
upstream_read_cb(void *userarg) upstream_read_cb(void *userarg)
{ {
@ -814,8 +795,6 @@ upstream_read_cb(void *userarg)
uint16_t query_id; uint16_t query_id;
intptr_t query_id_intptr; intptr_t query_id_intptr;
fprintf(stderr, "[TLS] method: upstream_read_cb\n");
if (upstream->tls_obj) if (upstream->tls_obj)
q = stub_tls_read(upstream->tls_obj, &upstream->tcp, q = stub_tls_read(upstream->tls_obj, &upstream->tcp,
&upstream->upstreams->mf); &upstream->upstreams->mf);
@ -825,7 +804,6 @@ upstream_read_cb(void *userarg)
switch (q) { switch (q) {
case STUB_TCP_AGAIN: case STUB_TCP_AGAIN:
fprintf(stderr, "[TLS] method: upstream_read_cb -> STUB_TCP_AGAIN\n");
return; return;
case STUB_TCP_ERROR: case STUB_TCP_ERROR:
@ -833,7 +811,6 @@ upstream_read_cb(void *userarg)
return; return;
default: default:
fprintf(stderr, "[TLS] method: upstream_read_cb -> processing reponse\n");
/* Lookup netreq */ /* Lookup netreq */
query_id = (uint16_t) q; query_id = (uint16_t) q;
@ -851,7 +828,6 @@ upstream_read_cb(void *userarg)
netreq->response = upstream->tcp.read_buf; netreq->response = upstream->tcp.read_buf;
netreq->response_len = netreq->response_len =
upstream->tcp.read_pos - upstream->tcp.read_buf; upstream->tcp.read_pos - upstream->tcp.read_buf;
netreq->tls_obj = upstream->tls_obj;
upstream->tcp.read_buf = NULL; upstream->tcp.read_buf = NULL;
upstream->upstreams->current = 0; upstream->upstreams->current = 0;
@ -906,7 +882,6 @@ static int
stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq)
{ {
getdns_dns_req *dnsreq = netreq->owner; getdns_dns_req *dnsreq = netreq->owner;
fprintf(stderr, "[TLS] method: stub_tcp_write\n");
size_t pkt_len = netreq->response - netreq->query; size_t pkt_len = netreq->response - netreq->query;
ssize_t written; ssize_t written;
@ -1045,8 +1020,6 @@ stub_tcp_write_cb(void *userarg)
static int static int
stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq) stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq)
{ {
fprintf(stderr, "[TLS] method: stub_tls_write\n");
size_t pkt_len = netreq->response - netreq->query; size_t pkt_len = netreq->response - netreq->query;
ssize_t written; ssize_t written;
uint16_t query_id; uint16_t query_id;
@ -1100,8 +1073,6 @@ upstream_write_cb(void *userarg)
getdns_dns_req *dnsreq = netreq->owner; getdns_dns_req *dnsreq = netreq->owner;
int q; int q;
fprintf(stderr, "[TLS] method: upstream_write_cb for %s with class %d\n", dnsreq->name, (int)netreq->request_class);
if (upstream->tls_obj) if (upstream->tls_obj)
q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq); q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq);
else else
@ -1243,7 +1214,6 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) {
connect_addr = upstream->addr; connect_addr = upstream->addr;
addr = &connect_addr; addr = &connect_addr;
set_port(addr, TLS_PORT); set_port(addr, TLS_PORT);
fprintf(stderr, "[TLS] Forcing switch to port %d for TLS\n", TLS_PORT);
} }
if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
@ -1271,8 +1241,6 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
getdns_dns_req *dnsreq = netreq->owner; getdns_dns_req *dnsreq = netreq->owner;
getdns_upstream *upstream = pick_upstream(dnsreq); getdns_upstream *upstream = pick_upstream(dnsreq);
fprintf(stderr, "[TLS] method: priv_getdns_submit_stub_request\n");
if (!upstream) if (!upstream)
return GETDNS_RETURN_GENERIC_ERROR; return GETDNS_RETURN_GENERIC_ERROR;
@ -1325,19 +1293,16 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
/* We are the first. Make global socket and connect. */ /* We are the first. Make global socket and connect. */
if ((upstream->fd = tcp_connect(upstream, transport)) == -1) { if ((upstream->fd = tcp_connect(upstream, transport)) == -1) {
//TODO: Hum, a reset doesn't make the connect fail...
if (fb_transport == NONE) if (fb_transport == NONE)
return GETDNS_RETURN_GENERIC_ERROR; return GETDNS_RETURN_GENERIC_ERROR;
fprintf(stderr, "[TLS] Connect failed on fd... %d\n", upstream->fd);
if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1) if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1)
return GETDNS_RETURN_GENERIC_ERROR; return GETDNS_RETURN_GENERIC_ERROR;
fallback = 1; fallback = 1;
} }
/* Now do a handshake for TLS. Note waiting for this to succeed or /* Now do a handshake for TLS. Note waiting for this to succeed or
timeout blocks the scheduling of any messages for this upstream*/ * timeout blocks the scheduling of any messages for this upstream*/
if (transport == TLS && (fallback == 0)) { if (transport == TLS && (fallback == 0)) {
fprintf(stderr, "[TLS] Doing SSL handshake... %d\n", upstream->fd);
upstream->tls_obj = do_tls_handshake(dnsreq, upstream); upstream->tls_obj = do_tls_handshake(dnsreq, upstream);
if (!upstream->tls_obj) { if (!upstream->tls_obj) {
if (fb_transport == NONE) if (fb_transport == NONE)
@ -1347,6 +1312,10 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
return GETDNS_RETURN_GENERIC_ERROR; return GETDNS_RETURN_GENERIC_ERROR;
} }
} }
/* Attach to the global event loop
* so it can do it's own scheduling
*/
upstream->loop = dnsreq->context->extension;
} else { } else {
/* Cater for the case of the user downgrading and existing TLS /* Cater for the case of the user downgrading and existing TLS
connection to TCP for some reason...*/ connection to TCP for some reason...*/
@ -1358,11 +1327,6 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
} }
netreq->upstream = upstream; netreq->upstream = upstream;
/* Attach to the global event loop
* so it can do it's own scheduling
*/
upstream->loop = dnsreq->context->extension;
/* We have a context wide socket. /* We have a context wide socket.
* Now schedule the write request. * Now schedule the write request.
*/ */

View File

@ -212,7 +212,6 @@ typedef struct getdns_network_req
uint8_t *opt; /* offset of OPT RR in query */ uint8_t *opt; /* offset of OPT RR in query */
size_t response_len; size_t response_len;
uint8_t *response; uint8_t *response;
SSL* tls_obj;
size_t wire_data_sz; size_t wire_data_sz;
uint8_t wire_data[]; uint8_t wire_data[];