diff --git a/src/context.c b/src/context.c index f0a3631a..f1a51536 100644 --- a/src/context.c +++ b/src/context.c @@ -69,11 +69,15 @@ typedef struct host_name_addrs { uint8_t host_name[]; } host_name_addrs; +static getdns_transport_list_t +getdns_upstream_transports[GETDNS_UPSTREAM_TRANSPORTS] = { + GETDNS_TRANSPORT_STARTTLS, // Define before TCP to ease fallback + GETDNS_TRANSPORT_TCP, + GETDNS_TRANSPORT_TLS, +}; + static in_port_t -getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { - GETDNS_PORT_ZERO, - GETDNS_PORT_ZERO, - GETDNS_PORT_ZERO, +getdns_port_array[GETDNS_UPSTREAM_TRANSPORTS] = { GETDNS_PORT_DNS, GETDNS_PORT_DNS, GETDNS_PORT_DNS_OVER_TLS @@ -81,9 +85,6 @@ getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { char* getdns_port_str_array[] = { - GETDNS_STR_PORT_ZERO, - GETDNS_STR_PORT_ZERO, - GETDNS_STR_PORT_ZERO, GETDNS_STR_PORT_DNS, GETDNS_STR_PORT_DNS, GETDNS_STR_PORT_DNS_OVER_TLS @@ -91,6 +92,7 @@ getdns_port_str_array[] = { /* Private functions */ getdns_return_t create_default_namespaces(struct getdns_context *context); +getdns_return_t create_default_dns_transports(struct getdns_context *context); static struct getdns_list *create_default_root_servers(void); static getdns_return_t set_os_defaults(struct getdns_context *); static int transaction_id_cmp(const void *, const void *); @@ -139,42 +141,22 @@ create_default_namespaces(struct getdns_context *context) return GETDNS_RETURN_GOOD; } -static getdns_transport_list_t * -get_dns_transport_list(getdns_context *context, int *count) +/** + * Helper to get default transports. + */ +getdns_return_t +create_default_dns_transports(struct getdns_context *context) { - if (context == NULL) - return NULL; + context->dns_transports = GETDNS_XMALLOC(context->my_mf, getdns_transport_list_t, 2); + if(context->dns_transports == NULL) + return GETDNS_RETURN_GENERIC_ERROR; - /* Count how many we have*/ - for (*count = 0; *count < GETDNS_BASE_TRANSPORT_MAX; (*count)++) { - if (context->dns_base_transports[*count] == GETDNS_BASE_TRANSPORT_NONE) - break; - } + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; + context->dns_transport_count = 2; + context->dns_transport_current = 0; - // use normal malloc here so users can do normal free - getdns_transport_list_t * transports = malloc(*count * sizeof(getdns_transport_list_t)); - - if(transports == NULL) - return NULL; - for (int i = 0; i < (int)*count; i++) { - switch(context->dns_base_transports[i]) { - case GETDNS_BASE_TRANSPORT_UDP: - transports[i] = GETDNS_TRANSPORT_UDP; - break; - case GETDNS_BASE_TRANSPORT_TCP: - transports[i] = GETDNS_TRANSPORT_TCP; - break; - case GETDNS_BASE_TRANSPORT_TLS: - transports[i] = GETDNS_TRANSPORT_TLS; - break; - case GETDNS_BASE_TRANSPORT_STARTTLS: - transports[i] = GETDNS_TRANSPORT_STARTTLS; - break; - default: - break; - } - } - return transports; + return GETDNS_RETURN_GOOD; } static inline void canonicalize_dname(uint8_t *dname) @@ -614,6 +596,8 @@ upstream_init(getdns_upstream *upstream, (void) memcpy(&upstream->addr, ai->ai_addr, ai->ai_addrlen); /* How is this upstream doing? */ + upstream->writes_done = 0; + upstream->responses_recieved = 0; upstream->to_retry = 2; upstream->back_off = 1; @@ -621,8 +605,9 @@ upstream_init(getdns_upstream *upstream, upstream->fd = -1; upstream->tls_obj = NULL; upstream->starttls_req = NULL; - upstream->dns_base_transport = GETDNS_BASE_TRANSPORT_TCP; + upstream->transport = GETDNS_TRANSPORT_TCP; upstream->tls_hs_state = GETDNS_HS_NONE; + upstream->tcp.write_error = 0; upstream->loop = NULL; (void) getdns_eventloop_event_init( &upstream->event, upstream, NULL, NULL, NULL); @@ -725,11 +710,8 @@ set_os_defaults(struct getdns_context *context) token = parse + strcspn(parse, " \t\r\n"); *token = 0; - getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; - for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { - char *port_str = getdns_port_str_array[base_transport]; - if (strncmp(port_str, GETDNS_STR_PORT_ZERO, 1) == 0) - continue; + for (size_t i = 0; i < GETDNS_UPSTREAM_TRANSPORTS; i++) { + char *port_str = getdns_port_str_array[i]; if ((s = getaddrinfo(parse, port_str, &hints, &result))) continue; if (!result) @@ -743,7 +725,7 @@ set_os_defaults(struct getdns_context *context) upstream = &context->upstreams-> upstreams[context->upstreams->count++]; upstream_init(upstream, context->upstreams, result); - upstream->dns_base_transport = base_transport; + upstream->transport = getdns_upstream_transports[i]; freeaddrinfo(result); } } @@ -873,8 +855,8 @@ getdns_context_create_with_extended_memory_functions( result->dnssec_allowed_skew = 0; result->edns_maximum_udp_payload_size = -1; - result->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; - result->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + if ((r = create_default_dns_transports(result))) + goto error; result->limit_outstanding_queries = 0; result->has_ta = priv_getdns_parse_ta_file(NULL, NULL); result->return_dnssec_status = GETDNS_EXTENSION_FALSE; @@ -958,6 +940,8 @@ getdns_context_destroy(struct getdns_context *context) if (context->namespaces) GETDNS_FREE(context->my_mf, context->namespaces); + if (context->dns_transports) + GETDNS_FREE(context->my_mf, context->dns_transports); if(context->fchg_resolvconf) { if(context->fchg_resolvconf->prevstat) @@ -1201,67 +1185,79 @@ static getdns_return_t getdns_set_base_dns_transports(struct getdns_context *context, size_t transport_count, getdns_transport_list_t *transports) { - RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); - for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE; + size_t i; - if ((int)transport_count == 0 || transports == NULL || - (int)transport_count > GETDNS_BASE_TRANSPORT_MAX) { - return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; + RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); + if (transport_count == 0 || transports == NULL) { + return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; + } + + for(i=0; idns_base_transports[j] = GETDNS_BASE_TRANSPORT_UDP; - break; - case GETDNS_TRANSPORT_TCP: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TCP; - break; - case GETDNS_TRANSPORT_TLS: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TLS; - break; - case GETDNS_TRANSPORT_STARTTLS: - context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_STARTTLS; - break; - default: - return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; - } - } - return GETDNS_RETURN_GOOD; + GETDNS_FREE(context->my_mf, context->dns_transports); + + /** duplicate **/ + context->dns_transports = GETDNS_XMALLOC(context->my_mf, + getdns_transport_list_t, transport_count); + memcpy(context->dns_transports, transports, + transport_count * sizeof(getdns_transport_list_t)); + context->dns_transport_count = transport_count; + dispatch_updated(context, GETDNS_CONTEXT_CODE_NAMESPACES); + + return GETDNS_RETURN_GOOD; } static getdns_return_t set_ub_dns_transport(struct getdns_context* context) { /* These mappings are not exact because Unbound is configured differently, - so just map as close as possible from the first 1 or 2 transports. */ - switch (context->dns_base_transports[0]) { - case GETDNS_BASE_TRANSPORT_UDP: + so just map as close as possible. Not all options can be supported.*/ + switch (context->dns_transports[0]) { + case GETDNS_TRANSPORT_UDP: set_ub_string_opt(context, "do-udp:", "yes"); - if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_TCP) + if (context->dns_transports[1] == GETDNS_TRANSPORT_TCP) set_ub_string_opt(context, "do-tcp:", "yes"); else set_ub_string_opt(context, "do-tcp:", "no"); break; - case GETDNS_BASE_TRANSPORT_TLS: - /* Note: If TLS is used in recursive mode this will try TLS on port - * 53... So this is prohibited when preparing for resolution.*/ - if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE) { - set_ub_string_opt(context, "ssl-upstream:", "yes"); - set_ub_string_opt(context, "do-udp:", "no"); - set_ub_string_opt(context, "do-tcp:", "yes"); - break; - } - if (context->dns_base_transports[1] != GETDNS_BASE_TRANSPORT_TCP) - break; - /* Fallthrough */ - case GETDNS_BASE_TRANSPORT_STARTTLS: - case GETDNS_BASE_TRANSPORT_TCP: - /* Note: no STARTTLS or fallback to TCP available directly in unbound, so we just - * use TCP for now to make sure the messages are sent. */ + case GETDNS_TRANSPORT_TCP: set_ub_string_opt(context, "do-udp:", "no"); set_ub_string_opt(context, "do-tcp:", "yes"); break; + case GETDNS_TRANSPORT_TLS: + case GETDNS_TRANSPORT_STARTTLS: + set_ub_string_opt(context, "do-udp:", "no"); + set_ub_string_opt(context, "do-tcp:", "yes"); + /* Find out if there is a fallback available. */ + int fallback = 0; + for (size_t i = 1; i < context->dns_transport_count; i++) { + if (context->dns_transports[i] == GETDNS_TRANSPORT_TCP) { + fallback = 1; + break; + } + else if (context->dns_transports[i] == GETDNS_TRANSPORT_UDP) { + set_ub_string_opt(context, "do-udp:", "yes"); + set_ub_string_opt(context, "do-tcp:", "no"); + fallback = 1; + break; + } + } + if (context->dns_transports[0] == GETDNS_TRANSPORT_TLS) { + if (fallback == 0) + /* Use TLS if it is the only thing.*/ + set_ub_string_opt(context, "ssl-upstream:", "yes"); + break; + } else if (fallback == 0) + /* Can't support STARTTLS with no fallback. This leads to + * timeouts with un stub validation.... */ + set_ub_string_opt(context, "do-tcp:", "no"); + break; default: return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; } @@ -1278,31 +1274,37 @@ getdns_context_set_dns_transport(struct getdns_context *context, { RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); - for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE; + size_t count = 2; + if (value == GETDNS_TRANSPORT_UDP_ONLY || + value == GETDNS_TRANSPORT_TCP_ONLY || + value == GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN || + value == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) + count = 1; + context->dns_transports = GETDNS_XMALLOC(context->my_mf, + getdns_transport_list_t, count); switch (value) { case GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_UDP_ONLY: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; + context->dns_transports[0] = GETDNS_TRANSPORT_UDP; break; case GETDNS_TRANSPORT_TCP_ONLY: case GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; + context->dns_transports[0] = GETDNS_TRANSPORT_TLS; break; case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_TLS; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_STARTTLS; - context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + context->dns_transports[0] = GETDNS_TRANSPORT_STARTTLS; + context->dns_transports[1] = GETDNS_TRANSPORT_TCP; break; default: return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; @@ -1658,15 +1660,14 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, } /* Loop to create upstreams as needed*/ - getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; - for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { + for (size_t j = 0; j < GETDNS_UPSTREAM_TRANSPORTS; j++) { uint32_t port; struct addrinfo *ai; - port = getdns_port_array[base_transport]; + port = getdns_port_array[j]; if (port == GETDNS_PORT_ZERO) continue; - if (base_transport != GETDNS_BASE_TRANSPORT_TLS) + if (getdns_upstream_transports[j] != GETDNS_TRANSPORT_TLS) (void) getdns_dict_get_int(dict, "port", &port); else (void) getdns_dict_get_int(dict, "tls_port", &port); @@ -1689,7 +1690,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, upstream = &upstreams->upstreams[upstreams->count]; upstream->addr.ss_family = addr.ss_family; upstream_init(upstream, upstreams, ai); - upstream->dns_base_transport = base_transport; + upstream->transport = getdns_upstream_transports[j]; upstreams->count++; freeaddrinfo(ai); } @@ -1913,9 +1914,9 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) upstream = &upstreams->upstreams[i]; /*[TLS]: Use only the TLS subset of upstreams when TLS is the only thing * used. All other cases must currently fallback to TCP for libunbound.*/ - if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && - context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE && - upstream->dns_base_transport != GETDNS_BASE_TRANSPORT_TLS) + if (context->dns_transports[0] == GETDNS_TRANSPORT_TLS && + context->dns_transport_count ==1 && + upstream->transport != GETDNS_TRANSPORT_TLS) continue; upstream_ntop_buf(upstream, addr, 1024); ub_ctx_set_fwd(ctx, addr); @@ -2023,10 +2024,13 @@ getdns_context_prepare_for_resolution(struct getdns_context *context, return GETDNS_RETURN_BAD_CONTEXT; } } - /* Block use of TLS ONLY in recursive mode as it won't work */ + /* Block use of STARTTLS/TLS ONLY in recursive mode as it won't work */ + /* Note: If TLS is used in recursive mode this will try TLS on port + * 53 so it is blocked here. So is STARTTLS only at the moment. */ if (context->resolution_type == GETDNS_RESOLUTION_RECURSING && - context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && - context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE) + context->dns_transport_count == 1 && + (context->dns_transports[0] == GETDNS_TRANSPORT_TLS || + context->dns_transports[0] == GETDNS_TRANSPORT_STARTTLS)) return GETDNS_RETURN_BAD_CONTEXT; if (context->resolution_type_set == context->resolution_type) @@ -2321,17 +2325,16 @@ priv_get_context_settings(getdns_context* context) { upstreams); getdns_list_destroy(upstreams); } - /* create a transport list */ - getdns_list* transports = getdns_list_create_with_context(context); - if (transports) { - int transport_count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &transport_count); - for (int i = 0; i < transport_count; i++) { - r |= getdns_list_set_int(transports, i, transport_list[i]); + if (context->dns_transport_count > 0) { + /* create a namespace list */ + size_t i; + getdns_list* transports = getdns_list_create_with_context(context); + if (transports) { + for (i = 0; i < context->dns_transport_count; ++i) { + r |= getdns_list_set_int(transports, i, context->dns_transports[i]); + } + r |= getdns_dict_set_list(result, "dns_transport_list", transports); } - r |= getdns_dict_set_list(result, "dns_transport_list", transports); - free(transport_list); } if (context->namespace_count > 0) { /* create a namespace list */ @@ -2525,35 +2528,34 @@ getdns_context_get_dns_transport(getdns_context *context, getdns_transport_t* value) { RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(value, GETDNS_RETURN_INVALID_PARAMETER); - int count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &count); - if (!count) + int count = context->dns_transport_count; + getdns_transport_list_t *transports = context->dns_transports; + if (!count) return GETDNS_RETURN_WRONG_TYPE_REQUESTED; /* Best effort mapping for backwards compatibility*/ - if (transport_list[0] == GETDNS_TRANSPORT_UDP) { + if (transports[0] == GETDNS_TRANSPORT_UDP) { if (count == 1) *value = GETDNS_TRANSPORT_UDP_ONLY; - else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; } - if (transport_list[0] == GETDNS_TRANSPORT_TCP) { + if (transports[0] == GETDNS_TRANSPORT_TCP) { if (count == 1) *value = GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN; } - if (transport_list[0] == GETDNS_TRANSPORT_TLS) { + if (transports[0] == GETDNS_TRANSPORT_TLS) { if (count == 1) *value = GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN; - else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; } - if (transport_list[0] == GETDNS_TRANSPORT_STARTTLS) { - if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP) + if (transports[0] == GETDNS_TRANSPORT_STARTTLS) { + if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP) *value = GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN; else return GETDNS_RETURN_WRONG_TYPE_REQUESTED; @@ -2567,16 +2569,15 @@ getdns_context_get_dns_transport_list(getdns_context *context, RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(transport_count, GETDNS_RETURN_INVALID_PARAMETER); RETURN_IF_NULL(transports, GETDNS_RETURN_INVALID_PARAMETER); - - int count; - getdns_transport_list_t *transport_list = - get_dns_transport_list(context, &count); - *transport_count = count; - if (!transport_count) { + *transport_count = context->dns_transport_count; + if (!context->dns_transport_count) { *transports = NULL; return GETDNS_RETURN_GOOD; } - *transports = transport_list; + // use normal malloc here so users can do normal free + *transports = malloc(context->dns_transport_count * sizeof(getdns_transport_list_t)); + memcpy(*transports, context->dns_transports, + context->dns_transport_count * sizeof(getdns_transport_list_t)); return GETDNS_RETURN_GOOD; } diff --git a/src/context.h b/src/context.h index 040f18d3..26c65bdd 100644 --- a/src/context.h +++ b/src/context.h @@ -86,12 +86,14 @@ typedef struct getdns_upstream { struct sockaddr_storage addr; /* How is this upstream doing? */ + size_t writes_done; + size_t responses_recieved; int to_retry; int back_off; /* For sharing a TCP socket to this upstream */ int fd; - getdns_base_transport_t dns_base_transport; + getdns_transport_list_t transport; SSL* tls_obj; getdns_tls_hs_state_t tls_hs_state; getdns_dns_req * starttls_req; @@ -138,10 +140,13 @@ struct getdns_context { struct getdns_list *suffix; struct getdns_list *dnssec_trust_anchors; getdns_upstreams *upstreams; - getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; uint16_t limit_outstanding_queries; uint32_t dnssec_allowed_skew; + getdns_transport_list_t *dns_transports; + size_t dns_transport_count; + size_t dns_transport_current; + uint8_t edns_extended_rcode; uint8_t edns_version; uint8_t edns_do_bit; diff --git a/src/request-internal.c b/src/request-internal.c index b4267022..33ef055c 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -89,9 +89,10 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner, net_req->upstream = NULL; net_req->fd = -1; - for (i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) - net_req->dns_base_transports[i] = owner->context->dns_base_transports[i]; - net_req->transport = 0; + net_req->transport_count = owner->context->dns_transport_count; + net_req->transport_current = 0; + memcpy(net_req->transports, owner->context->dns_transports, + net_req->transport_count * sizeof(getdns_transport_list_t)); memset(&net_req->event, 0, sizeof(net_req->event)); memset(&net_req->tcp, 0, sizeof(net_req->tcp)); net_req->query_id = 0; diff --git a/src/stub.c b/src/stub.c index b667ee64..4e143a77 100644 --- a/src/stub.c +++ b/src/stub.c @@ -57,14 +57,22 @@ static uint32_t prev_secret = 0; static void upstream_read_cb(void *userarg); static void upstream_write_cb(void *userarg); +static void upstream_idle_timeout_cb(void *userarg); static void upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq); +static void upstream_reschedule_events(getdns_upstream *upstream, + size_t idle_timeout); +static void upstream_reschedule_netreq_events(getdns_upstream *upstream, + getdns_network_req *netreq); +static int upstream_connect(getdns_upstream *upstream, + getdns_transport_list_t transport, + getdns_dns_req *dnsreq); static void netreq_upstream_read_cb(void *userarg); static void netreq_upstream_write_cb(void *userarg); static int fallback_on_write(getdns_network_req *netreq); static void stub_tcp_write_cb(void *userarg); - +static void stub_timeout_cb(void *userarg); /*****************************/ /* General utility functions */ /*****************************/ @@ -354,7 +362,7 @@ getdns_sock_nonblock(int sockfd) } static int -tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +tcp_connect(getdns_upstream *upstream, getdns_transport_list_t transport) { int fd = -1; if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) @@ -363,9 +371,8 @@ tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) getdns_sock_nonblock(fd); #ifdef USE_TCP_FASTOPEN /* Leave the connect to the later call to sendto() if using TCP*/ - if (transport == GETDNS_BASE_TRANSPORT_TCP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || - transport == GETDNS_BASE_TRANSPORT_STARTTLS) + if (transport == GETDNS_TRANSPORT_TCP || + transport == GETDNS_TRANSPORT_STARTTLS) return fd; #endif if (connect(fd, (struct sockaddr *)&upstream->addr, @@ -378,6 +385,22 @@ tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) return fd; } +static int +tcp_connected(getdns_upstream *upstream) { + /* Already tried and failed, so let the fallback code take care of things */ + if (upstream->fd == -1 || upstream->tcp.write_error != 0) + return STUB_TCP_ERROR; + + int error = 0; + socklen_t len = (socklen_t)sizeof(error); + getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); + if (error == EINPROGRESS || error == EWOULDBLOCK) + return STUB_TCP_AGAIN; /* try again */ + else if (error != 0) + return STUB_TCP_ERROR; + return 0; +} + /**************************/ /* Error/cleanup functions*/ /**************************/ @@ -394,18 +417,18 @@ stub_next_upstream(getdns_network_req *netreq) * same upstream (and the next message may not use the same transport), * but the next message will find the next matching one thanks to logic in * upstream_select, but this could be better */ - if (++dnsreq->upstreams->current > dnsreq->upstreams->count) + if (++dnsreq->upstreams->current >= dnsreq->upstreams->count) dnsreq->upstreams->current = 0; } static void stub_cleanup(getdns_network_req *netreq) { + DEBUG_STUB("*** %s\n", __FUNCTION__); getdns_dns_req *dnsreq = netreq->owner; getdns_network_req *r, *prev_r; getdns_upstream *upstream; intptr_t query_id_intptr; - int reschedule; GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); @@ -433,28 +456,16 @@ stub_cleanup(getdns_network_req *netreq) if (r == upstream->write_queue_last) upstream->write_queue_last = prev_r ? prev_r : NULL; + netreq->write_queue_tail = NULL; break; } - reschedule = 0; - if (!upstream->write_queue && upstream->event.write_cb) { - upstream->event.write_cb = NULL; - reschedule = 1; - } - if (!upstream->netreq_by_query_id.count && upstream->event.read_cb) { - upstream->event.read_cb = NULL; - reschedule = 1; - } - if (reschedule) { - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - if (upstream->event.read_cb || upstream->event.write_cb) - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); - } + upstream_reschedule_events(upstream, netreq->owner->context->idle_timeout); } static int tls_cleanup(getdns_upstream *upstream) { + DEBUG_STUB("*** %s\n", __FUNCTION__); SSL_free(upstream->tls_obj); upstream->tls_obj = NULL; upstream->tls_hs_state = GETDNS_HS_FAILED; @@ -463,6 +474,17 @@ tls_cleanup(getdns_upstream *upstream) GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, getdns_eventloop_event_init(&upstream->event, upstream, NULL, upstream_write_cb, NULL)); + /* Reset sync event, with full timeout (which isn't correct)*/ + getdns_network_req *netreq = upstream->write_queue; + if (netreq && (netreq->event.write_cb || netreq->event.read_cb)) { + GETDNS_CLEAR_EVENT(netreq->owner->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT( + netreq->owner->loop, upstream->fd, netreq->owner->context->timeout, + getdns_eventloop_event_init( + &netreq->event, netreq, + NULL, netreq_upstream_write_cb, + stub_timeout_cb)); + } return STUB_TLS_SETUP_ERROR; } @@ -492,14 +514,6 @@ upstream_erred(getdns_upstream *upstream) upstream->fd = -1; } -static void -message_erred(getdns_network_req *netreq) -{ - stub_cleanup(netreq); - netreq->state = NET_REQ_FINISHED; - priv_getdns_check_dns_req_complete(netreq->owner); -} - void priv_getdns_cancel_stub_request(getdns_network_req *netreq) { @@ -510,6 +524,7 @@ priv_getdns_cancel_stub_request(getdns_network_req *netreq) static void stub_erred(getdns_network_req *netreq) { + DEBUG_STUB("*** %s\n", __FUNCTION__); stub_next_upstream(netreq); stub_cleanup(netreq); /* TODO[TLS]: When we get an error (which is probably a timeout) and are @@ -522,11 +537,12 @@ stub_erred(getdns_network_req *netreq) static void stub_timeout_cb(void *userarg) { + DEBUG_STUB("*** %s\n", __FUNCTION__); getdns_network_req *netreq = (getdns_network_req *)userarg; /* For now, mark a STARTTLS timeout as a failured negotiation and allow * fallback but don't close the connection. */ - if (is_starttls_response(netreq)) { + if (netreq->owner == netreq->upstream->starttls_req) { netreq->upstream->tls_hs_state = GETDNS_HS_FAILED; stub_next_upstream(netreq); stub_cleanup(netreq); @@ -542,18 +558,23 @@ stub_timeout_cb(void *userarg) static void upstream_idle_timeout_cb(void *userarg) { - DEBUG_STUB("%s\n", __FUNCTION__); getdns_upstream *upstream = (getdns_upstream *)userarg; + DEBUG_STUB("*** %s: **Closing connection %d**\n", + __FUNCTION__, upstream->fd); /*There is a race condition with a new request being scheduled while this happens so take ownership of the fd asap*/ int fd = upstream->fd; upstream->fd = -1; upstream->event.timeout_cb = NULL; + upstream->event.read_cb = NULL; + upstream->event.write_cb = NULL; GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - upstream->tls_hs_state = GETDNS_HS_NONE; + if (upstream->tls_hs_state != GETDNS_HS_FAILED) + upstream->tls_hs_state = GETDNS_HS_NONE; if (upstream->tls_obj != NULL) { SSL_shutdown(upstream->tls_obj); SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; } close(fd); } @@ -562,7 +583,7 @@ upstream_idle_timeout_cb(void *userarg) static void upstream_tls_timeout_cb(void *userarg) { - DEBUG_STUB("%s\n", __FUNCTION__); + DEBUG_STUB("*** %s\n", __FUNCTION__); getdns_upstream *upstream = (getdns_upstream *)userarg; /* Clean up and trigger a write to let the fallback code to its job */ tls_cleanup(upstream); @@ -585,6 +606,33 @@ upstream_tls_timeout_cb(void *userarg) } } +static void +stub_tls_timeout_cb(void *userarg) +{ + DEBUG_STUB("*** %s\n", __FUNCTION__); + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_upstream *upstream = netreq->upstream; + /* Clean up and trigger a write to let the fallback code to its job */ + tls_cleanup(upstream); + + /* Need to handle the case where the far end doesn't respond to a + * TCP SYN and doesn't do a reset (as is the case with e.g. 8.8.8.8@1021). + * For that case the socket never becomes writable so doesn't trigger any + * callbacks. If so then clear out the queue in one go.*/ + int ret; + fd_set fds; + FD_ZERO(&fds); + FD_SET(FD_SET_T upstream->fd, &fds); + struct timeval tval; + tval.tv_sec = 0; + tval.tv_usec = 0; + ret = select(upstream->fd+1, NULL, &fds, NULL, &tval); + if (ret == 0) { + while (upstream->write_queue) + upstream_write_cb(upstream); + } +} + /****************************/ /* TCP read/write functions */ /****************************/ @@ -663,24 +711,15 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) uint16_t query_id; intptr_t query_id_intptr; + int q = tcp_connected(netreq->upstream); + if (q != 0) + return q; + /* Do we have remaining data that we could not write before? */ if (! tcp->write_buf) { /* No, this is an initial write. Try to send */ - - /* Not keeping connections open? Then the first random number - * will do as the query id. - * - * Otherwise find a unique query_id not already written (or in - * the write_queue) for that upstream. Register this netreq - * by query_id in the process. - */ - if ((netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_TCP_SINGLE) || - (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_UDP)) - query_id = arc4random(); - else do { + do { query_id = arc4random(); query_id_intptr = (intptr_t)query_id; netreq->node.key = (void *)query_id_intptr; @@ -774,10 +813,10 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) static int tls_requested(getdns_network_req *netreq) { - return (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_TLS || - netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_STARTTLS) ? + return (netreq->transports[netreq->transport_current] == + GETDNS_TRANSPORT_TLS || + netreq->transports[netreq->transport_current] == + GETDNS_TRANSPORT_STARTTLS) ? 1 : 0; } @@ -786,16 +825,16 @@ tls_should_write(getdns_upstream *upstream) { /* Should messages be written on TLS upstream. Remember that for STARTTLS * the first message should got over TCP as the handshake isn't started yet.*/ - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && upstream->tls_hs_state != GETDNS_HS_NONE) ? 1 : 0; } static int tls_should_read(getdns_upstream *upstream) { - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && !(upstream->tls_hs_state == GETDNS_HS_FAILED || upstream->tls_hs_state == GETDNS_HS_NONE)) ? 1 : 0; } @@ -804,8 +843,8 @@ static int tls_failed(getdns_upstream *upstream) { /* No messages should be scheduled onto an upstream in this state */ - return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + return ((upstream->transport == GETDNS_TRANSPORT_TLS || + upstream->transport == GETDNS_TRANSPORT_STARTTLS) && upstream->tls_hs_state == GETDNS_HS_FAILED) ? 1: 0; } @@ -831,10 +870,11 @@ tls_create_object(getdns_context *context, int fd) static int tls_do_handshake(getdns_upstream *upstream) { - DEBUG_STUB("%s\n", __FUNCTION__); + DEBUG_STUB("--- %s\n", __FUNCTION__); int r; int want; ERR_clear_error(); + getdns_network_req *netreq = upstream->write_queue; while ((r = SSL_do_handshake(upstream->tls_obj)) != 1) { want = SSL_get_error(upstream->tls_obj, r); @@ -845,6 +885,16 @@ tls_do_handshake(getdns_upstream *upstream) GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_TLS, &upstream->event); + /* Reschedule for synchronous */ + if (netreq && netreq->event.write_cb) { + GETDNS_CLEAR_EVENT(netreq->owner->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT( + netreq->owner->loop, upstream->fd, TIMEOUT_TLS, + getdns_eventloop_event_init( + &netreq->event, netreq, + netreq_upstream_read_cb, NULL, + stub_tls_timeout_cb)); + } upstream->tls_hs_state = GETDNS_HS_READ; return STUB_TCP_AGAIN; case SSL_ERROR_WANT_WRITE: @@ -853,6 +903,16 @@ tls_do_handshake(getdns_upstream *upstream) GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_TLS, &upstream->event); + /* Reschedule for synchronous */ + if (netreq && netreq->event.read_cb) { + GETDNS_CLEAR_EVENT(netreq->owner->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT( + netreq->owner->loop, upstream->fd, TIMEOUT_TLS, + getdns_eventloop_event_init( + &netreq->event, netreq, + NULL, netreq_upstream_write_cb, + stub_tls_timeout_cb)); + } upstream->tls_hs_state = GETDNS_HS_WRITE; return STUB_TCP_AGAIN; default: @@ -867,15 +927,23 @@ tls_do_handshake(getdns_upstream *upstream) GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, getdns_eventloop_event_init(&upstream->event, upstream, NULL, upstream_write_cb, NULL)); + /* Reschedule for synchronous */ + /* TODO[TLS]: Re-instating full context->timeout here is wrong, as time has + passes since the netreq was originally scheduled, but we only hove one + timeout in sync mode.... Need a timer on requests really.... Worst case + is we add TIMEOUT_TLS to the total timeout, since TLS is likely to be + the first choice if it is used at all.*/ + if (netreq && (netreq->event.read_cb || netreq->event.write_cb)) + upstream_reschedule_netreq_events(upstream, netreq); return 0; } static int tls_connected(getdns_upstream* upstream) { - /* Already have a connection*/ + /* Already have a TLS connection*/ if (upstream->tls_hs_state == GETDNS_HS_DONE && - (upstream->tls_obj != NULL) && (upstream->fd != -1)) + (upstream->tls_obj != NULL)) return 0; /* Already tried and failed, so let the fallback code take care of things */ @@ -883,13 +951,12 @@ tls_connected(getdns_upstream* upstream) return STUB_TLS_SETUP_ERROR; /* Lets make sure the connection is up before we try a handshake*/ - int error = 0; - socklen_t len = (socklen_t)sizeof(error); - getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); - if (error == EINPROGRESS || error == EWOULDBLOCK) - return STUB_TCP_AGAIN; /* try again */ - else if (error != 0) - return tls_cleanup(upstream); + int q = tcp_connected(upstream); + if (q != 0) { + if (q == STUB_TCP_ERROR) + tls_cleanup(upstream); + return q; + } return tls_do_handshake(upstream); } @@ -1040,6 +1107,7 @@ stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, static void stub_udp_read_cb(void *userarg) { + DEBUG_STUB("%s\n", __FUNCTION__); getdns_network_req *netreq = (getdns_network_req *)userarg; getdns_dns_req *dnsreq = netreq->owner; getdns_upstream *upstream = netreq->upstream; @@ -1070,27 +1138,23 @@ stub_udp_read_cb(void *userarg) return; /* Client cookie didn't match? */ close(netreq->fd); - /* TODO: check not past end of transports*/ - getdns_base_transport_t next_transport = - netreq->dns_base_transports[netreq->transport + 1]; - if (GLDNS_TC_WIRE(netreq->response) && - next_transport == GETDNS_BASE_TRANSPORT_TCP) { - - if ((netreq->fd = socket( - upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + if (GLDNS_TC_WIRE(netreq->response)) { + if (!(netreq->transport_current < netreq->transport_count)) goto done; - - getdns_sock_nonblock(netreq->fd); - if (connect(netreq->fd, (struct sockaddr *)&upstream->addr, - upstream->addr_len) == -1 && errno != EINPROGRESS) { - - close(netreq->fd); + getdns_transport_list_t next_transport = + netreq->transports[++netreq->transport_current]; + if (next_transport != GETDNS_TRANSPORT_TCP) goto done; - } + /* For now, special case where fallback should be on the same upstream*/ + if ((netreq->fd = upstream_connect(upstream, next_transport, + dnsreq)) == -1) + goto done; + upstream_schedule_netreq(netreq->upstream, netreq); GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_tcp_write_cb, stub_timeout_cb)); + dnsreq->loop, netreq->upstream->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, NULL, + ( dnsreq->loop != netreq->upstream->loop /* Synchronous lookup? */ + ? netreq_upstream_write_cb : NULL), stub_timeout_cb)); return; } @@ -1108,6 +1172,7 @@ done: static void stub_udp_write_cb(void *userarg) { + DEBUG_STUB("%s\n", __FUNCTION__); getdns_network_req *netreq = (getdns_network_req *)userarg; getdns_dns_req *dnsreq = netreq->owner; size_t pkt_len = netreq->response - netreq->query; @@ -1221,7 +1286,7 @@ stub_tcp_write_cb(void *userarg) static void upstream_read_cb(void *userarg) { - DEBUG_STUB("%s\n", __FUNCTION__); + DEBUG_STUB("--- READ: %s\n", __FUNCTION__); getdns_upstream *upstream = (getdns_upstream *)userarg; getdns_network_req *netreq; getdns_dns_req *dnsreq; @@ -1264,6 +1329,9 @@ upstream_read_cb(void *userarg) netreq->response_len = upstream->tcp.read_pos - upstream->tcp.read_buf; upstream->tcp.read_buf = NULL; + upstream->responses_recieved++; + /* TODO[TLS]: I don't think we should do this for TCP. We should stay + * on a working connection until we hit a problem.*/ upstream->upstreams->current = 0; /* netreq may die before setting timeout*/ idle_timeout = netreq->owner->context->idle_timeout; @@ -1272,26 +1340,6 @@ upstream_read_cb(void *userarg) netreq->secure = 0; netreq->bogus = 0; - stub_cleanup(netreq); - - /* More to read/write for syncronous lookups? */ - if (netreq->event.read_cb) { - dnsreq = netreq->owner; - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - if (upstream->netreq_by_query_id.count || - upstream->write_queue) - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, upstream->fd, - dnsreq->context->timeout, - getdns_eventloop_event_init( - &netreq->event, netreq, - ( upstream->netreq_by_query_id.count ? - netreq_upstream_read_cb : NULL ), - ( upstream->write_queue ? - netreq_upstream_write_cb : NULL), - stub_timeout_cb)); - } - if (netreq->owner == upstream->starttls_req) { dnsreq = netreq->owner; if (is_starttls_response(netreq)) { @@ -1309,42 +1357,37 @@ upstream_read_cb(void *userarg) netreq->owner->context->timeout, getdns_eventloop_event_init(&upstream->event, upstream, NULL, upstream_write_cb, NULL)); - } else - priv_getdns_check_dns_req_complete(netreq->owner); - - /* Nothing more to read? Then deschedule the reads.*/ - if (! upstream->netreq_by_query_id.count) { - upstream->event.read_cb = NULL; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - if (upstream->event.write_cb) - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, - &upstream->event); - else { - upstream->event.timeout_cb = upstream_idle_timeout_cb; - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, idle_timeout, - &upstream->event); - } } + + /* This also reschedules events for the upstream*/ + stub_cleanup(netreq); + + /* More to read/write for syncronous lookups? */ + if (netreq->event.read_cb) + upstream_reschedule_netreq_events(upstream, netreq); + + if (netreq->owner != upstream->starttls_req) + priv_getdns_check_dns_req_complete(netreq->owner); } } static void netreq_upstream_read_cb(void *userarg) { + DEBUG_STUB("--- READ: %s\n", __FUNCTION__); upstream_read_cb(((getdns_network_req *)userarg)->upstream); } static void upstream_write_cb(void *userarg) { - DEBUG_STUB("%s\n", __FUNCTION__); getdns_upstream *upstream = (getdns_upstream *)userarg; getdns_network_req *netreq = upstream->write_queue; getdns_dns_req *dnsreq = netreq->owner; int q; + DEBUG_STUB("--- WRITE: %s: %p TYPE: %d\n", __FUNCTION__, netreq, + netreq->request_type); if (tls_requested(netreq) && tls_should_write(upstream)) q = stub_tls_write(upstream, &upstream->tcp, netreq); else @@ -1355,16 +1398,24 @@ upstream_write_cb(void *userarg) return; case STUB_TCP_ERROR: - stub_erred(netreq); - return; - + /* Problem with the TCP connection itself. Need to fallback.*/ + DEBUG_STUB("--- WRITE: Setting write error\n"); + upstream->tcp.write_error = 1; + /* Use policy of trying next upstream in this case. Need more work on + * TCP connection re-use.*/ + stub_next_upstream(netreq); + /* Fall through */ case STUB_TLS_SETUP_ERROR: /* Could not complete the TLS set up. Need to fallback.*/ - if (fallback_on_write(netreq) == STUB_TCP_ERROR) - message_erred(netreq); + stub_cleanup(netreq); + if (fallback_on_write(netreq) == STUB_TCP_ERROR) { + netreq->state = NET_REQ_FINISHED; + priv_getdns_check_dns_req_complete(netreq->owner); + } return; default: + upstream->writes_done++; netreq->query_id = (uint16_t) q; /* Unqueue the netreq from the write_queue */ if (!(upstream->write_queue = netreq->write_queue_tail)) { @@ -1418,6 +1469,9 @@ upstream_write_cb(void *userarg) static void netreq_upstream_write_cb(void *userarg) { + DEBUG_STUB("--- WRITE: %s: %p TYPE: %d\n", __FUNCTION__, + ((getdns_network_req *)userarg), + ((getdns_network_req *)userarg)->request_type); upstream_write_cb(((getdns_network_req *)userarg)->upstream); } @@ -1427,20 +1481,26 @@ netreq_upstream_write_cb(void *userarg) static int upstream_transport_valid(getdns_upstream *upstream, - getdns_base_transport_t transport) + getdns_transport_list_t transport) { - /* For single shot transports, use only the TCP upstream. */ - if (transport == GETDNS_BASE_TRANSPORT_UDP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) - return (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP ? 1:0); + /* Single shot UDP, uses same upstream as plain TCP. */ + if (transport == GETDNS_TRANSPORT_UDP) + return (upstream->transport == GETDNS_TRANSPORT_TCP ? 1:0); + /* If we got an error and have never managed to write to this TCP then + treat it as a hard failure */ + if (transport == GETDNS_TRANSPORT_TCP && + upstream->transport == GETDNS_TRANSPORT_TCP && + upstream->tcp.write_error != 0) { + return 0; + } /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't * upgraded to avoid opening a new connection if one is aleady open. */ - if (transport == GETDNS_BASE_TRANSPORT_TCP && - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && + if (transport == GETDNS_TRANSPORT_TCP && + upstream->transport == GETDNS_TRANSPORT_STARTTLS && upstream->tls_hs_state == GETDNS_HS_FAILED) return 1; /* Otherwise, transport must match, and not have failed */ - if (upstream->dns_base_transport != transport) + if (upstream->transport != transport) return 0; if (tls_failed(upstream)) return 0; @@ -1448,27 +1508,34 @@ upstream_transport_valid(getdns_upstream *upstream, } static getdns_upstream * -upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) +upstream_select(getdns_network_req *netreq, getdns_transport_list_t transport) { + DEBUG_STUB(" %s\n", __FUNCTION__); getdns_upstream *upstream; getdns_upstreams *upstreams = netreq->owner->upstreams; size_t i; if (!upstreams->count) return NULL; - + + + /* Only do this when a new message is scheduled?*/ for (i = 0; i < upstreams->count; i++) if (upstreams->upstreams[i].to_retry <= 0) upstreams->upstreams[i].to_retry++; + /* TODO[TLS]: Should we create a tmp array of upstreams with correct*/ + /* transport type and/or maintain separate current for transports?*/ i = upstreams->current; + DEBUG_STUB(" current upstream: %d of %d \n",(int)i, (int)upstreams->count); do { if (upstreams->upstreams[i].to_retry > 0 && upstream_transport_valid(&upstreams->upstreams[i], transport)) { upstreams->current = i; + DEBUG_STUB(" selected upstream: %d\n",(int)i); return &upstreams->upstreams[i]; } - if (++i > upstreams->count) + if (++i >= upstreams->count) i = 0; } while (i != upstreams->current); @@ -1479,8 +1546,10 @@ upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) upstream = &upstreams->upstreams[i]; /* Need to check again that the transport is valid */ - if (!upstream_transport_valid(upstream, transport)) + if (!upstream_transport_valid(upstream, transport)) { + DEBUG_STUB(" ! No valid upstream available\n"); return NULL; + } upstream->back_off++; upstream->to_retry = 1; upstreams->current = upstream - upstreams->upstreams; @@ -1489,31 +1558,28 @@ upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) int -upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, +upstream_connect(getdns_upstream *upstream, getdns_transport_list_t transport, getdns_dns_req *dnsreq) { - DEBUG_STUB("%s\n", __FUNCTION__); int fd = -1; switch(transport) { - case GETDNS_BASE_TRANSPORT_UDP: + case GETDNS_TRANSPORT_UDP: if ((fd = socket( upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) return -1; getdns_sock_nonblock(fd); return fd; - case GETDNS_BASE_TRANSPORT_TCP: + case GETDNS_TRANSPORT_TCP: /* Use existing if available*/ if (upstream->fd != -1) return upstream->fd; - /* Otherwise, fall through */ - case GETDNS_BASE_TRANSPORT_TCP_SINGLE: fd = tcp_connect(upstream, transport); upstream->loop = dnsreq->context->extension; upstream->fd = fd; break; - case GETDNS_BASE_TRANSPORT_TLS: + case GETDNS_TRANSPORT_TLS: /* Use existing if available*/ if (upstream->fd != -1 && !tls_failed(upstream)) return upstream->fd; @@ -1528,7 +1594,7 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, upstream->loop = dnsreq->context->extension; upstream->fd = fd; break; - case GETDNS_BASE_TRANSPORT_STARTTLS: + case GETDNS_TRANSPORT_STARTTLS: /* Use existing if available. Let the fallback code handle it if * STARTTLS isn't availble. */ if (upstream->fd != -1) @@ -1559,14 +1625,15 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, static getdns_upstream* find_upstream_for_specific_transport(getdns_network_req *netreq, - getdns_base_transport_t transport, + getdns_transport_list_t transport, int *fd) { - /* TODO[TLS]: Fallback through upstreams....?*/ getdns_upstream *upstream = upstream_select(netreq, transport); if (!upstream) return NULL; *fd = upstream_connect(upstream, transport, netreq->owner); + DEBUG_STUB(" %s: Found: %d %p fd:%d\n", __FUNCTION__, + transport, upstream, upstream->fd); return upstream; } @@ -1574,15 +1641,16 @@ static int find_upstream_for_netreq(getdns_network_req *netreq) { int fd = -1; - int i = netreq->transport; - for (; i < GETDNS_BASE_TRANSPORT_MAX && - netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) { - netreq->upstream = find_upstream_for_specific_transport(netreq, - netreq->dns_base_transports[i], + getdns_upstream *upstream; + for (size_t i = netreq->transport_current; + i < netreq->transport_count; i++) { + upstream = find_upstream_for_specific_transport(netreq, + netreq->transports[i], &fd); - if (fd == -1 || !netreq->upstream) + if (fd == -1 || !upstream) continue; - netreq->transport = i; + netreq->transport_current = i; + netreq->upstream = upstream; return fd; } return -1; @@ -1593,92 +1661,110 @@ find_upstream_for_netreq(getdns_network_req *netreq) /***********************/ static int -move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, - getdns_upstream *new_upstream) +fallback_on_write(getdns_network_req *netreq) { - DEBUG_STUB("%s\n", __FUNCTION__); - /* Remove from queue, clearing event and fd if we are the last*/ - if (!(upstream->write_queue = netreq->write_queue_tail)) { - upstream->write_queue_last = NULL; + + /* Deal with UDP and change error code*/ + + DEBUG_STUB("#-----> %s: %p TYPE: %d\n", __FUNCTION__, netreq, netreq->request_type); + getdns_upstream *upstream = netreq->upstream; + + /* Try to find a fallback transport*/ + getdns_return_t result = priv_getdns_submit_stub_request(netreq); + + /* For sync messages we must re-schedule the events on the old upstream + * here too. Must schedule this last to make sure it is called back first! */ + if (netreq->owner->loop != upstream->loop && upstream->write_queue) + upstream_reschedule_netreq_events(upstream, upstream->write_queue); + + if (result != GETDNS_RETURN_GOOD) + return STUB_TCP_ERROR; + + return (netreq->transports[netreq->transport_current] + == GETDNS_TRANSPORT_UDP) ? + netreq->fd : netreq->upstream->fd; +} + +static void +upstream_reschedule_events(getdns_upstream *upstream, size_t idle_timeout) { + + DEBUG_STUB("# %s: %p %d\n", __FUNCTION__, upstream, upstream->fd); + int reschedule = 0; + if (!upstream->write_queue && upstream->event.write_cb) { upstream->event.write_cb = NULL; + reschedule = 1; + } + if (upstream->write_queue && !upstream->event.write_cb) { + upstream->event.write_cb = upstream_write_cb; + reschedule = 1; + } + if (!upstream->netreq_by_query_id.count && upstream->event.read_cb) { + upstream->event.read_cb = NULL; + reschedule = 1; + } + if (upstream->netreq_by_query_id.count && !upstream->event.read_cb) { + upstream->event.read_cb = upstream_read_cb; + reschedule = 1; + } + if (reschedule) { GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - + if (upstream->event.read_cb || upstream->event.write_cb) + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + else { + DEBUG_STUB("# %s: *Idle connection %d* \n", + __FUNCTION__, upstream->fd); + upstream->event.timeout_cb = upstream_idle_timeout_cb; + if (upstream->tcp.write_error != 0) + idle_timeout = 0; + GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, + idle_timeout, &upstream->event); + } + } +} + +static void +upstream_reschedule_netreq_events(getdns_upstream *upstream, + getdns_network_req *netreq) { + if (netreq) { + DEBUG_STUB("# %s: %p: TYPE: %d\n", __FUNCTION__, + netreq, netreq->request_type); + getdns_dns_req *dnsreq = netreq->owner; + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + if (upstream->netreq_by_query_id.count || upstream->write_queue) + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, upstream->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + (upstream->netreq_by_query_id.count ? + netreq_upstream_read_cb : NULL ), + (upstream->write_queue ? + netreq_upstream_write_cb : NULL), + stub_timeout_cb)); + } + if (!upstream->netreq_by_query_id.count && !upstream->write_queue) { + /* This is a sync call, and the connection is idle. But we can't set a + * timeout since we won't have an event loop if there are no netreqs. + * Could set a timer and check it when the next req comes in but... + * chances are it will be on the same transport and if we have a new + * req the conneciton is no longer idle so probably better to re-use + * than shut and immediately open a new one! + * So we will have to be aggressive and shut the connection....*/ + DEBUG_STUB("# %s: **Closing connection %d**\n", + __FUNCTION__, upstream->fd); + if (upstream->tls_obj) { + SSL_shutdown(upstream->tls_obj); + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + } close(upstream->fd); upstream->fd = -1; } - netreq->write_queue_tail = NULL; - - /* Schedule with the new upstream */ - netreq->upstream = new_upstream; - upstream_schedule_netreq(new_upstream, netreq); - - /* TODO[TLS]: Timout need to be adjusted and rescheduled on the new fd ...*/ - /* Note, setup timeout should be shorter than message timeout for - * messages with fallback or don't have time to re-try. */ - - /* For sync messages we must re-schedule the events here.*/ - if (netreq->owner->loop != upstream->loop) { - /* Create an event for the new upstream*/ - GETDNS_CLEAR_EVENT(netreq->owner->loop, &netreq->event); - GETDNS_SCHEDULE_EVENT(netreq->owner->loop, new_upstream->fd, - netreq->owner->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - ( new_upstream->netreq_by_query_id.count ? - netreq_upstream_read_cb : NULL ), - ( new_upstream->write_queue ? - netreq_upstream_write_cb : NULL), - stub_timeout_cb)); - - /* Now one for the old upstream. Must schedule this last to make sure - * it is called back first....?*/ - if (upstream->write_queue) { - GETDNS_CLEAR_EVENT(netreq->owner->loop, &upstream->write_queue->event); - GETDNS_SCHEDULE_EVENT( - upstream->write_queue->owner->loop, upstream->fd, - upstream->write_queue->owner->context->timeout, - getdns_eventloop_event_init(&upstream->write_queue->event, - upstream->write_queue, - ( upstream->netreq_by_query_id.count ? - netreq_upstream_read_cb : NULL ), - ( upstream->write_queue ? - netreq_upstream_write_cb : NULL), - stub_timeout_cb)); - } - } - netreq->transport++; - return upstream->fd; -} - -static int -fallback_on_write(getdns_network_req *netreq) -{ - DEBUG_STUB("%s\n", __FUNCTION__); - /* TODO[TLS]: Fallback through all transports.*/ - getdns_base_transport_t next_transport = - netreq->dns_base_transports[netreq->transport + 1]; - if (next_transport == GETDNS_BASE_TRANSPORT_NONE) - return STUB_TCP_ERROR; - - if (netreq->dns_base_transports[netreq->transport] == - GETDNS_BASE_TRANSPORT_STARTTLS && - next_transport == GETDNS_BASE_TRANSPORT_TCP) { - /* Special case where can stay on same upstream*/ - netreq->transport++; - return netreq->upstream->fd; - } - getdns_upstream *upstream = netreq->upstream; - int fd; - getdns_upstream *new_upstream = - find_upstream_for_specific_transport(netreq, next_transport, &fd); - if (!new_upstream) - return STUB_TCP_ERROR; - return move_netreq(netreq, upstream, new_upstream); } static void upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) { - DEBUG_STUB("%s\n", __FUNCTION__); + DEBUG_STUB("# %s: %p TYPE: %d\n", __FUNCTION__, netreq, netreq->request_type); /* We have a connected socket and a global event loop */ assert(upstream->fd >= 0); assert(upstream->loop); @@ -1688,18 +1774,18 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) upstream->write_queue = upstream->write_queue_last = netreq; upstream->event.timeout_cb = NULL; GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + upstream->event.write_cb = upstream_write_cb; if (upstream->tls_hs_state == GETDNS_HS_WRITE || (upstream->starttls_req && upstream->starttls_req->netreqs[0] == netreq)) { /* Set a timeout on the upstream so we can catch failed setup*/ /* TODO[TLS]: When generic fallback supported, we should decide how * to split the timeout between transports. */ - GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, - netreq->owner->context->timeout / 2, - getdns_eventloop_event_init(&upstream->event, upstream, - NULL, upstream_write_cb, upstream_tls_timeout_cb)); + upstream->event.timeout_cb = upstream_tls_timeout_cb; + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, netreq->owner->context->timeout / 2, + &upstream->event); } else { - upstream->event.write_cb = upstream_write_cb; GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, &upstream->event); } @@ -1712,7 +1798,7 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) getdns_return_t priv_getdns_submit_stub_request(getdns_network_req *netreq) { - DEBUG_STUB("%s\n", __FUNCTION__); + DEBUG_STUB("--> %s\n", __FUNCTION__); int fd = -1; getdns_dns_req *dnsreq = netreq->owner; @@ -1722,24 +1808,25 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq) if (fd == -1) return GETDNS_RETURN_GENERIC_ERROR; - getdns_base_transport_t transport = - netreq->dns_base_transports[netreq->transport]; + getdns_transport_list_t transport = + netreq->transports[netreq->transport_current]; switch(transport) { - case GETDNS_BASE_TRANSPORT_UDP: - case GETDNS_BASE_TRANSPORT_TCP_SINGLE: + case GETDNS_TRANSPORT_UDP: netreq->fd = fd; + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, - NULL, (transport == GETDNS_BASE_TRANSPORT_UDP ? + NULL, (transport == GETDNS_TRANSPORT_UDP ? stub_udp_write_cb: stub_tcp_write_cb), stub_timeout_cb)); return GETDNS_RETURN_GOOD; - case GETDNS_BASE_TRANSPORT_STARTTLS: - case GETDNS_BASE_TRANSPORT_TLS: - case GETDNS_BASE_TRANSPORT_TCP: + case GETDNS_TRANSPORT_STARTTLS: + case GETDNS_TRANSPORT_TLS: + case GETDNS_TRANSPORT_TCP: upstream_schedule_netreq(netreq->upstream, netreq); /* TODO[TLS]: Change scheduling for sync calls. */ + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->upstream->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, NULL, diff --git a/src/test/getdns_query.c b/src/test/getdns_query.c index 83064557..f451bd4c 100644 --- a/src/test/getdns_query.c +++ b/src/test/getdns_query.c @@ -180,19 +180,22 @@ void callback(getdns_context *context, getdns_callback_type_t callback_type, fprintf(stdout, "ASYNC response:\n%s\n", response_str); free(response_str); } - fprintf(stderr, - "The callback with ID %llu was successfull.\n", + fprintf(stdout, + "Result: The callback with ID %llu was successfull.\n", (unsigned long long)trans_id); } else if (callback_type == GETDNS_CALLBACK_CANCEL) fprintf(stderr, - "The callback with ID %llu was cancelled. Exiting.\n", + "Result: The callback with ID %llu was cancelled. Exiting.\n", (unsigned long long)trans_id); - else + else { fprintf(stderr, - "The callback got a callback_type of %d. Exiting.\n", + "Result: The callback got a callback_type of %d. Exiting.\n", callback_type); - + fprintf(stderr, + "Error : '%s'\n", + getdns_get_errorstr_by_id(callback_type)); + } getdns_dict_destroy(response); response = NULL; } @@ -274,6 +277,11 @@ getdns_return_t parse_args(int argc, char **argv) " %d", r); break; } + } else if (arg[1] == '0') { + /* Unset all existing extensions*/ + getdns_dict_destroy(extensions); + extensions = getdns_dict_create(); + break; } else if ((r = getdns_dict_set_int(extensions, arg+1, GETDNS_EXTENSION_TRUE))) { fprintf(stderr, "Could not set extension " @@ -443,7 +451,7 @@ getdns_return_t parse_args(int argc, char **argv) return GETDNS_RETURN_GENERIC_ERROR; } size_t transport_count = 0; - getdns_transport_list_t transports[GETDNS_BASE_TRANSPORT_MAX]; + getdns_transport_list_t transports[GETDNS_TRANSPORTS_MAX]; if ((r = fill_transport_list(context, argv[i], transports, &transport_count)) || (r = getdns_context_set_dns_transport_list(context, transport_count, transports))){ @@ -522,8 +530,10 @@ main(int argc, char **argv) if (!fgets(line, 1024, stdin) || !*line) break; } else { - if (!fgets(line, 1024, fp) || !*line) + if (!fgets(line, 1024, fp) || !*line) { + fprintf(stdout,"End of file."); break; + } fprintf(stdout,"Found query: %s", line); } @@ -531,6 +541,10 @@ main(int argc, char **argv) linec = 1; if ( ! (token = strtok(line, " \t\f\n\r"))) continue; + if (*token == '#') { + fprintf(stdout,"Result: Skipping comment\n"); + continue; + } do linev[linec++] = token; while (linec < 256 && (token = strtok(NULL, " \t\f\n\r"))); @@ -596,9 +610,7 @@ main(int argc, char **argv) r = GETDNS_RETURN_GENERIC_ERROR; break; } - if (r) - goto done_destroy_extensions; - if (!quiet) { + if (response && !quiet) { if ((response_str = json ? getdns_print_json_dict(response, json == 1) : getdns_pretty_print_dict(response))) { @@ -611,10 +623,15 @@ main(int argc, char **argv) fprintf( stderr , "Could not print response\n"); } - } else if (r == GETDNS_RETURN_GOOD) - fprintf(stdout, "Response code was: GOOD\n"); - else if (interactive) - fprintf(stderr, "An error occurred: %d\n", r); + } + if (r == GETDNS_RETURN_GOOD) { + uint32_t status; + getdns_dict_get_int(response, "status", &status); + fprintf(stdout, "Response code was: GOOD. Status was: %s\n", + getdns_get_errorstr_by_id(status)); + } else + fprintf(stderr, "An error occurred: %d '%s'\n", r, + getdns_get_errorstr_by_id(r)); } } while (interactive); @@ -633,8 +650,7 @@ done_destroy_context: if (r == CONTINUE) return 0; - if (r) - fprintf(stderr, "An error occurred: %d\n", r); + fprintf(stdout, "\nAll done.\n"); return r; } diff --git a/src/types-internal.h b/src/types-internal.h index 53d9db2e..12049aef 100644 --- a/src/types-internal.h +++ b/src/types-internal.h @@ -99,6 +99,9 @@ struct getdns_upstream; #define TIMEOUT_FOREVER ((int64_t)-1) #define ASSERT_UNREACHABLE 0 +#define GETDNS_TRANSPORTS_MAX 4 +#define GETDNS_UPSTREAM_TRANSPORTS 3 + /** @} */ @@ -156,6 +159,7 @@ typedef struct getdns_tcp_state { uint8_t *write_buf; size_t write_buf_len; size_t written; + int write_error; uint8_t *read_buf; size_t read_buf_len; @@ -164,17 +168,6 @@ typedef struct getdns_tcp_state { } getdns_tcp_state; -/* TODO[TLS]: change this name to getdns_transport when API updated*/ -typedef enum getdns_base_transport { - GETDNS_BASE_TRANSPORT_MIN = 0, - GETDNS_BASE_TRANSPORT_NONE = 0, - GETDNS_BASE_TRANSPORT_UDP, - GETDNS_BASE_TRANSPORT_TCP_SINGLE, /* To be removed? */ - GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */ - GETDNS_BASE_TRANSPORT_TCP, - GETDNS_BASE_TRANSPORT_TLS, - GETDNS_BASE_TRANSPORT_MAX -} getdns_base_transport_t; /** * Request data @@ -203,8 +196,9 @@ typedef struct getdns_network_req /* For stub resolving */ struct getdns_upstream *upstream; int fd; - getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; - int transport; + getdns_transport_list_t transports[GETDNS_TRANSPORTS_MAX]; + size_t transport_count; + size_t transport_current; getdns_eventloop_event event; getdns_tcp_state tcp; uint16_t query_id;