Cleanup of TLS code

This commit is contained in:
Sara Dickinson 2015-04-16 18:01:17 +01:00
parent 99aa79b48f
commit 99c1973fae
4 changed files with 20 additions and 63 deletions

View File

@ -1145,15 +1145,15 @@ set_ub_dns_transport(struct getdns_context* context,
case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN:
case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN:
/* TODO: Investigate why ssl-upstream in Unbound isn't working (error
that the SSL lib isn't init'ed but that is done in prep_for_res.*/
/* Note: no fallback or pipelining available directly in unbound.*/
* that the SSL lib isn't init'ed but that is done in prep_for_res.
* Note: no fallback or pipelining available directly in unbound.*/
set_ub_string_opt(context, "do-udp:", "no");
set_ub_string_opt(context, "do-tcp:", "yes");
//set_ub_string_opt(context, "ssl-upstream:", "yes");
/* set_ub_string_opt(context, "ssl-upstream:", "yes");*/
/* TODO: Specifying a different port to do TLS on in unbound is a bit
tricky as it involves modifying each fwd upstream defined on the
unbound ctx... And to support fallback this would have to be reset
from the stub code while trying to connect...*/
* tricky as it involves modifying each fwd upstream defined on the
* unbound ctx... And to support fallback this would have to be reset
* from the stub code while trying to connect...*/
break;
default:
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
@ -1171,10 +1171,10 @@ getdns_context_set_dns_transport(struct getdns_context *context,
{
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
/* Note that the call below does not have any effect in unbound after the
ctx is finalised. So will not apply for recursive mode or stub + dnssec.
However the method returns success as otherwise the transport could not
be reset for stub mode..... */
/* Also, not all transport options supported in libunbound yet*/
* ctx is finalised. So will not apply for recursive mode or stub + dnssec.
* However the method returns success as otherwise the transport could not
* be reset for stub mode.....
* Also, not all transport options supported in libunbound yet */
if (set_ub_dns_transport(context, value) != GETDNS_RETURN_GOOD) {
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
}
@ -1799,9 +1799,7 @@ getdns_context_prepare_for_resolution(struct getdns_context *context,
return GETDNS_RETURN_BAD_CONTEXT;
}
/* Transport can in theory be set per query in stub mode so deal with it
here */
printf("[TLS] preparing for resolution, checking transport type\n");
/* Transport can in theory be set per query in stub mode */
if (context->resolution_type == GETDNS_RESOLUTION_STUB) {
switch (context->dns_transport) {
case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN:
@ -1809,13 +1807,10 @@ getdns_context_prepare_for_resolution(struct getdns_context *context,
if (context->tls_ctx == NULL) {
/* Init the SSL library */
SSL_library_init();
/* Load error messages */
SSL_load_error_strings();
/* Create client context, use TLS v1.2 only for now */
SSL_CTX* tls_ctx = SSL_CTX_new(TLSv1_2_client_method());
if(!tls_ctx) {
ERR_print_errors_fp(stderr);
return GETDNS_RETURN_BAD_CONTEXT;
}
context->tls_ctx = tls_ctx;

View File

@ -98,7 +98,6 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner,
net_req->write_queue_tail = NULL;
net_req->query_len = 0;
net_req->response_len = 0;
net_req->tls_obj = NULL;
net_req->wire_data_sz = wire_data_sz;
if (max_query_sz) {

View File

@ -320,8 +320,8 @@ upstream_erred(getdns_upstream *upstream)
netreq->state = NET_REQ_FINISHED;
priv_getdns_check_dns_req_complete(netreq->owner);
}
// TODO[TLS]: When we get an error (which is probably a timeout) and are
// using to keep connections open should we leave the connection up here?
/* TODO[TLS]: When we get an error (which is probably a timeout) and are
* using to keep connections open should we leave the connection up here? */
if (upstream->tls_obj) {
SSL_shutdown(upstream->tls_obj);
SSL_free(upstream->tls_obj);
@ -507,8 +507,6 @@ stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf)
uint8_t *buf;
size_t buf_size;
fprintf(stderr, "[TLS] method: stub_tcp_read\n");
if (!tcp->read_buf) {
/* First time tcp read, create a buffer for reading */
if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096)))
@ -529,7 +527,6 @@ stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf)
/* TODO: Try to reconnect */
return STUB_TCP_ERROR;
}
fprintf(stderr, "[TLS] method: read %d TCP bytes \n", (int)read);
tcp->to_read -= read;
tcp->read_pos += read;
@ -575,16 +572,13 @@ stub_tcp_read_cb(void *userarg)
&dnsreq->context->mf))) {
case STUB_TCP_AGAIN:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> tcp again\n");
return;
case STUB_TCP_ERROR:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> tcp error\n");
stub_erred(netreq);
return;
default:
fprintf(stderr, "[TLS] method: stub_tcp_read_cb -> All done. close fd %d\n", netreq->fd);
GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event);
if (q != netreq->query_id)
return;
@ -618,7 +612,8 @@ sock_wait(int sockfd)
fd_set fds;
FD_ZERO(&fds);
FD_SET(FD_SET_T sockfd, &fds);
struct timeval timeout = {2, 0 };
/*TODO[TLS]: Pick up this timeout from the context*/
struct timeval timeout = {5, 0 };
ret = select(sockfd+1, NULL, &fds, NULL, &timeout);
if(ret == 0)
/* timeout expired */
@ -632,7 +627,6 @@ sock_wait(int sockfd)
static int
sock_connected(int sockfd)
{
fprintf(stderr, "[TLS] connect in progress \n");
/* wait(write) until connected or error */
while(1) {
int error = 0;
@ -667,10 +661,8 @@ do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream)
{
/*Lets make sure the connection is up before we try a handshake*/
if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) {
fprintf(stderr, "[TLS] connect failed \n");
return NULL;
}
fprintf(stderr, "[TLS] connect done \n");
/* Create SSL instance */
SSL* ssl = SSL_new(dnsreq->context->tls_ctx);
@ -694,29 +686,24 @@ do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream)
while ((r = SSL_do_handshake(ssl)) != 1)
{
want = SSL_get_error(ssl, r);
fprintf(stderr, "[TLS] in handshake loop %d, want is %d \n", r, want);
switch (want) {
case SSL_ERROR_WANT_READ:
if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) {
fprintf(stderr, "[TLS] ssl handshake timeout %d\n", want);
SSL_free(ssl);
return NULL;
}
break;
case SSL_ERROR_WANT_WRITE:
if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) {
fprintf(stderr, "[TLS] ssl handshake timeout %d\n", want);
SSL_free(ssl);
return NULL;
}
break;
default:
fprintf(stderr, "[TLS] got ssl error code %d\n", want);
SSL_free(ssl);
return NULL;
}
}
fprintf(stderr, "[TLS] got TLS connection\n");
return ssl;
}
@ -727,8 +714,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
uint8_t *buf;
size_t buf_size;
fprintf(stderr, "[TLS] method: stub_tls_read\n");
if (!tcp->read_buf) {
/* First time tls read, create a buffer for reading */
if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096)))
@ -750,7 +735,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
} else
return STUB_TCP_ERROR;
}
fprintf(stderr, "[TLS] method: read %d TLS bytes \n", (int)read);
tcp->to_read -= read;
tcp->read_pos += read;
@ -761,7 +745,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
if (read == 2) {
/* Read the packet size short */
tcp->to_read = gldns_read_uint16(tcp->read_buf);
fprintf(stderr, "[TLS] method: %d TLS bytes to read \n", (int)tcp->to_read);
if (tcp->to_read < GLDNS_HEADER_SIZE)
return STUB_TCP_ERROR;
@ -781,7 +764,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
}
/* Ready to start reading the packet */
fprintf(stderr, "[TLS] method: resetting read_pos \n");
tcp->read_pos = tcp->read_buf;
read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read);
if (read <= 0) {
@ -803,7 +785,6 @@ stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
static void netreq_upstream_read_cb(void *userarg);
static void netreq_upstream_write_cb(void *userarg);
static void upstream_write_cb(void *userarg);
static void
upstream_read_cb(void *userarg)
{
@ -814,8 +795,6 @@ upstream_read_cb(void *userarg)
uint16_t query_id;
intptr_t query_id_intptr;
fprintf(stderr, "[TLS] method: upstream_read_cb\n");
if (upstream->tls_obj)
q = stub_tls_read(upstream->tls_obj, &upstream->tcp,
&upstream->upstreams->mf);
@ -825,7 +804,6 @@ upstream_read_cb(void *userarg)
switch (q) {
case STUB_TCP_AGAIN:
fprintf(stderr, "[TLS] method: upstream_read_cb -> STUB_TCP_AGAIN\n");
return;
case STUB_TCP_ERROR:
@ -833,7 +811,6 @@ upstream_read_cb(void *userarg)
return;
default:
fprintf(stderr, "[TLS] method: upstream_read_cb -> processing reponse\n");
/* Lookup netreq */
query_id = (uint16_t) q;
@ -851,7 +828,6 @@ upstream_read_cb(void *userarg)
netreq->response = upstream->tcp.read_buf;
netreq->response_len =
upstream->tcp.read_pos - upstream->tcp.read_buf;
netreq->tls_obj = upstream->tls_obj;
upstream->tcp.read_buf = NULL;
upstream->upstreams->current = 0;
@ -906,7 +882,6 @@ static int
stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq)
{
getdns_dns_req *dnsreq = netreq->owner;
fprintf(stderr, "[TLS] method: stub_tcp_write\n");
size_t pkt_len = netreq->response - netreq->query;
ssize_t written;
@ -1045,8 +1020,6 @@ stub_tcp_write_cb(void *userarg)
static int
stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq)
{
fprintf(stderr, "[TLS] method: stub_tls_write\n");
size_t pkt_len = netreq->response - netreq->query;
ssize_t written;
uint16_t query_id;
@ -1100,8 +1073,6 @@ upstream_write_cb(void *userarg)
getdns_dns_req *dnsreq = netreq->owner;
int q;
fprintf(stderr, "[TLS] method: upstream_write_cb for %s with class %d\n", dnsreq->name, (int)netreq->request_class);
if (upstream->tls_obj)
q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq);
else
@ -1243,7 +1214,6 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) {
connect_addr = upstream->addr;
addr = &connect_addr;
set_port(addr, TLS_PORT);
fprintf(stderr, "[TLS] Forcing switch to port %d for TLS\n", TLS_PORT);
}
if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
@ -1271,8 +1241,6 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
getdns_dns_req *dnsreq = netreq->owner;
getdns_upstream *upstream = pick_upstream(dnsreq);
fprintf(stderr, "[TLS] method: priv_getdns_submit_stub_request\n");
if (!upstream)
return GETDNS_RETURN_GENERIC_ERROR;
@ -1325,19 +1293,16 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
/* We are the first. Make global socket and connect. */
if ((upstream->fd = tcp_connect(upstream, transport)) == -1) {
//TODO: Hum, a reset doesn't make the connect fail...
if (fb_transport == NONE)
return GETDNS_RETURN_GENERIC_ERROR;
fprintf(stderr, "[TLS] Connect failed on fd... %d\n", upstream->fd);
if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
fallback = 1;
}
/* Now do a handshake for TLS. Note waiting for this to succeed or
timeout blocks the scheduling of any messages for this upstream*/
* timeout blocks the scheduling of any messages for this upstream*/
if (transport == TLS && (fallback == 0)) {
fprintf(stderr, "[TLS] Doing SSL handshake... %d\n", upstream->fd);
upstream->tls_obj = do_tls_handshake(dnsreq, upstream);
if (!upstream->tls_obj) {
if (fb_transport == NONE)
@ -1347,6 +1312,10 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
return GETDNS_RETURN_GENERIC_ERROR;
}
}
/* Attach to the global event loop
* so it can do it's own scheduling
*/
upstream->loop = dnsreq->context->extension;
} else {
/* Cater for the case of the user downgrading and existing TLS
connection to TCP for some reason...*/
@ -1358,11 +1327,6 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
}
netreq->upstream = upstream;
/* Attach to the global event loop
* so it can do it's own scheduling
*/
upstream->loop = dnsreq->context->extension;
/* We have a context wide socket.
* Now schedule the write request.
*/

View File

@ -212,7 +212,6 @@ typedef struct getdns_network_req
uint8_t *opt; /* offset of OPT RR in query */
size_t response_len;
uint8_t *response;
SSL* tls_obj;
size_t wire_data_sz;
uint8_t wire_data[];