Re-factor of internal handing of transport list.

This commit is contained in:
Sara Dickinson 2015-06-19 18:28:29 +01:00
parent 0acdcc34b0
commit 635cf9e182
6 changed files with 202 additions and 232 deletions

View File

@ -69,11 +69,15 @@ typedef struct host_name_addrs {
uint8_t host_name[];
} host_name_addrs;
static getdns_transport_list_t
getdns_upstream_transports[GETDNS_UPSTREAM_TRANSPORTS] = {
GETDNS_TRANSPORT_TCP,
GETDNS_TRANSPORT_TLS,
GETDNS_TRANSPORT_STARTTLS
};
static in_port_t
getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = {
GETDNS_PORT_ZERO,
GETDNS_PORT_ZERO,
GETDNS_PORT_ZERO,
getdns_port_array[GETDNS_UPSTREAM_TRANSPORTS] = {
GETDNS_PORT_DNS,
GETDNS_PORT_DNS,
GETDNS_PORT_DNS_OVER_TLS
@ -81,9 +85,6 @@ getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = {
char*
getdns_port_str_array[] = {
GETDNS_STR_PORT_ZERO,
GETDNS_STR_PORT_ZERO,
GETDNS_STR_PORT_ZERO,
GETDNS_STR_PORT_DNS,
GETDNS_STR_PORT_DNS,
GETDNS_STR_PORT_DNS_OVER_TLS
@ -91,6 +92,7 @@ getdns_port_str_array[] = {
/* Private functions */
getdns_return_t create_default_namespaces(struct getdns_context *context);
getdns_return_t create_default_dns_transports(struct getdns_context *context);
static struct getdns_list *create_default_root_servers(void);
static getdns_return_t set_os_defaults(struct getdns_context *);
static int transaction_id_cmp(const void *, const void *);
@ -139,42 +141,22 @@ create_default_namespaces(struct getdns_context *context)
return GETDNS_RETURN_GOOD;
}
static getdns_transport_list_t *
get_dns_transport_list(getdns_context *context, int *count)
/**
* Helper to get default transports.
*/
getdns_return_t
create_default_dns_transports(struct getdns_context *context)
{
if (context == NULL)
return NULL;
context->dns_transports = GETDNS_XMALLOC(context->my_mf, getdns_transport_list_t, 2);
if(context->dns_transports == NULL)
return GETDNS_RETURN_GENERIC_ERROR;
/* Count how many we have*/
for (*count = 0; *count < GETDNS_BASE_TRANSPORT_MAX; (*count)++) {
if (context->dns_base_transports[*count] == GETDNS_BASE_TRANSPORT_NONE)
break;
}
context->dns_transports[0] = GETDNS_TRANSPORT_UDP;
context->dns_transports[1] = GETDNS_TRANSPORT_TCP;
context->dns_transport_count = 2;
context->dns_transport_current = 0;
// use normal malloc here so users can do normal free
getdns_transport_list_t * transports = malloc(*count * sizeof(getdns_transport_list_t));
if(transports == NULL)
return NULL;
for (int i = 0; i < (int)*count; i++) {
switch(context->dns_base_transports[i]) {
case GETDNS_BASE_TRANSPORT_UDP:
transports[i] = GETDNS_TRANSPORT_UDP;
break;
case GETDNS_BASE_TRANSPORT_TCP:
transports[i] = GETDNS_TRANSPORT_TCP;
break;
case GETDNS_BASE_TRANSPORT_TLS:
transports[i] = GETDNS_TRANSPORT_TLS;
break;
case GETDNS_BASE_TRANSPORT_STARTTLS:
transports[i] = GETDNS_TRANSPORT_STARTTLS;
break;
default:
break;
}
}
return transports;
return GETDNS_RETURN_GOOD;
}
static inline void canonicalize_dname(uint8_t *dname)
@ -621,7 +603,7 @@ upstream_init(getdns_upstream *upstream,
upstream->fd = -1;
upstream->tls_obj = NULL;
upstream->starttls_req = NULL;
upstream->dns_base_transport = GETDNS_BASE_TRANSPORT_TCP;
upstream->transport = GETDNS_TRANSPORT_TCP;
upstream->tls_hs_state = GETDNS_HS_NONE;
upstream->loop = NULL;
(void) getdns_eventloop_event_init(
@ -725,11 +707,8 @@ set_os_defaults(struct getdns_context *context)
token = parse + strcspn(parse, " \t\r\n");
*token = 0;
getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN;
for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) {
char *port_str = getdns_port_str_array[base_transport];
if (strncmp(port_str, GETDNS_STR_PORT_ZERO, 1) == 0)
continue;
for (size_t i = 0; i < GETDNS_UPSTREAM_TRANSPORTS; i++) {
char *port_str = getdns_port_str_array[i];
if ((s = getaddrinfo(parse, port_str, &hints, &result)))
continue;
if (!result)
@ -743,7 +722,7 @@ set_os_defaults(struct getdns_context *context)
upstream = &context->upstreams->
upstreams[context->upstreams->count++];
upstream_init(upstream, context->upstreams, result);
upstream->dns_base_transport = base_transport;
upstream->transport = getdns_upstream_transports[i];
freeaddrinfo(result);
}
}
@ -873,8 +852,8 @@ getdns_context_create_with_extended_memory_functions(
result->dnssec_allowed_skew = 0;
result->edns_maximum_udp_payload_size = -1;
result->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP;
result->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP;
if ((r = create_default_dns_transports(result)))
goto error;
result->limit_outstanding_queries = 0;
result->has_ta = priv_getdns_parse_ta_file(NULL, NULL);
result->return_dnssec_status = GETDNS_EXTENSION_FALSE;
@ -1201,62 +1180,61 @@ static getdns_return_t
getdns_set_base_dns_transports(struct getdns_context *context,
size_t transport_count, getdns_transport_list_t *transports)
{
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++)
context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE;
size_t i;
if ((int)transport_count == 0 || transports == NULL ||
(int)transport_count > GETDNS_BASE_TRANSPORT_MAX) {
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
if (transport_count == 0 || transports == NULL) {
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
}
for(i=0; i<transport_count; i++)
{
if( transports[i] != GETDNS_TRANSPORT_UDP
&& transports[i] != GETDNS_TRANSPORT_TCP
&& transports[i] != GETDNS_TRANSPORT_TLS
&& transports[i] != GETDNS_TRANSPORT_STARTTLS)
return GETDNS_RETURN_INVALID_PARAMETER;
}
for (size_t j = 0; j < transport_count; j++) {
switch(transports[j]) {
case GETDNS_TRANSPORT_UDP:
context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_UDP;
break;
case GETDNS_TRANSPORT_TCP:
context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TCP;
break;
case GETDNS_TRANSPORT_TLS:
context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_TLS;
break;
case GETDNS_TRANSPORT_STARTTLS:
context->dns_base_transports[j] = GETDNS_BASE_TRANSPORT_STARTTLS;
break;
default:
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
}
}
return GETDNS_RETURN_GOOD;
GETDNS_FREE(context->my_mf, context->dns_transports);
/** duplicate **/
context->dns_transports = GETDNS_XMALLOC(context->my_mf,
getdns_transport_list_t, transport_count);
memcpy(context->dns_transports, transports,
transport_count * sizeof(getdns_transport_list_t));
context->dns_transport_count = transport_count;
dispatch_updated(context, GETDNS_CONTEXT_CODE_NAMESPACES);
return GETDNS_RETURN_GOOD;
}
static getdns_return_t
set_ub_dns_transport(struct getdns_context* context) {
/* These mappings are not exact because Unbound is configured differently,
so just map as close as possible from the first 1 or 2 transports. */
switch (context->dns_base_transports[0]) {
case GETDNS_BASE_TRANSPORT_UDP:
switch (context->dns_transports[0]) {
case GETDNS_TRANSPORT_UDP:
set_ub_string_opt(context, "do-udp:", "yes");
if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_TCP)
if (context->dns_transports[1] == GETDNS_TRANSPORT_TCP)
set_ub_string_opt(context, "do-tcp:", "yes");
else
set_ub_string_opt(context, "do-tcp:", "no");
break;
case GETDNS_BASE_TRANSPORT_TLS:
case GETDNS_TRANSPORT_TLS:
/* Note: If TLS is used in recursive mode this will try TLS on port
* 53... So this is prohibited when preparing for resolution.*/
if (context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE) {
if (context->dns_transport_count == 0) {
set_ub_string_opt(context, "ssl-upstream:", "yes");
set_ub_string_opt(context, "do-udp:", "no");
set_ub_string_opt(context, "do-tcp:", "yes");
break;
}
if (context->dns_base_transports[1] != GETDNS_BASE_TRANSPORT_TCP)
if (context->dns_transports[1] != GETDNS_TRANSPORT_TCP)
break;
/* Fallthrough */
case GETDNS_BASE_TRANSPORT_STARTTLS:
case GETDNS_BASE_TRANSPORT_TCP:
case GETDNS_TRANSPORT_STARTTLS:
case GETDNS_TRANSPORT_TCP:
/* Note: no STARTTLS or fallback to TCP available directly in unbound, so we just
* use TCP for now to make sure the messages are sent. */
set_ub_string_opt(context, "do-udp:", "no");
@ -1278,31 +1256,37 @@ getdns_context_set_dns_transport(struct getdns_context *context,
{
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++)
context->dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE;
size_t count = 2;
if (value == GETDNS_TRANSPORT_UDP_ONLY ||
value == GETDNS_TRANSPORT_TCP_ONLY ||
value == GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN ||
value == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN)
count = 1;
context->dns_transports = GETDNS_XMALLOC(context->my_mf,
getdns_transport_list_t, count);
switch (value) {
case GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP;
context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP;
context->dns_transports[0] = GETDNS_TRANSPORT_UDP;
context->dns_transports[1] = GETDNS_TRANSPORT_TCP;
break;
case GETDNS_TRANSPORT_UDP_ONLY:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP;
context->dns_transports[0] = GETDNS_TRANSPORT_UDP;
break;
case GETDNS_TRANSPORT_TCP_ONLY:
case GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TCP;
context->dns_transports[0] = GETDNS_TRANSPORT_TCP;
break;
case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS;
context->dns_transports[0] = GETDNS_TRANSPORT_TLS;
break;
case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS;
context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP;
context->dns_transports[0] = GETDNS_TRANSPORT_TLS;
context->dns_transports[1] = GETDNS_TRANSPORT_TCP;
break;
case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN:
context->dns_base_transports[0] = GETDNS_BASE_TRANSPORT_STARTTLS;
context->dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP;
context->dns_transports[0] = GETDNS_TRANSPORT_STARTTLS;
context->dns_transports[1] = GETDNS_TRANSPORT_TCP;
break;
default:
return GETDNS_RETURN_CONTEXT_UPDATE_FAIL;
@ -1658,15 +1642,14 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context,
}
/* Loop to create upstreams as needed*/
getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN;
for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) {
for (size_t j = 0; j < GETDNS_UPSTREAM_TRANSPORTS; j++) {
uint32_t port;
struct addrinfo *ai;
port = getdns_port_array[base_transport];
port = getdns_port_array[j];
if (port == GETDNS_PORT_ZERO)
continue;
if (base_transport != GETDNS_BASE_TRANSPORT_TLS)
if (getdns_upstream_transports[j] != GETDNS_TRANSPORT_TLS)
(void) getdns_dict_get_int(dict, "port", &port);
else
(void) getdns_dict_get_int(dict, "tls_port", &port);
@ -1689,7 +1672,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context,
upstream = &upstreams->upstreams[upstreams->count];
upstream->addr.ss_family = addr.ss_family;
upstream_init(upstream, upstreams, ai);
upstream->dns_base_transport = base_transport;
upstream->transport = getdns_upstream_transports[j];
upstreams->count++;
freeaddrinfo(ai);
}
@ -1913,9 +1896,9 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context)
upstream = &upstreams->upstreams[i];
/*[TLS]: Use only the TLS subset of upstreams when TLS is the only thing
* used. All other cases must currently fallback to TCP for libunbound.*/
if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS &&
context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE &&
upstream->dns_base_transport != GETDNS_BASE_TRANSPORT_TLS)
if (context->dns_transports[0] == GETDNS_TRANSPORT_TLS &&
context->dns_transport_count ==1 &&
upstream->transport != GETDNS_TRANSPORT_TLS)
continue;
upstream_ntop_buf(upstream, addr, 1024);
ub_ctx_set_fwd(ctx, addr);
@ -2025,8 +2008,8 @@ getdns_context_prepare_for_resolution(struct getdns_context *context,
}
/* Block use of TLS ONLY in recursive mode as it won't work */
if (context->resolution_type == GETDNS_RESOLUTION_RECURSING &&
context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS &&
context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE)
context->dns_transports[0] == GETDNS_TRANSPORT_TLS &&
context->dns_transport_count == 1)
return GETDNS_RETURN_BAD_CONTEXT;
if (context->resolution_type_set == context->resolution_type)
@ -2321,17 +2304,16 @@ priv_get_context_settings(getdns_context* context) {
upstreams);
getdns_list_destroy(upstreams);
}
/* create a transport list */
getdns_list* transports = getdns_list_create_with_context(context);
if (transports) {
int transport_count;
getdns_transport_list_t *transport_list =
get_dns_transport_list(context, &transport_count);
for (int i = 0; i < transport_count; i++) {
r |= getdns_list_set_int(transports, i, transport_list[i]);
if (context->dns_transport_count > 0) {
/* create a namespace list */
size_t i;
getdns_list* transports = getdns_list_create_with_context(context);
if (transports) {
for (i = 0; i < context->dns_transport_count; ++i) {
r |= getdns_list_set_int(transports, i, context->dns_transports[i]);
}
r |= getdns_dict_set_list(result, "dns_transport_list", transports);
}
r |= getdns_dict_set_list(result, "dns_transport_list", transports);
free(transport_list);
}
if (context->namespace_count > 0) {
/* create a namespace list */
@ -2525,35 +2507,34 @@ getdns_context_get_dns_transport(getdns_context *context,
getdns_transport_t* value) {
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
RETURN_IF_NULL(value, GETDNS_RETURN_INVALID_PARAMETER);
int count;
getdns_transport_list_t *transport_list =
get_dns_transport_list(context, &count);
if (!count)
int count = context->dns_transport_count;
getdns_transport_list_t *transports = context->dns_transports;
if (!count)
return GETDNS_RETURN_WRONG_TYPE_REQUESTED;
/* Best effort mapping for backwards compatibility*/
if (transport_list[0] == GETDNS_TRANSPORT_UDP) {
if (transports[0] == GETDNS_TRANSPORT_UDP) {
if (count == 1)
*value = GETDNS_TRANSPORT_UDP_ONLY;
else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP)
else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP)
*value = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP;
else
return GETDNS_RETURN_WRONG_TYPE_REQUESTED;
}
if (transport_list[0] == GETDNS_TRANSPORT_TCP) {
if (transports[0] == GETDNS_TRANSPORT_TCP) {
if (count == 1)
*value = GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN;
}
if (transport_list[0] == GETDNS_TRANSPORT_TLS) {
if (transports[0] == GETDNS_TRANSPORT_TLS) {
if (count == 1)
*value = GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN;
else if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP)
else if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP)
*value = GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN;
else
return GETDNS_RETURN_WRONG_TYPE_REQUESTED;
}
if (transport_list[0] == GETDNS_TRANSPORT_STARTTLS) {
if (count == 2 && transport_list[1] == GETDNS_TRANSPORT_TCP)
if (transports[0] == GETDNS_TRANSPORT_STARTTLS) {
if (count == 2 && transports[1] == GETDNS_TRANSPORT_TCP)
*value = GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN;
else
return GETDNS_RETURN_WRONG_TYPE_REQUESTED;
@ -2567,16 +2548,15 @@ getdns_context_get_dns_transport_list(getdns_context *context,
RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER);
RETURN_IF_NULL(transport_count, GETDNS_RETURN_INVALID_PARAMETER);
RETURN_IF_NULL(transports, GETDNS_RETURN_INVALID_PARAMETER);
int count;
getdns_transport_list_t *transport_list =
get_dns_transport_list(context, &count);
*transport_count = count;
if (!transport_count) {
*transport_count = context->dns_transport_count;
if (!context->dns_transport_count) {
*transports = NULL;
return GETDNS_RETURN_GOOD;
}
*transports = transport_list;
// use normal malloc here so users can do normal free
*transports = malloc(context->dns_transport_count * sizeof(getdns_transport_list_t));
memcpy(*transports, context->dns_transports,
context->dns_transport_count * sizeof(getdns_transport_list_t));
return GETDNS_RETURN_GOOD;
}

View File

@ -91,7 +91,7 @@ typedef struct getdns_upstream {
/* For sharing a TCP socket to this upstream */
int fd;
getdns_base_transport_t dns_base_transport;
getdns_transport_list_t transport;
SSL* tls_obj;
getdns_tls_hs_state_t tls_hs_state;
getdns_dns_req * starttls_req;
@ -138,10 +138,13 @@ struct getdns_context {
struct getdns_list *suffix;
struct getdns_list *dnssec_trust_anchors;
getdns_upstreams *upstreams;
getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX];
uint16_t limit_outstanding_queries;
uint32_t dnssec_allowed_skew;
getdns_transport_list_t *dns_transports;
size_t dns_transport_count;
size_t dns_transport_current;
uint8_t edns_extended_rcode;
uint8_t edns_version;
uint8_t edns_do_bit;

View File

@ -62,6 +62,7 @@ network_req_cleanup(getdns_network_req *net_req)
if (net_req->response && (net_req->response < net_req->wire_data ||
net_req->response > net_req->wire_data+ net_req->wire_data_sz))
GETDNS_FREE(net_req->owner->my_mf, net_req->response);
GETDNS_FREE(net_req->owner->my_mf, net_req->transports);
}
static int
@ -89,9 +90,13 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner,
net_req->upstream = NULL;
net_req->fd = -1;
for (i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++)
net_req->dns_base_transports[i] = owner->context->dns_base_transports[i];
net_req->transport = 0;
net_req->transports = GETDNS_XMALLOC(net_req->owner->my_mf,
getdns_transport_list_t,
owner->context->dns_transport_count);
memcpy(owner->context->dns_transports, net_req->transports,
owner->context->dns_transport_count * sizeof(getdns_transport_list_t));
net_req->transport_count = owner->context->dns_transport_count;
net_req->transport_current = 0;
memset(&net_req->event, 0, sizeof(net_req->event));
memset(&net_req->tcp, 0, sizeof(net_req->tcp));
net_req->query_id = 0;

View File

@ -59,6 +59,9 @@ static void upstream_read_cb(void *userarg);
static void upstream_write_cb(void *userarg);
static void upstream_schedule_netreq(getdns_upstream *upstream,
getdns_network_req *netreq);
static int upstream_connect(getdns_upstream *upstream,
getdns_transport_list_t transport,
getdns_dns_req *dnsreq);
static void netreq_upstream_read_cb(void *userarg);
static void netreq_upstream_write_cb(void *userarg);
static int fallback_on_write(getdns_network_req *netreq);
@ -354,7 +357,7 @@ getdns_sock_nonblock(int sockfd)
}
static int
tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport)
tcp_connect(getdns_upstream *upstream, getdns_transport_list_t transport)
{
int fd = -1;
if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
@ -363,9 +366,8 @@ tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport)
getdns_sock_nonblock(fd);
#ifdef USE_TCP_FASTOPEN
/* Leave the connect to the later call to sendto() if using TCP*/
if (transport == GETDNS_BASE_TRANSPORT_TCP ||
transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE ||
transport == GETDNS_BASE_TRANSPORT_STARTTLS)
if (transport == GETDNS_TRANSPORT_TCP ||
transport == GETDNS_TRANSPORT_STARTTLS)
return fd;
#endif
if (connect(fd, (struct sockaddr *)&upstream->addr,
@ -667,20 +669,7 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq)
if (! tcp->write_buf) {
/* No, this is an initial write. Try to send
*/
/* Not keeping connections open? Then the first random number
* will do as the query id.
*
* Otherwise find a unique query_id not already written (or in
* the write_queue) for that upstream. Register this netreq
* by query_id in the process.
*/
if ((netreq->dns_base_transports[netreq->transport] ==
GETDNS_BASE_TRANSPORT_TCP_SINGLE) ||
(netreq->dns_base_transports[netreq->transport] ==
GETDNS_BASE_TRANSPORT_UDP))
query_id = arc4random();
else do {
do {
query_id = arc4random();
query_id_intptr = (intptr_t)query_id;
netreq->node.key = (void *)query_id_intptr;
@ -774,10 +763,10 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq)
static int
tls_requested(getdns_network_req *netreq)
{
return (netreq->dns_base_transports[netreq->transport] ==
GETDNS_BASE_TRANSPORT_TLS ||
netreq->dns_base_transports[netreq->transport] ==
GETDNS_BASE_TRANSPORT_STARTTLS) ?
return (netreq->transports[netreq->transport_current] ==
GETDNS_TRANSPORT_TLS ||
netreq->transports[netreq->transport_current] ==
GETDNS_TRANSPORT_STARTTLS) ?
1 : 0;
}
@ -786,16 +775,16 @@ tls_should_write(getdns_upstream *upstream)
{
/* Should messages be written on TLS upstream. Remember that for STARTTLS
* the first message should got over TCP as the handshake isn't started yet.*/
return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS ||
upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) &&
return ((upstream->transport == GETDNS_TRANSPORT_TLS ||
upstream->transport == GETDNS_TRANSPORT_STARTTLS) &&
upstream->tls_hs_state != GETDNS_HS_NONE) ? 1 : 0;
}
static int
tls_should_read(getdns_upstream *upstream)
{
return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS ||
upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) &&
return ((upstream->transport == GETDNS_TRANSPORT_TLS ||
upstream->transport == GETDNS_TRANSPORT_STARTTLS) &&
!(upstream->tls_hs_state == GETDNS_HS_FAILED ||
upstream->tls_hs_state == GETDNS_HS_NONE)) ? 1 : 0;
}
@ -804,8 +793,8 @@ static int
tls_failed(getdns_upstream *upstream)
{
/* No messages should be scheduled onto an upstream in this state */
return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS ||
upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) &&
return ((upstream->transport == GETDNS_TRANSPORT_TLS ||
upstream->transport == GETDNS_TRANSPORT_STARTTLS) &&
upstream->tls_hs_state == GETDNS_HS_FAILED) ? 1: 0;
}
@ -1070,23 +1059,16 @@ stub_udp_read_cb(void *userarg)
return; /* Client cookie didn't match? */
close(netreq->fd);
/* TODO: check not past end of transports*/
getdns_base_transport_t next_transport =
netreq->dns_base_transports[netreq->transport + 1];
if (GLDNS_TC_WIRE(netreq->response) &&
next_transport == GETDNS_BASE_TRANSPORT_TCP) {
if ((netreq->fd = socket(
upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
if (GLDNS_TC_WIRE(netreq->response)) {
if (!(netreq->transport_current < netreq->transport_count))
goto done;
getdns_sock_nonblock(netreq->fd);
if (connect(netreq->fd, (struct sockaddr *)&upstream->addr,
upstream->addr_len) == -1 && errno != EINPROGRESS) {
close(netreq->fd);
netreq->transport_current++;
if (netreq->transport_current != GETDNS_TRANSPORT_TCP)
goto done;
}
if ((netreq->fd = upstream_connect(upstream, netreq->transport_current,
dnsreq)) == -1)
goto done;
GETDNS_SCHEDULE_EVENT(
dnsreq->loop, netreq->fd, dnsreq->context->timeout,
getdns_eventloop_event_init(&netreq->event, netreq,
@ -1427,20 +1409,19 @@ netreq_upstream_write_cb(void *userarg)
static int
upstream_transport_valid(getdns_upstream *upstream,
getdns_base_transport_t transport)
getdns_transport_list_t transport)
{
/* For single shot transports, use only the TCP upstream. */
if (transport == GETDNS_BASE_TRANSPORT_UDP ||
transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE)
return (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP ? 1:0);
/* Single shot UDP, uses same upstream as plain TCP. */
if (transport == GETDNS_TRANSPORT_UDP)
return (upstream->transport == GETDNS_TRANSPORT_TCP ? 1:0);
/* Allow TCP messages to be sent on a STARTTLS upstream that hasn't
* upgraded to avoid opening a new connection if one is aleady open. */
if (transport == GETDNS_BASE_TRANSPORT_TCP &&
upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS &&
if (transport == GETDNS_TRANSPORT_TCP &&
upstream->transport == GETDNS_TRANSPORT_STARTTLS &&
upstream->tls_hs_state == GETDNS_HS_FAILED)
return 1;
/* Otherwise, transport must match, and not have failed */
if (upstream->dns_base_transport != transport)
if (upstream->transport != transport)
return 0;
if (tls_failed(upstream))
return 0;
@ -1448,7 +1429,7 @@ upstream_transport_valid(getdns_upstream *upstream,
}
static getdns_upstream *
upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport)
upstream_select(getdns_network_req *netreq, getdns_transport_list_t transport)
{
getdns_upstream *upstream;
getdns_upstreams *upstreams = netreq->owner->upstreams;
@ -1489,31 +1470,29 @@ upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport)
int
upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport,
upstream_connect(getdns_upstream *upstream, getdns_transport_list_t transport,
getdns_dns_req *dnsreq)
{
DEBUG_STUB("%s\n", __FUNCTION__);
int fd = -1;
switch(transport) {
case GETDNS_BASE_TRANSPORT_UDP:
case GETDNS_TRANSPORT_UDP:
if ((fd = socket(
upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1)
return -1;
getdns_sock_nonblock(fd);
return fd;
case GETDNS_BASE_TRANSPORT_TCP:
case GETDNS_TRANSPORT_TCP:
/* Use existing if available*/
if (upstream->fd != -1)
return upstream->fd;
/* Otherwise, fall through */
case GETDNS_BASE_TRANSPORT_TCP_SINGLE:
fd = tcp_connect(upstream, transport);
upstream->loop = dnsreq->context->extension;
upstream->fd = fd;
break;
case GETDNS_BASE_TRANSPORT_TLS:
case GETDNS_TRANSPORT_TLS:
/* Use existing if available*/
if (upstream->fd != -1 && !tls_failed(upstream))
return upstream->fd;
@ -1528,7 +1507,7 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport,
upstream->loop = dnsreq->context->extension;
upstream->fd = fd;
break;
case GETDNS_BASE_TRANSPORT_STARTTLS:
case GETDNS_TRANSPORT_STARTTLS:
/* Use existing if available. Let the fallback code handle it if
* STARTTLS isn't availble. */
if (upstream->fd != -1)
@ -1559,7 +1538,7 @@ upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport,
static getdns_upstream*
find_upstream_for_specific_transport(getdns_network_req *netreq,
getdns_base_transport_t transport,
getdns_transport_list_t transport,
int *fd)
{
/* TODO[TLS]: Fallback through upstreams....?*/
@ -1574,15 +1553,13 @@ static int
find_upstream_for_netreq(getdns_network_req *netreq)
{
int fd = -1;
int i = netreq->transport;
for (; i < GETDNS_BASE_TRANSPORT_MAX &&
netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) {
for (size_t i = 0; i < netreq->transport_count; i++) {
netreq->upstream = find_upstream_for_specific_transport(netreq,
netreq->dns_base_transports[i],
netreq->transports[i],
&fd);
if (fd == -1 || !netreq->upstream)
continue;
netreq->transport = i;
netreq->transport_current = i;
return fd;
}
return -1;
@ -1645,7 +1622,7 @@ move_netreq(getdns_network_req *netreq, getdns_upstream *upstream,
stub_timeout_cb));
}
}
netreq->transport++;
netreq->transport_current++;
return upstream->fd;
}
@ -1654,16 +1631,18 @@ fallback_on_write(getdns_network_req *netreq)
{
DEBUG_STUB("%s\n", __FUNCTION__);
/* TODO[TLS]: Fallback through all transports.*/
getdns_base_transport_t next_transport =
netreq->dns_base_transports[netreq->transport + 1];
if (next_transport == GETDNS_BASE_TRANSPORT_NONE)
if (netreq->transport_current = netreq->transport_count - 1)
return STUB_TCP_ERROR;
if (netreq->dns_base_transports[netreq->transport] ==
GETDNS_BASE_TRANSPORT_STARTTLS &&
next_transport == GETDNS_BASE_TRANSPORT_TCP) {
/* Special case where can stay on same upstream*/
netreq->transport++;
getdns_transport_list_t next_transport =
netreq->transports[netreq->transport_current + 1];
if (netreq->transports[netreq->transport_current] ==
GETDNS_TRANSPORT_STARTTLS &&
next_transport == GETDNS_TRANSPORT_TCP) {
/* TODO[TLS]: Check this is always OK....
* Special case where can stay on same upstream*/
netreq->transport_current++;
return netreq->upstream->fd;
}
getdns_upstream *upstream = netreq->upstream;
@ -1722,22 +1701,21 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq)
if (fd == -1)
return GETDNS_RETURN_GENERIC_ERROR;
getdns_base_transport_t transport =
netreq->dns_base_transports[netreq->transport];
getdns_transport_list_t transport =
netreq->transports[netreq->transport_current];
switch(transport) {
case GETDNS_BASE_TRANSPORT_UDP:
case GETDNS_BASE_TRANSPORT_TCP_SINGLE:
case GETDNS_TRANSPORT_UDP:
netreq->fd = fd;
GETDNS_SCHEDULE_EVENT(
dnsreq->loop, netreq->fd, dnsreq->context->timeout,
getdns_eventloop_event_init(&netreq->event, netreq,
NULL, (transport == GETDNS_BASE_TRANSPORT_UDP ?
NULL, (transport == GETDNS_TRANSPORT_UDP ?
stub_udp_write_cb: stub_tcp_write_cb), stub_timeout_cb));
return GETDNS_RETURN_GOOD;
case GETDNS_BASE_TRANSPORT_STARTTLS:
case GETDNS_BASE_TRANSPORT_TLS:
case GETDNS_BASE_TRANSPORT_TCP:
case GETDNS_TRANSPORT_STARTTLS:
case GETDNS_TRANSPORT_TLS:
case GETDNS_TRANSPORT_TCP:
upstream_schedule_netreq(netreq->upstream, netreq);
/* TODO[TLS]: Change scheduling for sync calls. */
GETDNS_SCHEDULE_EVENT(

View File

@ -443,7 +443,7 @@ getdns_return_t parse_args(int argc, char **argv)
return GETDNS_RETURN_GENERIC_ERROR;
}
size_t transport_count = 0;
getdns_transport_list_t transports[GETDNS_BASE_TRANSPORT_MAX];
getdns_transport_list_t transports[strlen(argv[])];
if ((r = fill_transport_list(context, argv[i], transports, &transport_count)) ||
(r = getdns_context_set_dns_transport_list(context,
transport_count, transports))){

View File

@ -99,6 +99,9 @@ struct getdns_upstream;
#define TIMEOUT_FOREVER ((int64_t)-1)
#define ASSERT_UNREACHABLE 0
#define GETDNS_TRANSPORTS_MAX 4
#define GETDNS_UPSTREAM_TRANSPORTS 3
/** @}
*/
@ -164,17 +167,17 @@ typedef struct getdns_tcp_state {
} getdns_tcp_state;
/* TODO[TLS]: change this name to getdns_transport when API updated*/
typedef enum getdns_base_transport {
GETDNS_BASE_TRANSPORT_MIN = 0,
GETDNS_BASE_TRANSPORT_NONE = 0,
GETDNS_BASE_TRANSPORT_UDP,
GETDNS_BASE_TRANSPORT_TCP_SINGLE, /* To be removed? */
GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */
GETDNS_BASE_TRANSPORT_TCP,
GETDNS_BASE_TRANSPORT_TLS,
GETDNS_BASE_TRANSPORT_MAX
} getdns_base_transport_t;
// /* TODO[TLS]: change this name to getdns_transport when API updated*/
// typedef enum getdns_base_transport {
// GETDNS_TRANSPORT_MIN = 0,
// GETDNS_TRANSPORT_NONE = 0,
// GETDNS_TRANSPORT_UDP,
// GETDNS_TRANSPORT_TCP_SINGLE, /* To be removed? */
// GETDNS_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */
// GETDNS_TRANSPORT_TCP,
// GETDNS_TRANSPORT_TLS,
// GETDNS_TRANSPORT_MAX
// } getdns_base_transport_t;
/**
* Request data
@ -203,8 +206,9 @@ typedef struct getdns_network_req
/* For stub resolving */
struct getdns_upstream *upstream;
int fd;
getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX];
int transport;
getdns_transport_list_t *transports;
size_t transport_count;
size_t transport_current;
getdns_eventloop_event event;
getdns_tcp_state tcp;
uint16_t query_id;