From 7905eda8b702eb45d4a211bb08edcc848e588d81 Mon Sep 17 00:00:00 2001 From: Sara Dickinson Date: Thu, 30 Apr 2015 12:24:13 +0100 Subject: [PATCH] Some clean up of connection handling. Still a problem with STARTTLS fallback that needs fixing. --- src/context.c | 20 +++++++++------- src/context.h | 2 -- src/stub.c | 65 +++++++++++++++++++-------------------------------- 3 files changed, 35 insertions(+), 52 deletions(-) diff --git a/src/context.c b/src/context.c index ab075d2d..56d08476 100644 --- a/src/context.c +++ b/src/context.c @@ -267,7 +267,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in *)sa)->sin_port); - if (port != 0 && port != GETDNS_PORT_TCP && + if (port != GETDNS_PORT_ZERO && port != GETDNS_PORT_TCP && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -283,7 +283,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port); - if (port != 0 && port != GETDNS_PORT_TCP && + if (port != GETDNS_PORT_TCP && port != GETDNS_PORT_TCP && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -553,7 +553,7 @@ upstream_ntop_buf(getdns_upstream *upstream, char *buf, size_t len) if (upstream_scope_id(upstream)) (void) snprintf(buf + strlen(buf), len - strlen(buf), "%%%d", (int)*upstream_scope_id(upstream)); - else if (upstream_port(upstream) != GETDNS_PORT_TCP && upstream_port(upstream) != 0) + else if (upstream_port(upstream) != GETDNS_PORT_TCP && upstream_port(upstream) != GETDNS_PORT_ZERO) (void) snprintf(buf + strlen(buf), len - strlen(buf), "@%d", (int)upstream_port(upstream)); } @@ -687,7 +687,7 @@ set_os_defaults(struct getdns_context *context) getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { - char * port_str = getdns_port_str_array[base_transport]; + char *port_str = getdns_port_str_array[base_transport]; if (strncmp(port_str, GETDNS_STR_PORT_ZERO, 1) == 0) continue; if ((s = getaddrinfo(parse, port_str, &hints, &result))) @@ -1219,7 +1219,7 @@ set_ub_dns_transport(struct getdns_context* context, /* Note: If TLS is used in recursive mode this will try TLS on port * 53... So this is prohibited when preparing for resolution.*/ set_ub_string_opt(context, "ssl-upstream:", "yes"); - /* Fall through*/ + /* Fall through */ case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: /* Note: no fallback to TCP available directly in unbound, so we just @@ -1792,11 +1792,13 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) (void) ub_ctx_set_fwd(ctx, NULL); for (i = 0; i < upstreams->count; i++) { upstream = &upstreams->upstreams[i]; - /*[TLS]: Use only the subset of upstreams that match the first transport */ - if (context->dns_transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) { - if (upstream_port(upstream) != GETDNS_PORT_TLS) + /*[TLS]: Use only the TLS subset of upstreams when only TLS is used. + * All other cases must currently fallback to TCP for libunbound. */ + if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && + context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_NONE && + upstream_port(upstream) != GETDNS_PORT_TLS) continue; - } else if (upstream_port(upstream) != GETDNS_PORT_TCP) + else if (upstream_port(upstream) != GETDNS_PORT_TCP) continue; upstream_ntop_buf(upstream, addr, 1024); ub_ctx_set_fwd(ctx, addr); diff --git a/src/context.h b/src/context.h index 78488f77..3ad6b342 100644 --- a/src/context.h +++ b/src/context.h @@ -237,8 +237,6 @@ void priv_getdns_context_ub_read_cb(void *userarg); getdns_return_t priv_set_base_dns_transports(getdns_base_transport_t *, getdns_transport_t); -getdns_base_transport_t priv_get_base_transport(getdns_transport_t transport, int level); - void priv_getdns_upstreams_dereference(getdns_upstreams *upstreams); #endif /* _GETDNS_CONTEXT_H_ */ diff --git a/src/stub.c b/src/stub.c index bcdadce0..1e1724f5 100755 --- a/src/stub.c +++ b/src/stub.c @@ -757,9 +757,8 @@ create_tls_object(getdns_context *context, int fd) if (context->tls_ctx == NULL) return NULL; SSL* ssl = SSL_new(context->tls_ctx); - if(!ssl) { + if(!ssl) return NULL; - } /* Connect the SSL object with a file descriptor */ if(!SSL_set_fd(ssl,fd)) { SSL_free(ssl); @@ -1084,7 +1083,7 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) * by query_id in the process. */ if ((*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) || - (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_UDP)) + (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_UDP)) query_id = arc4random(); else do { query_id = arc4random(); @@ -1379,7 +1378,8 @@ tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) #ifdef USE_TCP_FASTOPEN /* Leave the connect to the later call to sendto() if using TCP*/ if (transport == GETDNS_BASE_TRANSPORT_TCP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || + transport == GETDNS_BASE_TRANSPORT_STARTTLS) return fd; #endif if (connect(fd, (struct sockaddr *)&upstream->addr, @@ -1396,38 +1396,6 @@ int connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport, getdns_dns_req *dnsreq) { - /* First check if existing connection can be used, which may still be being - * set up. */ - switch(transport) { - case GETDNS_BASE_TRANSPORT_TCP: - if (upstream->fd != -1) { - fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream):" - "tcp_connect using existing TCP fd %d\n", upstream->fd); - return upstream->fd; - } - break; - case GETDNS_BASE_TRANSPORT_TLS: - if (tls_handshake_active(upstream->tls_hs_state)) { - fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream):" - "tcp_connect using existing TLS fd %d\n", upstream->fd); - return upstream->fd; - } - break; - case GETDNS_BASE_TRANSPORT_STARTTLS: - /* Either negotiating, or doing handshake*/ - if ((upstream->starttls_req != NULL) || - (upstream->starttls_req == NULL && - tls_handshake_active(upstream->tls_hs_state))) { - fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream):" - "tcp_connect using existing STARTTLS fd %d\n", upstream->fd); - return upstream->fd; - } - break; - default: - break; - } - - /* If not, create a new one */ int fd = -1; switch(transport) { case GETDNS_BASE_TRANSPORT_UDP: @@ -1437,14 +1405,21 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport getdns_sock_nonblock(fd); return fd; - case GETDNS_BASE_TRANSPORT_TCP_SINGLE: case GETDNS_BASE_TRANSPORT_TCP: + /* Use existing if available*/ + if (upstream->fd != -1) + return upstream->fd; + /* Otherwise, fall through */ + case GETDNS_BASE_TRANSPORT_TCP_SINGLE: fd = tcp_connect(upstream, transport); upstream->loop = dnsreq->context->extension; upstream->fd = fd; break; case GETDNS_BASE_TRANSPORT_TLS: + /* Use existing if available*/ + if (upstream->fd != 1 && tls_handshake_active(upstream->tls_hs_state)) + return upstream->fd; fd = tcp_connect(upstream, transport); if (fd == -1) return -1; upstream->tls_obj = create_tls_object(dnsreq->context, fd); @@ -1458,6 +1433,12 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport upstream->fd = fd; break; case GETDNS_BASE_TRANSPORT_STARTTLS: + /* Use existing if available. May be either negotiating or doing TLS */ + if (upstream->fd != 1 && + (upstream->starttls_req != NULL) || + (upstream->starttls_req == NULL && + tls_handshake_active(upstream->tls_hs_state))) + return upstream->fd; fd = tcp_connect(upstream, transport); if (fd == -1) return -1; if (!create_starttls_request(dnsreq, upstream, dnsreq->loop)) @@ -1479,12 +1460,13 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport return -1; /* Nothing to do*/ } - fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream): created new connection %d\n", fd); + fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream):" + " created new connection %d\n", fd); return fd; } static getdns_upstream* -pick_and_connect_to_upstream(getdns_network_req *netreq, +find_upstream_for_specific_transport(getdns_network_req *netreq, getdns_base_transport_t transport, int *fd) { @@ -1502,7 +1484,7 @@ find_upstream_for_netreq(getdns_network_req *netreq) int fd = -1; for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX && netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) { - netreq->upstream = pick_and_connect_to_upstream(netreq, + netreq->upstream = find_upstream_for_specific_transport(netreq, netreq->dns_base_transports[i], &fd); if (fd == -1) @@ -1524,6 +1506,7 @@ move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, upstream->write_queue_last = NULL; upstream->event.write_cb = NULL; GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + close(upstream->fd); upstream->fd = -1; } @@ -1591,7 +1574,7 @@ fallback_on_write(getdns_network_req *netreq) getdns_upstream *upstream = netreq->upstream; int fd; getdns_upstream *new_upstream = - pick_and_connect_to_upstream(netreq, *next_transport, &fd); + find_upstream_for_specific_transport(netreq, *next_transport, &fd); if (!new_upstream) return STUB_TCP_ERROR; return move_netreq(netreq, upstream, new_upstream);