diff --git a/src/context.c b/src/context.c index 3683c749..8dfde4c9 100644 --- a/src/context.c +++ b/src/context.c @@ -54,11 +54,11 @@ #include "list.h" #define GETDNS_PORT_ZERO 0 -#define GETDNS_PORT_TCP 53 -#define GETDNS_PORT_TLS 1021 +#define GETDNS_PORT_DNS 53 +#define GETDNS_PORT_DNS_OVER_TLS 1021 #define GETDNS_STR_PORT_ZERO "0" -#define GETDNS_STR_PORT_TCP "53" -#define GETDNS_STR_PORT_TLS "1021" +#define GETDNS_STR_PORT_DNS "53" +#define GETDNS_STR_PORT_DNS_OVER_TLS "1021" void *plain_mem_funcs_user_arg = MF_PLAIN; @@ -74,9 +74,9 @@ getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { GETDNS_PORT_ZERO, GETDNS_PORT_ZERO, GETDNS_PORT_ZERO, - GETDNS_PORT_TCP, - GETDNS_PORT_TCP, - GETDNS_PORT_TLS + GETDNS_PORT_DNS, + GETDNS_PORT_DNS, + GETDNS_PORT_DNS_OVER_TLS }; char* @@ -84,9 +84,9 @@ getdns_port_str_array[] = { GETDNS_STR_PORT_ZERO, GETDNS_STR_PORT_ZERO, GETDNS_STR_PORT_ZERO, - GETDNS_STR_PORT_TCP, - GETDNS_STR_PORT_TCP, - GETDNS_STR_PORT_TLS + GETDNS_STR_PORT_DNS, + GETDNS_STR_PORT_DNS, + GETDNS_STR_PORT_DNS_OVER_TLS }; /* Private functions */ @@ -267,7 +267,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in *)sa)->sin_port); - if (port != GETDNS_PORT_ZERO && port != GETDNS_PORT_TCP && + if (port != GETDNS_PORT_ZERO && port != GETDNS_PORT_DNS && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -283,7 +283,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port); - if (port != GETDNS_PORT_TCP && port != GETDNS_PORT_TCP && + if (port != GETDNS_PORT_DNS && port != GETDNS_PORT_DNS && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -553,7 +553,7 @@ upstream_ntop_buf(getdns_upstream *upstream, char *buf, size_t len) if (upstream_scope_id(upstream)) (void) snprintf(buf + strlen(buf), len - strlen(buf), "%%%d", (int)*upstream_scope_id(upstream)); - else if (upstream_port(upstream) != GETDNS_PORT_TCP && upstream_port(upstream) != GETDNS_PORT_ZERO) + else if (upstream_port(upstream) != GETDNS_PORT_DNS && upstream_port(upstream) != GETDNS_PORT_ZERO) (void) snprintf(buf + strlen(buf), len - strlen(buf), "@%d", (int)upstream_port(upstream)); } @@ -1521,7 +1521,6 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, struct addrinfo *ai; getdns_upstream *upstream; - /* So should we be throwing away the port the user set?*/ port = getdns_port_array[base_transport]; if (port == GETDNS_PORT_ZERO) continue; @@ -1569,6 +1568,8 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, if (getaddrinfo(addrstr, portstr, &hints, &ai)) goto invalid_parameter; + /* TODO[TLS]: Should probably check that the upstream doesn't + * already exist (in case user has specified port explicitly)*/ upstream_init(upstream, upstreams, ai); upstream->dns_base_transport = base_transport; upstreams->count++; @@ -1796,9 +1797,9 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) * used. All other cases must currently fallback to TCP for libunbound.*/ if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE && - upstream_port(upstream) != GETDNS_PORT_TLS) + upstream_port(upstream) != GETDNS_PORT_DNS_OVER_TLS) continue; - else if (upstream_port(upstream) != GETDNS_PORT_TCP) + else if (upstream_port(upstream) != GETDNS_PORT_DNS) continue; upstream_ntop_buf(upstream, addr, 1024); ub_ctx_set_fwd(ctx, addr); diff --git a/src/stub.c b/src/stub.c index 9542205b..e4dfc910 100755 --- a/src/stub.c +++ b/src/stub.c @@ -596,10 +596,14 @@ stub_udp_write_cb(void *userarg) static int transport_valid(struct getdns_upstream *upstream, getdns_base_transport_t transport) { - /* For single shot transports, use any upstream. */ + /* For single shot transports, use only the TCP upstream. */ if (transport == GETDNS_BASE_TRANSPORT_UDP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) - return 1; + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) { + if (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP) + return 1; + else + return 0; + } /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't upgraded * to avoid opening a new connection if one is aleady open. */ if (transport == GETDNS_BASE_TRANSPORT_TCP && @@ -621,7 +625,7 @@ static getdns_upstream * pick_upstream(getdns_network_req *netreq, getdns_base_transport_t transport) { getdns_upstream *upstream; - getdns_upstreams *upstreams = netreq->owner->context->upstreams; + getdns_upstreams *upstreams = netreq->owner->upstreams; size_t i; if (!upstreams->count) @@ -648,6 +652,9 @@ pick_upstream(getdns_network_req *netreq, getdns_base_transport_t transport) transport_valid(&upstreams->upstreams[i], transport)) upstream = &upstreams->upstreams[i]; + /* Need to check again that the transport is valid */ + if (!transport_valid(upstream, transport)) + return NULL; upstream->back_off++; upstream->to_retry = 1; upstreams->current = upstream - upstreams->upstreams; @@ -1493,8 +1500,9 @@ find_upstream_for_netreq(getdns_network_req *netreq) netreq->upstream = find_upstream_for_specific_transport(netreq, netreq->dns_base_transports[i], &fd); - if (fd == -1) + if (fd == -1 || !netreq->upstream) continue; + netreq->dns_base_transport = &netreq->dns_base_transports[i]; return fd; } return -1;