From 635cf9e1823fc2532f7b2c635301622e24ee84a2 Mon Sep 17 00:00:00 2001 From: Sara Dickinson Date: Fri, 19 Jun 2015 18:28:29 +0100 Subject: [PATCH] Re-factor of internal handing of transport list. --- src/context.c | 242 ++++++++++++++++++---------------------- src/context.h | 7 +- src/request-internal.c | 11 +- src/stub.c | 142 ++++++++++------------- src/test/getdns_query.c | 2 +- src/types-internal.h | 30 ++--- 6 files changed, 202 insertions(+), 232 deletions(-) diff --git a/src/context.c b/src/context.c index f0a3631a..c8c7be81 100644 --- a/src/context.c +++ b/src/context.c @@ -69,11 +69,15 @@ typedef struct host_name_addrs { uint8_t host_name[]; } host_name_addrs; +static getdns_transport_list_t +getdns_upstream_transports[GETDNS_UPSTREAM_TRANSPORTS] = { + GETDNS_TRANSPORT_TCP, + GETDNS_TRANSPORT_TLS, + GETDNS_TRANSPORT_STARTTLS +}; + static in_port_t -getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { - GETDNS_PORT_ZERO, - GETDNS_PORT_ZERO, - GETDNS_PORT_ZERO, +getdns_port_array[GETDNS_UPSTREAM_TRANSPORTS] = { GETDNS_PORT_DNS, GETDNS_PORT_DNS, GETDNS_PORT_DNS_OVER_TLS @@ -81,9 +85,6 @@ getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { char* getdns_port_str_array[] = { - GETDNS_STR_PORT_ZERO, - GETDNS_STR_PORT_ZERO, - GETDNS_STR_PORT_ZERO, GETDNS_STR_PORT_DNS, GETDNS_STR_PORT_DNS, GETDNS_STR_PORT_DNS_OVER_TLS @@ -91,6 +92,7 @@ getdns_port_str_array[] = { /* Private functions */ getdns_return_t create_default_namespaces(struct getdns_context *context); +getdns_return_t create_default_dns_transports(struct getdns_context *context); static struct getdns_list *create_default_root_servers(void); static getdns_return_t set_os_defaults(struct getdns_context *); static int transaction_id_cmp(const void *, const void *); @@ -139,42 +141,22 @@ create_default_namespaces(struct getdns_context *context) return GETDNS_RETURN_GOOD; } -static getdns_transport_list_t * -get_dns_transport_list(getdns_context *context, int *count) +/** + * Helper to get default transports. + */ +getdns_return_t +create_default_dns_transports(struct getdns_context *context) { - if (context == NULL) - return NULL; + context->dns_transports = GETDNS_XMALLOC(context->my_mf, getdns_transport_list_t, 2); + if(context->dns_transports == NULL) + return GETDNS_RETURN_GENERIC_ERROR; - /* Count how many we have*/ - for (*count = 0; *count < GETDNS_BASE_TRANSPORT_MAX; (*count)++) { - if (context->dns_base_transports[*count] == GETDNS_BASE_TRANSPORT_NONE) - break; - } + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; + context->dns_transport_count = 2; + context->dns_transport_current = 0; - // use normal malloc here so users can do normal free - getdns_transport_list_t * transports = malloc(*count * sizeof(getdns_transport_list_t)); - - if(transports == NULL) - return NULL; - for (int i = 0; i < (int)*count; i++) { - switch(context->dns_base_transports[i]) { - case GETDNS_BASE_TRANSPORT_UDP: - transports[i] = GETDNS_TRANSPORT_UDP; - break; - case GETDNS_BASE_TRANSPORT_TCP: - transports[i] = GETDNS_TRANSPORT_TCP; - break; - case GETDNS_BASE_TRANSPORT_TLS: - transports[i] = GETDNS_TRANSPORT_TLS; - break; - case GETDNS_BASE_TRANSPORT_STARTTLS: - transports[i] = GETDNS_TRANSPORT_STARTTLS; - break; - default: - break; - } - } - return transports; + return GETDNS_RETURN_GOOD; } static inline void canonicalize_dname(uint8_t *dname) @@ -621,7 +603,7 @@ upstream_init(getdns_upstream *upstream, upstream->fd = -1; upstream->tls_obj = NULL; upstream->starttls_req = NULL; - upstream->dns_base_transport = GETDNS_BASE_TRANSPORT_TCP; + upstream->transport = GETDNS_TRANSPORT_TCP; upstream->tls_hs_state = GETDNS_HS_NONE; upstream->loop = NULL; (void) getdns_eventloop_event_init( @@ -725,11 +707,8 @@ set_os_defaults(struct getdns_context *context) token = parse + strcspn(parse, " \t\r\n"); *token = 0; - getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; - for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { - char *port_str = getdns_port_str_array[base_transport]; - if (strncmp(port_str, GETDNS_STR_PORT_ZERO, 1) == 0) - continue; + for (size_t i = 0; i < GETDNS_UPSTREAM_TRANSPORTS; i++) { + char *port_str = getdns_port_str_array[i]; if ((s = getaddrinfo(parse, port_str, &hints, &result))) continue; if (!result) @@ -743,7 +722,7 @@ set_os_defaults(struct getdns_context *context) upstream = &context->upstreams-> upstreams[context->upstreams->count++]; upstream_init(upstream, context->upstreams, result); - upstream->dns_base_transport = base_transport; + upstream->transport = getdns_upstream_transports[i]; freeaddrinfo(result); } } @@ -873,8 +852,8 @@ getdns_context_create_with_extended_memory_functions( result->dnssec_allowed_skew = 0; result->edns_maximum_udp_payload_size = -1; - result->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; - result->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + if ((r = create_default_dns_transports(result))) + goto error; result->limit_outstanding_queries = 0; result->has_ta = priv_getdns_parse_ta_file(NULL, NULL); result->return_dnssec_status = GETDNS_EXTENSION_FALSE; @@ -1201,62 +1180,61 @@ static getdns_return_t getdns_set_base_dns_transports(struct getdns_context *context, size_t transport_count, getdns_transport_list_t *transports) { - RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); - for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE; + size_t i; - if ((int)transport_count == 0 || transports == NULL || - (int)transport_count > GETDNS_BASE_TRANSPORT_MAX) { - return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; + RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); + if (transport_count == 0 || transports == NULL) { + return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; + } + + for(i=0; idns_base_transports[j] = GETDNS_BASE_TRANSPORT_UDP; - break; - case GETDNS_TRANSPORT_TCP: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TCP; - break; - case GETDNS_TRANSPORT_TLS: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TLS; - break; - case GETDNS_TRANSPORT_STARTTLS: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_STARTTLS; - break; - default: - return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; - } - } - return GETDNS_RETURN_GOOD; + GETDNS_FREE(context->my_mf, context->dns_transports); + + /** duplicate **/ + context->dns_transports = GETDNS_XMALLOC(context->my_mf, + getdns_transport_list_t, transport_count); + memcpy(context->dns_transports, transports, + transport_count * sizeof(getdns_transport_list_t)); + context->dns_transport_count = transport_count; + dispatch_updated(context, GETDNS_CONTEXT_CODE_NAMESPACES); + + return GETDNS_RETURN_GOOD; } static getdns_return_t set_ub_dns_transport(struct getdns_context* context) { /* These mappings are not exact because Unbound is configured differently, so just map as close as possible from the first 1 or 2 transports. */ - switch (context->dns_base_transports[0]) { - case GETDNS_BASE_TRANSPORT_UDP: + switch (context->dns_transports[0]) { + case GETDNS_TRANSPORT_UDP: set_ub_string_opt(context, "do-udp:", "yes"); - if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_TCP) + if (context->dns_transports[1] == GETDNS_TRANSPORT_TCP) set_ub_string_opt(context, "do-tcp:", "yes"); else set_ub_string_opt(context, "do-tcp:", "no"); break; - case GETDNS_BASE_TRANSPORT_TLS: + case GETDNS_TRANSPORT_TLS: /* Note: If TLS is used in recursive mode this will try TLS on port * 53... So this is prohibited when preparing for resolution.*/ - if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE) { + if (context->dns_transport_count == 0) { set_ub_string_opt(context, "ssl-upstream:", "yes"); set_ub_string_opt(context, "do-udp:", "no"); set_ub_string_opt(context, "do-tcp:", "yes"); break; } - if (context->dns_base_transports[1] != GETDNS_BASE_TRANSPORT_TCP) + if (context->dns_transports[1] != GETDNS_TRANSPORT_TCP) break; /* Fallthrough */ - case GETDNS_BASE_TRANSPORT_STARTTLS: - case GETDNS_BASE_TRANSPORT_TCP: + case GETDNS_TRANSPORT_STARTTLS: + case GETDNS_TRANSPORT_TCP: /* Note: no STARTTLS or fallback to TCP available directly in unbound, so we just * use TCP for now to make sure the messages are sent. */ set_ub_string_opt(context, "do-udp:", "no"); @@ -1278,31 +1256,37 @@ getdns_context_set_dns_transport(struct getdns_context *context, { RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); - for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE; + size_t count = 2; + if (value == GETDNS_TRANSPORT_UDP_ONLY || + value == GETDNS_TRANSPORT_TCP_ONLY || + value == GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN || + value == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) + count = 1; + context->dns_transports = GETDNS_XMALLOC(context->my_mf, + getdns_transport_list_t, count); switch (value) { case GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_UDP_ONLY: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; break; case GETDNS_TRANSPORT_TCP_ONLY: case GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; + context->dns_transports[0] = GETDNS_TRANSPORT_TLS; break; case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_TLS; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_STARTTLS; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_STARTTLS; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; default: return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; @@ -1658,15 +1642,14 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, } /* Loop to create upstreams as needed*/ - getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; - for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { + for (size_t j = 0; j < GETDNS_UPSTREAM_TRANSPORTS; j++) { uint32_t port; struct addrinfo *ai; - port = getdns_port_array[base_transport]; + port = getdns_port_array[j]; if (port == GETDNS_PORT_ZERO) continue; - if (base_transport != GETDNS_BASE_TRANSPORT_TLS) + if (getdns_upstream_transports[j] != GETDNS_TRANSPORT_TLS) (void) getdns_dict_get_int(dict, "port", &port); else (void) getdns_dict_get_int(dict, "tls_port", &port); @@ -1689,7 +1672,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, upstream = &upstreams->upstreams[upstreams->count]; upstream->addr.ss_family = addr.ss_family; upstream_init(upstream, upstreams, ai); - upstream->dns_base_transport = base_transport; + upstream->transport = getdns_upstream_transports[j]; upstreams->count++; freeaddrinfo(ai); } @@ -1913,9 +1896,9 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) upstream = &upstreams->upstreams[i]; /*[TLS]: Use only the TLS subset of upstreams when TLS is the only thing * used. All other cases must currently fallback to TCP for libunbound.*/ - if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && - context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE && - upstream->dns_base_transport != GETDNS_BASE_TRANSPORT_TLS) + if (context->dns_transports[0] == GETDNS_TRANSPORT_TLS && + context->dns_transport_count ==1 && + upstream->transport != GETDNS_TRANSPORT_TLS) continue; upstream_ntop_buf(upstream, addr, 1024); ub_ctx_set_fwd(ctx, addr); @@ -2025,8 +2008,8 @@ getdns_context_prepare_for_resolution(struct getdns_context *context, } /* Block use of TLS ONLY in recursive mode as it won't work */ if (context->resolution_type == GETDNS_RESOLUTION_RECURSING && - context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && - context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE) + context->dns_transports[0] == GETDNS_TRANSPORT_TLS && + context->dns_transport_count == 1) return GETDNS_RETURN_BAD_CONTEXT; if (context->resolution_type_set == context->resolution_type) @@ -2321,17 +2304,16 @@ priv_get_context_settings(getdns_context* context) { upstreams); getdns_list_destroy(upstreams); } - /* create a transport list */ - getdns_list* transports = getdns_list_create_with_context(context); - if (transports) { - int transport_count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &transport_count); - for (int i = 0; i < transport_count; i++) { - r |= getdns_list_set_int(transports, i, transport_list[i]); + if (context->dns_transport_count > 0) { + /* create a namespace list */ + size_t i; + getdns_list* transports = getdns_list_create_with_context(context); + if (transports) { + for (i = 0; i < context->dns_transport_count; ++i) { + r |= getdns_list_set_int(transports, i, context->dns_transports[i]); + } + r |= getdns_dict_set_list(result, "dns_transport_list", transports); } - r |= getdns_dict_set_list(result, "dns_transport_list", transports); - free(transport_list); } if (context->namespace_count > 0) { /* create a namespace list */ @@ -2525,35 +2507,34 @@ getdns_context_get_dns_transport(getdns_context *context, getdns_transport_t* value) { RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(value, GETDNS_RETURN_INVALID_PARAMETER); - int count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &count); - if (!count) + int count = context->dns_transport_count; + getdns_transport_list_t *transports = context->dns_transports; + if (!count) return GETDNS_RETURN_WRONG_TYPE_REQUESTED; /* Best effort mapping for backwards compatibility*/ - if (transport_list[0] == GETDNS_TRANSPORT_UDP) { + if (transports[0] == GETDNS_TRANSPORT_UDP) { if (count == 1) *value = GETDNS_TRANSPORT_UDP_ONLY; - else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; } - if (transport_list[0] == GETDNS_TRANSPORT_TCP) { + if (transports[0] == GETDNS_TRANSPORT_TCP) { if (count == 1) *value = GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN; } - if (transport_list[0] == GETDNS_TRANSPORT_TLS) { + if (transports[0] == GETDNS_TRANSPORT_TLS) { if (count == 1) *value = GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN; - else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; } - if (transport_list[0] == GETDNS_TRANSPORT_STARTTLS) { - if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + if (transports[0] == GETDNS_TRANSPORT_STARTTLS) { + if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; @@ -2567,16 +2548,15 @@ getdns_context_get_dns_transport_list(getdns_context *context, RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(transport_count, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(transports, GETDNS_RETURN_INVALID_PARAMETER); - - int count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &count); - *transport_count = count; - if (!transport_count) { + *transport_count = context->dns_transport_count; + if (!context->dns_transport_count) { *transports = NULL; return GETDNS_RETURN_GOOD; } - *transports = transport_list; + // use normal malloc here so users can do normal free + *transports = malloc(context->dns_transport_count * sizeof(getdns_transport_list_t)); + memcpy(*transports, context->dns_transports, + context->dns_transport_count * sizeof(getdns_transport_list_t)); return GETDNS_RETURN_GOOD; } diff --git a/src/context.h b/src/context.h index 040f18d3..942a0ed4 100644 --- a/src/context.h +++ b/src/context.h @@ -91,7 +91,7 @@ typedef struct getdns_upstream { /* For sharing a TCP socket to this upstream */ int fd; - getdns_base_transport_t dns_base_transport; + getdns_transport_list_t transport; SSL* tls_obj; getdns_tls_hs_state_t tls_hs_state; getdns_dns_req * starttls_req; @@ -138,10 +138,13 @@ struct getdns_context { struct getdns_list *suffix; struct getdns_list *dnssec_trust_anchors; getdns_upstreams *upstreams; - getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; uint16_t limit_outstanding_queries; uint32_t dnssec_allowed_skew; + getdns_transport_list_t *dns_transports; + size_t dns_transport_count; + size_t dns_transport_current; + uint8_t edns_extended_rcode; uint8_t edns_version; uint8_t edns_do_bit; diff --git a/src/request-internal.c b/src/request-internal.c index b4267022..90b343aa 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -62,6 +62,7 @@ network_req_cleanup(getdns_network_req *net_req) if (net_req->response && (net_req->response < net_req->wire_data || net_req->response > net_req->wire_data+ net_req->wire_data_sz)) GETDNS_FREE(net_req->owner->my_mf, net_req->response); + GETDNS_FREE(net_req->owner->my_mf, net_req->transports); } static int @@ -89,9 +90,13 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner, net_req->upstream = NULL; net_req->fd = -1; - for (i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - net_req->dns_base_transports[i] = owner->context->dns_base_transports[i]; - net_req->transport = 0; + net_req->transports = GETDNS_XMALLOC(net_req->owner->my_mf, + getdns_transport_list_t, + owner->context->dns_transport_count); + memcpy(owner->context->dns_transports, net_req->transports, + owner->context->dns_transport_count * sizeof(getdns_transport_list_t)); + net_req->transport_count = owner->context->dns_transport_count; + net_req->transport_current = 0; memset(&net_req->event, 0, sizeof(net_req->event)); memset(&net_req->tcp, 0, sizeof(net_req->tcp)); net_req->query_id = 0; diff --git a/src/stub.c b/src/stub.c index df06c9c4..ad913f19 100644 --- a/src/stub.c +++ b/src/stub.c @@ -59,6 +59,9 @@ static void upstream_read_cb(void *userarg); static void upstream_write_cb(void *userarg); static void upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq); +static int upstream_connect(getdns_upstream *upstream, + getdns_transport_list_t transport, + getdns_dns_req *dnsreq); static void netreq_upstream_read_cb(void *userarg); static void netreq_upstream_write_cb(void *userarg); static int fallback_on_write(getdns_network_req *netreq); @@ -354,7 +357,7 @@ getdns_sock_nonblock(int sockfd) } static int -tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +tcp_connect(getdns_upstream *upstream, getdns_transport_list_t transport) { int fd = -1; if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) @@ -363,9 +366,8 @@ tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) getdns_sock_nonblock(fd); #ifdef USE_TCP_FASTOPEN /* Leave the connect to the later call to sendto() if using TCP*/ - if (transport == GETDNS_BASE_TRANSPORT_TCP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || - transport == GETDNS_BASE_TRANSPORT_STARTTLS) + if (transport == GETDNS_TRANSPORT_TCP || + transport == GETDNS_TRANSPORT_STARTTLS) return fd; #endif if (connect(fd, (struct sockaddr *)&upstream->addr, @@ -667,20 +669,7 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) if (! tcp->write_buf) { /* No, this is an initial write. Try to send */ - - /* Not keeping connections open? Then the first random number - * will do as the query id. - * - * Otherwise find a unique query_id not already written (or in - * the write_queue) for that upstream. Register this netreq - * by query_id in the process. - */ - if ((netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_TCP_SINGLE) || - (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_UDP)) - query_id = arc4random(); - else do { + do { query_id = arc4random(); query_id_intptr = (intptr_t)query_id; netreq->node.key = (void *)query_id_intptr; @@ -774,10 +763,10 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) static int tls_requested(getdns_network_req *netreq) { - return (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_TLS || - netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_STARTTLS) ? + return (netreq->transports[netreq->transport_current] == + GETDNS_TRANSPORT_TLS || + netreq->transports[netreq->transport_current] == + GETDNS_TRANSPORT_STARTTLS) ? 1 : 0; } @@ -786,16 +775,16 @@ tls_should_write(getdns_upstream *upstream) { /* Should messages be written on TLS upstream. Remember that for STARTTLS * the first message should got over TCP as the handshake isn't started yet.*/ - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && upstream->tls_hs_state != GETDNS_HS_NONE) ? 1 : 0; } static int tls_should_read(getdns_upstream *upstream) { - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && !(upstream->tls_hs_state == GETDNS_HS_FAILED || upstream->tls_hs_state == GETDNS_HS_NONE)) ? 1 : 0; } @@ -804,8 +793,8 @@ static int tls_failed(getdns_upstream *upstream) { /* No messages should be scheduled onto an upstream in this state */ - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && upstream->tls_hs_state == GETDNS_HS_FAILED) ? 1: 0; } @@ -1070,23 +1059,16 @@ stub_udp_read_cb(void *userarg) return; /* Client cookie didn't match? */ close(netreq->fd); - /* TODO: check not past end of transports*/ - getdns_base_transport_t next_transport = - netreq->dns_base_transports[netreq->transport + 1]; - if (GLDNS_TC_WIRE(netreq->response) && - next_transport == GETDNS_BASE_TRANSPORT_TCP) { - - if ((netreq->fd = socket( - upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + if (GLDNS_TC_WIRE(netreq->response)) { + if (!(netreq->transport_current < netreq->transport_count)) goto done; - - getdns_sock_nonblock(netreq->fd); - if (connect(netreq->fd, (struct sockaddr *)&upstream->addr, - upstream->addr_len) == -1 && errno != EINPROGRESS) { - - close(netreq->fd); + netreq->transport_current++; + if (netreq->transport_current != GETDNS_TRANSPORT_TCP) goto done; - } + if ((netreq->fd = upstream_connect(upstream, netreq->transport_current, + dnsreq)) == -1) + goto done; + GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, @@ -1427,20 +1409,19 @@ netreq_upstream_write_cb(void *userarg) static int upstream_transport_valid(getdns_upstream *upstream, - getdns_base_transport_t transport) + getdns_transport_list_t transport) { - /* For single shot transports, use only the TCP upstream. */ - if (transport == GETDNS_BASE_TRANSPORT_UDP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) - return (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP ? 1:0); + /* Single shot UDP, uses same upstream as plain TCP. */ + if (transport == GETDNS_TRANSPORT_UDP) + return (upstream->transport == GETDNS_TRANSPORT_TCP ? 1:0); /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't * upgraded to avoid opening a new connection if one is aleady open. */ - if (transport == GETDNS_BASE_TRANSPORT_TCP && - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && + if (transport == GETDNS_TRANSPORT_TCP && + upstream->transport == GETDNS_TRANSPORT_STARTTLS && upstream->tls_hs_state == GETDNS_HS_FAILED) return 1; /* Otherwise, transport must match, and not have failed */ - if (upstream->dns_base_transport != transport) + if (upstream->transport != transport) return 0; if (tls_failed(upstream)) return 0; @@ -1448,7 +1429,7 @@ upstream_transport_valid(getdns_upstream *upstream, } static getdns_upstream * -upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) +upstream_select(getdns_network_req *netreq, getdns_transport_list_t transport) { getdns_upstream *upstream; getdns_upstreams *upstreams = netreq->owner->upstreams; @@ -1489,31 +1470,29 @@ upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) int -upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, +upstream_connect(getdns_upstream *upstream, getdns_transport_list_t transport, getdns_dns_req *dnsreq) { DEBUG_STUB("%s\n", __FUNCTION__); int fd = -1; switch(transport) { - case GETDNS_BASE_TRANSPORT_UDP: + case GETDNS_TRANSPORT_UDP: if ((fd = socket( upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) return -1; getdns_sock_nonblock(fd); return fd; - case GETDNS_BASE_TRANSPORT_TCP: + case GETDNS_TRANSPORT_TCP: /* Use existing if available*/ if (upstream->fd != -1) return upstream->fd; - /* Otherwise, fall through */ - case GETDNS_BASE_TRANSPORT_TCP_SINGLE: fd = tcp_connect(upstream, transport); upstream->loop = dnsreq->context->extension; upstream->fd = fd; break; - case GETDNS_BASE_TRANSPORT_TLS: + case GETDNS_TRANSPORT_TLS: /* Use existing if available*/ if (upstream->fd != -1 && !tls_failed(upstream)) return upstream->fd; @@ -1528,7 +1507,7 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, upstream->loop = dnsreq->context->extension; upstream->fd = fd; break; - case GETDNS_BASE_TRANSPORT_STARTTLS: + case GETDNS_TRANSPORT_STARTTLS: /* Use existing if available. Let the fallback code handle it if * STARTTLS isn't availble. */ if (upstream->fd != -1) @@ -1559,7 +1538,7 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, static getdns_upstream* find_upstream_for_specific_transport(getdns_network_req *netreq, - getdns_base_transport_t transport, + getdns_transport_list_t transport, int *fd) { /* TODO[TLS]: Fallback through upstreams....?*/ @@ -1574,15 +1553,13 @@ static int find_upstream_for_netreq(getdns_network_req *netreq) { int fd = -1; - int i = netreq->transport; - for (; i < GETDNS_BASE_TRANSPORT_MAX && - netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) { + for (size_t i = 0; i < netreq->transport_count; i++) { netreq->upstream = find_upstream_for_specific_transport(netreq, - netreq->dns_base_transports[i], + netreq->transports[i], &fd); if (fd == -1 || !netreq->upstream) continue; - netreq->transport = i; + netreq->transport_current = i; return fd; } return -1; @@ -1645,7 +1622,7 @@ move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, stub_timeout_cb)); } } - netreq->transport++; + netreq->transport_current++; return upstream->fd; } @@ -1654,16 +1631,18 @@ fallback_on_write(getdns_network_req *netreq) { DEBUG_STUB("%s\n", __FUNCTION__); /* TODO[TLS]: Fallback through all transports.*/ - getdns_base_transport_t next_transport = - netreq->dns_base_transports[netreq->transport + 1]; - if (next_transport == GETDNS_BASE_TRANSPORT_NONE) + if (netreq->transport_current = netreq->transport_count - 1) return STUB_TCP_ERROR; - if (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_STARTTLS && - next_transport == GETDNS_BASE_TRANSPORT_TCP) { - /* Special case where can stay on same upstream*/ - netreq->transport++; + getdns_transport_list_t next_transport = + netreq->transports[netreq->transport_current + 1]; + + if (netreq->transports[netreq->transport_current] == + GETDNS_TRANSPORT_STARTTLS && + next_transport == GETDNS_TRANSPORT_TCP) { + /* TODO[TLS]: Check this is always OK.... + * Special case where can stay on same upstream*/ + netreq->transport_current++; return netreq->upstream->fd; } getdns_upstream *upstream = netreq->upstream; @@ -1722,22 +1701,21 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq) if (fd == -1) return GETDNS_RETURN_GENERIC_ERROR; - getdns_base_transport_t transport = - netreq->dns_base_transports[netreq->transport]; + getdns_transport_list_t transport = + netreq->transports[netreq->transport_current]; switch(transport) { - case GETDNS_BASE_TRANSPORT_UDP: - case GETDNS_BASE_TRANSPORT_TCP_SINGLE: + case GETDNS_TRANSPORT_UDP: netreq->fd = fd; GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, - NULL, (transport == GETDNS_BASE_TRANSPORT_UDP ? + NULL, (transport == GETDNS_TRANSPORT_UDP ? stub_udp_write_cb: stub_tcp_write_cb), stub_timeout_cb)); return GETDNS_RETURN_GOOD; - case GETDNS_BASE_TRANSPORT_STARTTLS: - case GETDNS_BASE_TRANSPORT_TLS: - case GETDNS_BASE_TRANSPORT_TCP: + case GETDNS_TRANSPORT_STARTTLS: + case GETDNS_TRANSPORT_TLS: + case GETDNS_TRANSPORT_TCP: upstream_schedule_netreq(netreq->upstream, netreq); /* TODO[TLS]: Change scheduling for sync calls. */ GETDNS_SCHEDULE_EVENT( diff --git a/src/test/getdns_query.c b/src/test/getdns_query.c index 83064557..a390261e 100644 --- a/src/test/getdns_query.c +++ b/src/test/getdns_query.c @@ -443,7 +443,7 @@ getdns_return_t parse_args(int argc, char **argv) return GETDNS_RETURN_GENERIC_ERROR; } size_t transport_count = 0; - getdns_transport_list_t transports[GETDNS_BASE_TRANSPORT_MAX]; + getdns_transport_list_t transports[strlen(argv[])]; if ((r = fill_transport_list(context, argv[i], transports, &transport_count)) || (r = getdns_context_set_dns_transport_list(context, transport_count, transports))){ diff --git a/src/types-internal.h b/src/types-internal.h index 53d9db2e..fb6af68e 100644 --- a/src/types-internal.h +++ b/src/types-internal.h @@ -99,6 +99,9 @@ struct getdns_upstream; #define TIMEOUT_FOREVER ((int64_t)-1) #define ASSERT_UNREACHABLE 0 +#define GETDNS_TRANSPORTS_MAX 4 +#define GETDNS_UPSTREAM_TRANSPORTS 3 + /** @} */ @@ -164,17 +167,17 @@ typedef struct getdns_tcp_state { } getdns_tcp_state; -/* TODO[TLS]: change this name to getdns_transport when API updated*/ -typedef enum getdns_base_transport { - GETDNS_BASE_TRANSPORT_MIN = 0, - GETDNS_BASE_TRANSPORT_NONE = 0, - GETDNS_BASE_TRANSPORT_UDP, - GETDNS_BASE_TRANSPORT_TCP_SINGLE, /* To be removed? */ - GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */ - GETDNS_BASE_TRANSPORT_TCP, - GETDNS_BASE_TRANSPORT_TLS, - GETDNS_BASE_TRANSPORT_MAX -} getdns_base_transport_t; +// /* TODO[TLS]: change this name to getdns_transport when API updated*/ +// typedef enum getdns_base_transport { +// GETDNS_TRANSPORT_MIN = 0, +// GETDNS_TRANSPORT_NONE = 0, +// GETDNS_TRANSPORT_UDP, +// GETDNS_TRANSPORT_TCP_SINGLE, /* To be removed? */ +// GETDNS_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */ +// GETDNS_TRANSPORT_TCP, +// GETDNS_TRANSPORT_TLS, +// GETDNS_TRANSPORT_MAX +// } getdns_base_transport_t; /** * Request data @@ -203,8 +206,9 @@ typedef struct getdns_network_req /* For stub resolving */ struct getdns_upstream *upstream; int fd; - getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; - int transport; + getdns_transport_list_t *transports; + size_t transport_count; + size_t transport_current; getdns_eventloop_event event; getdns_tcp_state tcp; uint16_t query_id;