diff --git a/src/const-info.c b/src/const-info.c index 5c5a3d99..026d5da8 100755 --- a/src/const-info.c +++ b/src/const-info.c @@ -41,6 +41,7 @@ static struct const_info consts_info[] = { { 543, "GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN", GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN_TEXT }, { 544, "GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN", GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN_TEXT }, { 545, "GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN", GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN_TEXT }, + { 546, "GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN", GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN_TEXT }, { 550, "GETDNS_APPEND_NAME_ALWAYS", GETDNS_APPEND_NAME_ALWAYS_TEXT }, { 551, "GETDNS_APPEND_NAME_ONLY_TO_SINGLE_LABEL_AFTER_FAILURE", GETDNS_APPEND_NAME_ONLY_TO_SINGLE_LABEL_AFTER_FAILURE_TEXT }, { 552, "GETDNS_APPEND_NAME_ONLY_TO_MULTIPLE_LABEL_NAME_AFTER_FAILURE", GETDNS_APPEND_NAME_ONLY_TO_MULTIPLE_LABEL_NAME_AFTER_FAILURE_TEXT }, diff --git a/src/context.c b/src/context.c index 45e4b472..aabcc5e0 100644 --- a/src/context.c +++ b/src/context.c @@ -53,6 +53,13 @@ #include "stub.h" #include "list.h" +#define GETDNS_PORT_ZERO 0 +#define GETDNS_PORT_DNS 53 +#define GETDNS_PORT_DNS_OVER_TLS 1021 +#define GETDNS_STR_PORT_ZERO "0" +#define GETDNS_STR_PORT_DNS "53" +#define GETDNS_STR_PORT_DNS_OVER_TLS "1021" + void *plain_mem_funcs_user_arg = MF_PLAIN; typedef struct host_name_addrs { @@ -62,6 +69,26 @@ typedef struct host_name_addrs { uint8_t host_name[]; } host_name_addrs; +static in_port_t +getdns_port_array[GETDNS_BASE_TRANSPORT_MAX] = { + GETDNS_PORT_ZERO, + GETDNS_PORT_ZERO, + GETDNS_PORT_ZERO, + GETDNS_PORT_DNS, + GETDNS_PORT_DNS, + GETDNS_PORT_DNS_OVER_TLS +}; + +char* +getdns_port_str_array[] = { + GETDNS_STR_PORT_ZERO, + GETDNS_STR_PORT_ZERO, + GETDNS_STR_PORT_ZERO, + GETDNS_STR_PORT_DNS, + GETDNS_STR_PORT_DNS, + GETDNS_STR_PORT_DNS_OVER_TLS +}; + /* Private functions */ getdns_return_t create_default_namespaces(struct getdns_context *context); static struct getdns_list *create_default_root_servers(void); @@ -240,7 +267,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in *)sa)->sin_port); - if (port != 0 && port != 53 && + if (port != GETDNS_PORT_ZERO && port != GETDNS_PORT_DNS && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -256,7 +283,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port); - if (port != 0 && port != 53 && + if (port != GETDNS_PORT_DNS && port != GETDNS_PORT_DNS && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -514,8 +541,7 @@ upstream_scope_id(getdns_upstream *upstream) } static void -upstream_ntop_buf(getdns_upstream *upstream, getdns_transport_t transport, - char *buf, size_t len) +upstream_ntop_buf(getdns_upstream *upstream, char *buf, size_t len) { /* Also possible but prints scope_id by name (nor parsed by unbound) * @@ -527,10 +553,7 @@ upstream_ntop_buf(getdns_upstream *upstream, getdns_transport_t transport, if (upstream_scope_id(upstream)) (void) snprintf(buf + strlen(buf), len - strlen(buf), "%%%d", (int)*upstream_scope_id(upstream)); - if (transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) - (void) snprintf(buf + strlen(buf), len - strlen(buf), - "@%d", GETDNS_TLS_PORT); - else if (upstream_port(upstream) != 53 && upstream_port(upstream) != 0) + else if (upstream_port(upstream) != GETDNS_PORT_DNS && upstream_port(upstream) != GETDNS_PORT_ZERO) (void) snprintf(buf + strlen(buf), len - strlen(buf), "@%d", (int)upstream_port(upstream)); } @@ -557,6 +580,9 @@ upstream_init(getdns_upstream *upstream, /* For sharing a socket to this upstream with TCP */ upstream->fd = -1; upstream->tls_obj = NULL; + upstream->starttls_req = NULL; + upstream->dns_base_transport = GETDNS_BASE_TRANSPORT_TCP; + upstream->tls_hs_state = GETDNS_HS_NONE; upstream->loop = NULL; (void) getdns_eventloop_event_init( &upstream->event, upstream, NULL, NULL, NULL); @@ -659,20 +685,27 @@ set_os_defaults(struct getdns_context *context) token = parse + strcspn(parse, " \t\r\n"); *token = 0; - if ((s = getaddrinfo(parse, "53", &hints, &result))) - continue; + getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; + for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { + char *port_str = getdns_port_str_array[base_transport]; + if (strncmp(port_str, GETDNS_STR_PORT_ZERO, 1) == 0) + continue; + if ((s = getaddrinfo(parse, port_str, &hints, &result))) + continue; - /* No lookups, so maximal 1 result */ - if (! result) continue; + /* No lookups, so maximal 1 result */ + if (! result) continue; - /* Grow array when needed */ - if (context->upstreams->count == upstreams_limit) - context->upstreams = upstreams_resize( - context->upstreams, (upstreams_limit *= 2)); + /* Grow array when needed */ + if (context->upstreams->count == upstreams_limit) + context->upstreams = upstreams_resize( + context->upstreams, (upstreams_limit *= 2)); - upstream = &context->upstreams-> - upstreams[context->upstreams->count++]; - upstream_init(upstream, context->upstreams, result); + upstream = &context->upstreams-> + upstreams[context->upstreams->count++]; + upstream_init(upstream, context->upstreams, result); + upstream->dns_base_transport = base_transport; + } freeaddrinfo(result); } fclose(in); @@ -801,6 +834,7 @@ getdns_context_create_with_extended_memory_functions( result->dnssec_allowed_skew = 0; result->edns_maximum_udp_payload_size = -1; result->dns_transport = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP; + priv_set_base_dns_transports(result->dns_base_transports, result->dns_transport); result->limit_outstanding_queries = 0; result->has_ta = priv_getdns_parse_ta_file(NULL, NULL); result->return_dnssec_status = GETDNS_EXTENSION_FALSE; @@ -1124,31 +1158,43 @@ getdns_context_set_namespaces(struct getdns_context *context, return GETDNS_RETURN_GOOD; } /* getdns_context_set_namespaces */ -getdns_base_transport_t -priv_get_base_transport(getdns_transport_t transport, int level) { - if (!(level == 0 || level == 1)) return GETDNS_TRANSPORT_NONE; - switch (transport) { - case GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP: - if (level == 0) return GETDNS_TRANSPORT_UDP; - if (level == 1) return GETDNS_TRANSPORT_TCP; - case GETDNS_TRANSPORT_UDP_ONLY: - if (level == 0) return GETDNS_TRANSPORT_UDP; - if (level == 1) return GETDNS_TRANSPORT_NONE; - case GETDNS_TRANSPORT_TCP_ONLY: - if (level == 0) return GETDNS_TRANSPORT_TCP_SINGLE; - if (level == 1) return GETDNS_TRANSPORT_NONE; - case GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN: - if (level == 0) return GETDNS_TRANSPORT_TCP; - if (level == 1) return GETDNS_TRANSPORT_NONE; - case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - if (level == 0) return GETDNS_TRANSPORT_TLS; - if (level == 1) return GETDNS_TRANSPORT_NONE; - case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - if (level == 0) return GETDNS_TRANSPORT_TLS; - if (level == 1) return GETDNS_TRANSPORT_TCP; - default: - return GETDNS_TRANSPORT_NONE; - } +/* TODO[TLS]: Modify further when API changed.*/ +getdns_return_t +priv_set_base_dns_transports(getdns_base_transport_t *dns_base_transports, + getdns_transport_t value) +{ + for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) + dns_base_transports[i] = GETDNS_BASE_TRANSPORT_NONE; + switch (value) { + case GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; + dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP_SINGLE; + break; + case GETDNS_TRANSPORT_UDP_ONLY: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_UDP; + break; + case GETDNS_TRANSPORT_TCP_ONLY: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TCP_SINGLE; + break; + case GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TCP; + break; + case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; + break; + case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_TLS; + dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + break; + case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: + dns_base_transports[0] = GETDNS_BASE_TRANSPORT_STARTTLS; + dns_base_transports[1] = GETDNS_BASE_TRANSPORT_TCP; + break; + + default: + return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; + } + return GETDNS_RETURN_GOOD; } static getdns_return_t @@ -1170,12 +1216,12 @@ set_ub_dns_transport(struct getdns_context* context, set_ub_string_opt(context, "do-tcp:", "yes"); break; case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - /* Hum. If used in recursive mode this will try TLS on port 53... - * So we need to fix or document that or delay setting it until - * resolution.*/ + /* Note: If TLS is used in recursive mode this will try TLS on port + * 53... So this is prohibited when preparing for resolution.*/ set_ub_string_opt(context, "ssl-upstream:", "yes"); - /* Fall through*/ + /* Fall through */ case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: + case GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: /* Note: no fallback to TCP available directly in unbound, so we just * use TCP for now to make sure the messages are sent. */ set_ub_string_opt(context, "do-udp:", "no"); @@ -1195,17 +1241,22 @@ getdns_return_t getdns_context_set_dns_transport(struct getdns_context *context, getdns_transport_t value) { + /* TODO[TLS]: Modify further when API changed.*/ RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); /* Note that the call below does not have any effect in unbound after the - * ctx is finalised. So will not apply for recursive mode or stub + dnssec. + * ctx is finalised so for recursive mode or stub + dnssec only the first + * transport specified on the first query is used. * However the method returns success as otherwise the transport could not - * be reset for stub mode..... + * be reset for stub mode. * Also, not all transport options supported in libunbound yet */ if (set_ub_dns_transport(context, value) != GETDNS_RETURN_GOOD) { return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; } if (value != context->dns_transport) { + /*TODO[TLS]: remove this line when API updated*/ context->dns_transport = value; + if (priv_set_base_dns_transports(context->dns_base_transports, value) != GETDNS_RETURN_GOOD) + return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; dispatch_updated(context, GETDNS_CONTEXT_CODE_DNS_TRANSPORT); } @@ -1436,6 +1487,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, getdns_return_t r; size_t count = 0; size_t i; + //size_t upstreams_limit; getdns_upstreams *upstreams; char addrstr[1024], portstr[1024], *eos; struct addrinfo hints; @@ -1456,17 +1508,17 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, hints.ai_addr = NULL; hints.ai_next = NULL; - upstreams = upstreams_create(context, count); + upstreams = upstreams_create(context, count*3); + //upstreams_limit = count; for (i = 0; i < count; i++) { getdns_dict *dict; getdns_bindata *address_type; getdns_bindata *address_data; - uint32_t port; + struct sockaddr_storage addr; + getdns_bindata *scope_id; - struct addrinfo *ai; getdns_upstream *upstream; - upstream = &upstreams->upstreams[upstreams->count]; if ((r = getdns_list_get_dict(upstream_list, i, &dict))) goto error; @@ -1476,27 +1528,23 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, if (address_type->size < 4) goto invalid_parameter; if (strncmp((char *)address_type->data, "IPv4", 4) == 0) - upstream->addr.ss_family = AF_INET; + addr.ss_family = AF_INET; else if (strncmp((char *)address_type->data, "IPv6", 4) == 0) - upstream->addr.ss_family = AF_INET6; + addr.ss_family = AF_INET6; else goto invalid_parameter; if ((r = getdns_dict_get_bindata( dict, "address_data", &address_data))) goto error; - if ((upstream->addr.ss_family == AF_INET && + if ((addr.ss_family == AF_INET && address_data->size != 4) || - (upstream->addr.ss_family == AF_INET6 && + (addr.ss_family == AF_INET6 && address_data->size != 16)) goto invalid_parameter; - if (inet_ntop(upstream->addr.ss_family, address_data->data, + if (inet_ntop(addr.ss_family, address_data->data, addrstr, 1024) == NULL) goto invalid_parameter; - port = 53; - (void) getdns_dict_get_int(dict, "port", &port); - (void) snprintf(portstr, 1024, "%d", (int)port); - if (getdns_dict_get_bindata(dict, "scope_id", &scope_id) == GETDNS_RETURN_GOOD) { if (strlen(addrstr) + scope_id->size > 1022) @@ -1507,12 +1555,40 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, eos[scope_id->size] = 0; } - if (getaddrinfo(addrstr, portstr, &hints, &ai)) - goto invalid_parameter; + /* Loop to create upstreams as needed*/ + getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; + for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { + uint32_t port; + struct addrinfo *ai; + port = getdns_port_array[base_transport]; + if (port == GETDNS_PORT_ZERO) + continue; - upstream_init(upstream, upstreams, ai); - upstreams->count++; - freeaddrinfo(ai); + if (base_transport != GETDNS_BASE_TRANSPORT_TLS) + (void) getdns_dict_get_int(dict, "port", &port); + else + (void) getdns_dict_get_int(dict, "tls-port", &port); + (void) snprintf(portstr, 1024, "%d", (int)port); + + if (getaddrinfo(addrstr, portstr, &hints, &ai)) + goto invalid_parameter; + + /* TODO[TLS]: Should probably check that the upstream doesn't + * already exist (in case user has specified TLS port explicitly and + * to prevent duplicates) */ + + /* TODO[TLS]: Grow array when needed. This causes a crash later.... + if (upstreams->count == upstreams_limit) + upstreams = upstreams_resize( + upstreams, (upstreams_limit *= 2)); */ + + upstream = &upstreams->upstreams[upstreams->count]; + upstream->addr.ss_family = addr.ss_family; + upstream_init(upstream, upstreams, ai); + upstream->dns_base_transport = base_transport; + upstreams->count++; + freeaddrinfo(ai); + } } priv_getdns_upstreams_dereference(context->upstreams); context->upstreams = upstreams; @@ -1731,7 +1807,13 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) (void) ub_ctx_set_fwd(ctx, NULL); for (i = 0; i < upstreams->count; i++) { upstream = &upstreams->upstreams[i]; - upstream_ntop_buf(upstream, context->dns_transport, addr, 1024); + /*[TLS]: Use only the TLS subset of upstreams when TLS is the only thing + * used. All other cases must currently fallback to TCP for libunbound.*/ + if (context->dns_base_transports[0] == GETDNS_BASE_TRANSPORT_TLS && + context->dns_base_transports[1] == GETDNS_BASE_TRANSPORT_NONE && + upstream->dns_base_transport != GETDNS_BASE_TRANSPORT_TLS) + continue; + upstream_ntop_buf(upstream, addr, 1024); ub_ctx_set_fwd(ctx, addr); } @@ -1826,29 +1908,24 @@ getdns_context_prepare_for_resolution(struct getdns_context *context, } /* Transport can in theory be set per query in stub mode */ - /* TODO: move this transport logic to a separate functions*/ if (context->resolution_type == GETDNS_RESOLUTION_STUB) { - switch (context->dns_transport) { - case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - if (context->tls_ctx == NULL) { + /*TODO[TLS]: Check if TLS is in the list of transports.*/ + if (context->tls_ctx == NULL) { #ifdef HAVE_LIBTLS1_2 - /* Create client context, use TLS v1.2 only for now */ - context->tls_ctx = SSL_CTX_new(TLSv1_2_client_method()); + /* Create client context, use TLS v1.2 only for now */ + context->tls_ctx = SSL_CTX_new(TLSv1_2_client_method()); #endif - if(!context->tls_ctx && context->dns_transport == - GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) { - return GETDNS_RETURN_BAD_CONTEXT; - } - } - break; - default: - break; + /* TODO[TLS]: Check if TLS is the only option in the list*/ + // if(!context->tls_ctx && context->dns_transport == + // GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) { + // return GETDNS_RETURN_BAD_CONTEXT; + // } } } /* Block use of TLS ONLY in recursive mode as it won't work */ - if (context->resolution_type == GETDNS_RESOLUTION_RECURSING - && context->dns_transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) + /* TODO[TLS]: Check if TLS is the only option in the list*/ + if (context->resolution_type == GETDNS_RESOLUTION_RECURSING && + context->dns_transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) return GETDNS_RETURN_BAD_CONTEXT; if (context->resolution_type_set == context->resolution_type) diff --git a/src/context.h b/src/context.h index 12a2c263..3ad6b342 100644 --- a/src/context.h +++ b/src/context.h @@ -49,7 +49,6 @@ struct ub_ctx; #define GETDNS_FN_RESOLVCONF "/etc/resolv.conf" #define GETDNS_FN_HOSTS "/etc/hosts" -#define GETDNS_TLS_PORT 1021 enum filechgs { GETDNS_FCHG_ERRORS = -1 , GETDNS_FCHG_NOERROR = 0 @@ -72,13 +71,13 @@ struct filechg { struct stat *prevstat; }; -typedef enum getdns_base_transport { - GETDNS_TRANSPORT_NONE, - GETDNS_TRANSPORT_UDP, - GETDNS_TRANSPORT_TCP_SINGLE, - GETDNS_TRANSPORT_TCP, - GETDNS_TRANSPORT_TLS -} getdns_base_transport_t; +typedef enum getdns_tls_hs_state { + GETDNS_HS_NONE, + GETDNS_HS_WRITE, + GETDNS_HS_READ, + GETDNS_HS_DONE, + GETDNS_HS_FAILED +} getdns_tls_hs_state_t; typedef struct getdns_upstream { struct getdns_upstreams *upstreams; @@ -92,7 +91,10 @@ typedef struct getdns_upstream { /* For sharing a TCP socket to this upstream */ int fd; + getdns_base_transport_t dns_base_transport; SSL* tls_obj; + getdns_tls_hs_state_t tls_hs_state; + getdns_dns_req * starttls_req; getdns_eventloop_event event; getdns_eventloop *loop; getdns_tcp_state tcp; @@ -136,6 +138,7 @@ struct getdns_context { struct getdns_list *dnssec_trust_anchors; getdns_upstreams *upstreams; getdns_transport_t dns_transport; + getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; uint16_t limit_outstanding_queries; uint32_t dnssec_allowed_skew; @@ -231,7 +234,8 @@ int filechg_check(struct getdns_context *context, struct filechg *fchg); void priv_getdns_context_ub_read_cb(void *userarg); -getdns_base_transport_t priv_get_base_transport(getdns_transport_t transport, int level); +getdns_return_t priv_set_base_dns_transports(getdns_base_transport_t *, + getdns_transport_t); void priv_getdns_upstreams_dereference(getdns_upstreams *upstreams); diff --git a/src/getdns/getdns.h.in b/src/getdns/getdns.h.in index f69a2dfc..a2d6ea49 100755 --- a/src/getdns/getdns.h.in +++ b/src/getdns/getdns.h.in @@ -165,7 +165,8 @@ typedef enum getdns_transport_t { GETDNS_TRANSPORT_TCP_ONLY = 542, GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN = 543, GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN = 544, - GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN = 545 + GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN = 545, + GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN = 546 } getdns_transport_t; /** @@ -178,6 +179,7 @@ typedef enum getdns_transport_t { #define GETDNS_TRANSPORT_TCP_ONLY_KEEP_CONNECTIONS_OPEN_TEXT "See getdns_context_set_dns_transport()" #define GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN_TEXT "See getdns_context_set_dns_transport()" #define GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN_TEXT "See getdns_context_set_dns_transport()" +#define GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN_TEXT "See getdns_context_set_dns_transport()" /** @} */ diff --git a/src/request-internal.c b/src/request-internal.c index 0889b3fc..b4267022 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -89,6 +89,9 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner, net_req->upstream = NULL; net_req->fd = -1; + for (i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) + net_req->dns_base_transports[i] = owner->context->dns_base_transports[i]; + net_req->transport = 0; memset(&net_req->event, 0, sizeof(net_req->event)); memset(&net_req->tcp, 0, sizeof(net_req->tcp)); net_req->query_id = 0; diff --git a/src/stub.c b/src/stub.c index 0914aa57..4653620e 100755 --- a/src/stub.c +++ b/src/stub.c @@ -36,15 +36,49 @@ #include "stub.h" #include "gldns/gbuffer.h" #include "gldns/pkthdr.h" +#include "gldns/rrdef.h" +#include "gldns/str2wire.h" +#include "rr-iter.h" #include "context.h" #include #include "util-internal.h" #include "general.h" +#define STUB_TLS_SETUP_ERROR -4 +#define STUB_TCP_AGAIN -3 +#define STUB_TCP_ERROR -2 + +/* Don't currently have access to the context whilst doing handshake */ +#define TIMEOUT_TLS 2500 + +#define STUB_DEBUG 0 + static time_t secret_rollover_time = 0; static uint32_t secret = 0; static uint32_t prev_secret = 0; +static void upstream_read_cb(void *userarg); +static void upstream_write_cb(void *userarg); +static void upstream_schedule_netreq(getdns_upstream *upstream, + getdns_network_req *netreq); +static void netreq_upstream_read_cb(void *userarg); +static void netreq_upstream_write_cb(void *userarg); +static int fallback_on_write(getdns_network_req *netreq); + +static void stub_tcp_write_cb(void *userarg); + +/*****************************/ +/* General utility functions */ +/*****************************/ + +static void +stub_debug(const char *function_name) +{ +#ifdef STUB_DEBUG + fprintf(stderr,"[STUB DEBUG]: %s\n", function_name); +#endif +} + static void rollover_secret() { @@ -215,6 +249,99 @@ match_and_process_server_cookie( return 0; } +static int +create_starttls_request(getdns_dns_req *dnsreq, getdns_upstream *upstream, + getdns_eventloop *loop) +{ + getdns_return_t r = GETDNS_RETURN_GOOD; + getdns_dict* extensions = getdns_dict_create_with_context(dnsreq->context); + if (!extensions) { + return 0; + } + r = getdns_dict_set_int(extensions, "specify_class", GLDNS_RR_CLASS_CH); + if (r != GETDNS_RETURN_GOOD) { + getdns_dict_destroy(extensions); + return 0; + } + upstream->starttls_req = dns_req_new(dnsreq->context, loop, + "STARTTLS", GETDNS_RRTYPE_TXT, extensions); + /*TODO[TLS]: TO BIT*/ + if (upstream->starttls_req == NULL) + return 0; + getdns_dict_destroy(extensions); + + upstream->starttls_req->netreqs[0]->upstream = upstream; + return 1; +} + +static int +dname_equal(uint8_t *s1, uint8_t *s2) +{ + uint8_t i; + for (;;) { + if (*s1 != *s2) + return 0; + else if (!*s1) + return 1; + for (i = *s1++, s2++; i > 0; i--, s1++, s2++) + if ((*s1 & 0xDF) != (*s2 & 0xDF)) + return 0; + } +} + +static int +is_starttls_response(getdns_network_req *netreq) +{ + priv_getdns_rr_iter rr_iter_storage, *rr_iter; + priv_getdns_rdf_iter rdf_iter_storage, *rdf_iter; + uint16_t rr_type; + gldns_pkt_section section; + uint8_t starttls_name_space[256], + *starttls_name = starttls_name_space; + uint8_t owner_name_space[256], *owner_name; + size_t starttls_name_len = 256, owner_name_len; + + /* Servers that are not STARTTLS aware will refuse the CH query*/ + if (LDNS_RCODE_NOERROR != GLDNS_RCODE_WIRE(netreq->response)) + return 0; + + if (GLDNS_ANCOUNT(netreq->response) != 1) + return 0; + + (void) gldns_str2wire_dname_buf( + netreq->owner->name, starttls_name_space, &starttls_name_len); + + for ( rr_iter = priv_getdns_rr_iter_init(&rr_iter_storage + , netreq->response + , netreq->response_len) + ; rr_iter + ; rr_iter = priv_getdns_rr_iter_next(rr_iter)) { + + section = priv_getdns_rr_iter_section(rr_iter); + rr_type = gldns_read_uint16(rr_iter->rr_type); + if (section != GLDNS_SECTION_ANSWER || rr_type != GETDNS_RRTYPE_TXT) + continue; + + owner_name = priv_getdns_owner_if_or_as_decompressed( + rr_iter, owner_name_space, &owner_name_len); + if (!dname_equal(starttls_name, owner_name)) + continue; + + if (!(rdf_iter = priv_getdns_rdf_iter_init( + &rdf_iter_storage, rr_iter))) + continue; + /* re-use the starttls_name for the response dname*/ + starttls_name = priv_getdns_rdf_if_or_as_decompressed( + rdf_iter,starttls_name_space,&starttls_name_len); + if (dname_equal(starttls_name, owner_name)) + return 1; + else + return 0; + continue; + } + return 0; +} + /** best effort to set nonblocking */ static void getdns_sock_nonblock(int sockfd) @@ -235,6 +362,35 @@ getdns_sock_nonblock(int sockfd) #endif } +static int +tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +{ + int fd = -1; + if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + return -1; + + getdns_sock_nonblock(fd); +#ifdef USE_TCP_FASTOPEN + /* Leave the connect to the later call to sendto() if using TCP*/ + if (transport == GETDNS_BASE_TRANSPORT_TCP || + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || + transport == GETDNS_BASE_TRANSPORT_STARTTLS) + return fd; +#endif + if (connect(fd, (struct sockaddr *)&upstream->addr, + upstream->addr_len) == -1) { + if (errno != EINPROGRESS) { + close(fd); + return -1; + } + } + return fd; +} + +/**************************/ +/* Error/cleanup functions*/ +/**************************/ + static void stub_next_upstream(getdns_network_req *netreq) { @@ -243,6 +399,10 @@ stub_next_upstream(getdns_network_req *netreq) if (! --netreq->upstream->to_retry) netreq->upstream->to_retry = -(netreq->upstream->back_off *= 2); + /*[TLS]:TODO - This works because the next message won't try the exact + * same upstream (and the next message may not use the same transport), + * but the next message will find the next matching one thanks to logic in + * upstream_select, but this could be better */ if (++dnsreq->upstreams->current > dnsreq->upstreams->count) dnsreq->upstreams->current = 0; } @@ -301,6 +461,20 @@ stub_cleanup(getdns_network_req *netreq) } } +static int +tls_cleanup(getdns_upstream *upstream) +{ + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + upstream->tls_hs_state = GETDNS_HS_FAILED; + /* Reset timeout on failure*/ + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, + getdns_eventloop_event_init(&upstream->event, upstream, + NULL, upstream_write_cb, NULL)); + return STUB_TLS_SETUP_ERROR; +} + static void upstream_erred(getdns_upstream *upstream) { @@ -318,8 +492,6 @@ upstream_erred(getdns_upstream *upstream) netreq->state = NET_REQ_FINISHED; priv_getdns_check_dns_req_complete(netreq->owner); } - /* TODO[TLS]: When we get an error (which is probably a timeout) and are - * using to keep connections open should we leave the connection up here? */ if (upstream->tls_obj) { SSL_shutdown(upstream->tls_obj); SSL_free(upstream->tls_obj); @@ -329,6 +501,14 @@ upstream_erred(getdns_upstream *upstream) upstream->fd = -1; } +static void +message_erred(getdns_network_req *netreq) +{ + stub_cleanup(netreq); + netreq->state = NET_REQ_FINISHED; + priv_getdns_check_dns_req_complete(netreq->owner); +} + void priv_getdns_cancel_stub_request(getdns_network_req *netreq) { @@ -341,6 +521,8 @@ stub_erred(getdns_network_req *netreq) { stub_next_upstream(netreq); stub_cleanup(netreq); + /* TODO[TLS]: When we get an error (which is probably a timeout) and are + * using to keep connections open should we leave the connection up here? */ if (netreq->fd >= 0) close(netreq->fd); netreq->state = NET_REQ_FINISHED; priv_getdns_check_dns_req_complete(netreq->owner); @@ -350,6 +532,14 @@ static void stub_timeout_cb(void *userarg) { getdns_network_req *netreq = (getdns_network_req *)userarg; + + /* For now, mark a STARTTLS timeout as a failured negotiation and allow + * fallback but don't close the connection. */ + if (is_starttls_response(netreq)) { + netreq->upstream->tls_hs_state = GETDNS_HS_FAILED; + stub_next_upstream(netreq); + stub_cleanup(netreq); + } stub_next_upstream(netreq); stub_cleanup(netreq); @@ -357,146 +547,35 @@ stub_timeout_cb(void *userarg) (void) getdns_context_request_timed_out(netreq->owner); } -static void stub_tcp_write_cb(void *userarg); static void -stub_udp_read_cb(void *userarg) +upstream_tls_timeout_cb(void *userarg) { - getdns_network_req *netreq = (getdns_network_req *)userarg; - getdns_dns_req *dnsreq = netreq->owner; - getdns_upstream *upstream = netreq->upstream; + stub_debug(__FUNCTION__); + getdns_upstream *upstream = (getdns_upstream *)userarg; + /* Clean up and trigger a write to let the fallback code to its job */ + tls_cleanup(upstream); - ssize_t read; - - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - - read = recvfrom(netreq->fd, netreq->response, - netreq->max_udp_payload_size + 1, /* If read == max_udp_payload_size - * then all is good. If read == - * max_udp_payload_size + 1, then - * we receive more then requested! - * i.e. overflow - */ - 0, NULL, NULL); - if (read == -1 && (errno = EAGAIN || errno == EWOULDBLOCK)) - return; - - if (read < GLDNS_HEADER_SIZE) - return; /* Not DNS */ - - if (GLDNS_ID_WIRE(netreq->response) != netreq->query_id) - return; /* Cache poisoning attempt ;) */ - - if (netreq->owner->edns_cookies && match_and_process_server_cookie( - upstream, netreq->response, read)) - return; /* Client cookie didn't match? */ - - close(netreq->fd); - if (GLDNS_TC_WIRE(netreq->response) && - dnsreq->context->dns_transport == - GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP) { - - if ((netreq->fd = socket( - upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) - goto done; - - getdns_sock_nonblock(netreq->fd); - if (connect(netreq->fd, (struct sockaddr *)&upstream->addr, - upstream->addr_len) == -1 && errno != EINPROGRESS) { - - close(netreq->fd); - goto done; - } - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_tcp_write_cb, stub_timeout_cb)); - - return; + /* Need to handle the case where the far end doesn't respond to a + * TCP SYN and doesn't do a reset (as is the case with e.g. 8.8.8.8@1021). + * For that case the socket never becomes writable so doesn't trigger any + * callbacks. If so then clear out the queue in one go.*/ + int ret; + fd_set fds; + FD_ZERO(&fds); + FD_SET(FD_SET_T upstream->fd, &fds); + struct timeval tval; + tval.tv_sec = 0; + tval.tv_usec = 0; + ret = select(upstream->fd+1, NULL, &fds, NULL, &tval); + if (ret == 0) { + while (upstream->write_queue) + upstream_write_cb(upstream); } - netreq->response_len = read; - dnsreq->upstreams->current = 0; - - /* TODO: DNSSEC */ - netreq->secure = 0; - netreq->bogus = 0; -done: - netreq->state = NET_REQ_FINISHED; - priv_getdns_check_dns_req_complete(dnsreq); } -static void -stub_udp_write_cb(void *userarg) -{ - getdns_network_req *netreq = (getdns_network_req *)userarg; - getdns_dns_req *dnsreq = netreq->owner; - size_t pkt_len = netreq->response - netreq->query; - - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - - netreq->query_id = arc4random(); - GLDNS_ID_SET(netreq->query, netreq->query_id); - if (netreq->opt) { - if (netreq->edns_maximum_udp_payload_size == -1) - gldns_write_uint16(netreq->opt + 3, - ( netreq->max_udp_payload_size = - netreq->upstream->addr.ss_family == AF_INET6 - ? 1232 : 1432)); - if (netreq->owner->edns_cookies) { - netreq->response = attach_edns_cookie( - netreq->upstream, netreq->opt); - pkt_len = netreq->response - netreq->query; - } - } - - if ((ssize_t)pkt_len != sendto(netreq->fd, netreq->query, pkt_len, 0, - (struct sockaddr *)&netreq->upstream->addr, - netreq->upstream->addr_len)) { - close(netreq->fd); - return; - } - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - stub_udp_read_cb, NULL, stub_timeout_cb)); -} - -static getdns_upstream * -pick_upstream(getdns_dns_req *dnsreq) -{ - getdns_upstream *upstream; - size_t i; - - if (!dnsreq->upstreams->count) - return NULL; - - for (i = 0; i < dnsreq->upstreams->count; i++) - if (dnsreq->upstreams->upstreams[i].to_retry <= 0) - dnsreq->upstreams->upstreams[i].to_retry++; - - i = dnsreq->upstreams->current; - do { - if (dnsreq->upstreams->upstreams[i].to_retry > 0) { - dnsreq->upstreams->current = i; - return &dnsreq->upstreams->upstreams[i]; - } - if (++i > dnsreq->upstreams->count) - i = 0; - } while (i != dnsreq->upstreams->current); - - upstream = dnsreq->upstreams->upstreams; - for (i = 1; i < dnsreq->upstreams->count; i++) - if (dnsreq->upstreams->upstreams[i].back_off < - upstream->back_off) - upstream = &dnsreq->upstreams->upstreams[i]; - - upstream->back_off++; - upstream->to_retry = 1; - dnsreq->upstreams->current = upstream - dnsreq->upstreams->upstreams; - return upstream; -} - -#define STUB_TCP_AGAIN -2 -#define STUB_TCP_ERROR -1 +/****************************/ +/* TCP read/write functions */ +/****************************/ static int stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf) @@ -559,321 +638,6 @@ stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf) return GLDNS_ID_WIRE(tcp->read_buf); } -static void -stub_tcp_read_cb(void *userarg) -{ - getdns_network_req *netreq = (getdns_network_req *)userarg; - getdns_dns_req *dnsreq = netreq->owner; - int q; - - switch ((q = stub_tcp_read(netreq->fd, &netreq->tcp, - &dnsreq->context->mf))) { - - case STUB_TCP_AGAIN: - return; - - case STUB_TCP_ERROR: - stub_erred(netreq); - return; - - default: - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - if (q != netreq->query_id) - return; - if (netreq->owner->edns_cookies && - match_and_process_server_cookie( - netreq->upstream, netreq->tcp.read_buf, - netreq->tcp.read_pos - netreq->tcp.read_buf)) - return; /* Client cookie didn't match? */ - netreq->state = NET_REQ_FINISHED; - netreq->response = netreq->tcp.read_buf; - netreq->response_len = - netreq->tcp.read_pos - netreq->tcp.read_buf; - netreq->tcp.read_buf = NULL; - dnsreq->upstreams->current = 0; - - /* TODO: DNSSEC */ - netreq->secure = 0; - netreq->bogus = 0; - - stub_cleanup(netreq); - close(netreq->fd); - priv_getdns_check_dns_req_complete(dnsreq); - } -} - -/** wait for a socket to become ready */ -static int -sock_wait(int sockfd) -{ - int ret; - fd_set fds; - FD_ZERO(&fds); - FD_SET(FD_SET_T sockfd, &fds); - /*TODO[TLS]: Pick up this timeout from the context*/ - struct timeval timeout = {5, 0 }; - ret = select(sockfd+1, NULL, &fds, NULL, &timeout); - if(ret == 0) - /* timeout expired */ - return 0; - else if(ret == -1) - /* error */ - return 0; - return 1; -} - -static int -sock_connected(int sockfd) -{ - /* wait(write) until connected or error */ - while(1) { - int error = 0; - socklen_t len = (socklen_t)sizeof(error); - - if(!sock_wait(sockfd)) { - close(sockfd); - return -1; - } - - /* check if there is a pending error for nonblocking connect */ - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, (void*)&error, &len) < 0) { - error = errno; /* on solaris errno is error */ - } - if (error == EINPROGRESS || error == EWOULDBLOCK) - continue; /* try again */ - else if (error != 0) { - close(sockfd); - return -1; - } - /* connected */ - break; - } - return sockfd; -} - -/* The connection testing and handshake should be handled by integrating this - * with the event loop framework, but for now just implement a standalone - * handshake method.*/ -static SSL* -do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream) -{ - /*Lets make sure the connection is up before we try a handshake*/ - if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) { - return NULL; - } - - /* Create SSL instance */ - if (dnsreq->context->tls_ctx == NULL) - return NULL; - SSL* ssl = SSL_new(dnsreq->context->tls_ctx); - if(!ssl) { - return NULL; - } - /* Connect the SSL object with a file descriptor */ - if(!SSL_set_fd(ssl, upstream->fd)) { - SSL_free(ssl); - return NULL; - } - SSL_set_connect_state(ssl); - (void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY); - - int r; - int want; - fd_set fds; - FD_ZERO(&fds); - FD_SET(upstream->fd, &fds); - struct timeval timeout = {dnsreq->context->timeout/1000, 0 }; - while ((r = SSL_do_handshake(ssl)) != 1) - { - want = SSL_get_error(ssl, r); - switch (want) { - case SSL_ERROR_WANT_READ: - if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) { - SSL_free(ssl); - return NULL; - } - break; - case SSL_ERROR_WANT_WRITE: - if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) { - SSL_free(ssl); - return NULL; - } - break; - default: - SSL_free(ssl); - return NULL; - } - } - return ssl; -} - -static int -stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf) -{ - ssize_t read; - uint8_t *buf; - size_t buf_size; - - if (!tcp->read_buf) { - /* First time tls read, create a buffer for reading */ - if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) - return STUB_TCP_ERROR; - - tcp->read_buf_len = 4096; - tcp->read_pos = tcp->read_buf; - tcp->to_read = 2; /* Packet size */ - } - - ERR_clear_error(); - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake - renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { - return STUB_TCP_AGAIN; /* read more later */ - } else - return STUB_TCP_ERROR; - } - tcp->to_read -= read; - tcp->read_pos += read; - - if ((int)tcp->to_read > 0) - return STUB_TCP_AGAIN; - - read = tcp->read_pos - tcp->read_buf; - if (read == 2) { - /* Read the packet size short */ - tcp->to_read = gldns_read_uint16(tcp->read_buf); - - if (tcp->to_read < GLDNS_HEADER_SIZE) - return STUB_TCP_ERROR; - - /* Resize our buffer if needed */ - if (tcp->to_read > tcp->read_buf_len) { - buf_size = tcp->read_buf_len; - while (tcp->to_read > buf_size) - buf_size *= 2; - - if (!(buf = GETDNS_XREALLOC(*mf, - tcp->read_buf, uint8_t, buf_size))) - return STUB_TCP_ERROR; - - tcp->read_buf = buf; - tcp->read_buf_len = buf_size; - } - - /* Ready to start reading the packet */ - tcp->read_pos = tcp->read_buf; - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake - renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { - return STUB_TCP_AGAIN; /* read more later */ - } else - return STUB_TCP_ERROR; - } - tcp->to_read -= read; - tcp->read_pos += read; - if ((int)tcp->to_read > 0) - return STUB_TCP_AGAIN; - } - return GLDNS_ID_WIRE(tcp->read_buf); -} - -static void netreq_upstream_read_cb(void *userarg); -static void netreq_upstream_write_cb(void *userarg); -static void -upstream_read_cb(void *userarg) -{ - getdns_upstream *upstream = (getdns_upstream *)userarg; - getdns_network_req *netreq; - getdns_dns_req *dnsreq; - int q; - uint16_t query_id; - intptr_t query_id_intptr; - - if (upstream->tls_obj) - q = stub_tls_read(upstream->tls_obj, &upstream->tcp, - &upstream->upstreams->mf); - else - q = stub_tcp_read(upstream->fd, &upstream->tcp, - &upstream->upstreams->mf); - - switch (q) { - case STUB_TCP_AGAIN: - return; - - case STUB_TCP_ERROR: - upstream_erred(upstream); - return; - - default: - - /* Lookup netreq */ - query_id = (uint16_t) q; - query_id_intptr = (intptr_t) query_id; - netreq = (getdns_network_req *)getdns_rbtree_delete( - &upstream->netreq_by_query_id, (void *)query_id_intptr); - if (! netreq) /* maybe canceled */ { - /* reset read buffer */ - upstream->tcp.read_pos = upstream->tcp.read_buf; - upstream->tcp.to_read = 2; - return; - } - - netreq->state = NET_REQ_FINISHED; - netreq->response = upstream->tcp.read_buf; - netreq->response_len = - upstream->tcp.read_pos - upstream->tcp.read_buf; - upstream->tcp.read_buf = NULL; - upstream->upstreams->current = 0; - - /* TODO: DNSSEC */ - netreq->secure = 0; - netreq->bogus = 0; - - stub_cleanup(netreq); - - /* More to read/write for syncronous lookups? */ - if (netreq->event.read_cb) { - dnsreq = netreq->owner; - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - if (upstream->netreq_by_query_id.count || - upstream->write_queue) - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, upstream->fd, - dnsreq->context->timeout, - getdns_eventloop_event_init( - &netreq->event, netreq, - ( upstream->netreq_by_query_id.count ? - netreq_upstream_read_cb : NULL ), - ( upstream->write_queue ? - netreq_upstream_write_cb : NULL), - stub_timeout_cb)); - } - priv_getdns_check_dns_req_complete(netreq->owner); - - /* Nothing more to read? Then deschedule the reads.*/ - if (! upstream->netreq_by_query_id.count) { - upstream->event.read_cb = NULL; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - if (upstream->event.write_cb) - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, - &upstream->event); - } - } -} - -static void -netreq_upstream_read_cb(void *userarg) -{ - upstream_read_cb(((getdns_network_req *)userarg)->upstream); -} - /* stub_tcp_write(fd, tcp, netreq) * will return STUB_TCP_AGAIN when we need to come back again, * STUB_TCP_ERROR on error and a query_id on successfull sent. @@ -881,7 +645,6 @@ netreq_upstream_read_cb(void *userarg) static int stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) { - getdns_dns_req *dnsreq = netreq->owner; size_t pkt_len = netreq->response - netreq->query; ssize_t written; @@ -900,9 +663,10 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) * the write_queue) for that upstream. Register this netreq * by query_id in the process. */ - if ((dnsreq->context->dns_transport == GETDNS_TRANSPORT_TCP_ONLY) || - (dnsreq->context->dns_transport == GETDNS_TRANSPORT_UDP_ONLY) || - (dnsreq->context->dns_transport == GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP)) + if ((netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_TCP_SINGLE) || + (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_UDP)) query_id = arc4random(); else do { query_id = arc4random(); @@ -991,39 +755,232 @@ stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) } /* if (! tcp->write_buf) */ } -static void -stub_tcp_write_cb(void *userarg) +/*************************/ +/* TLS Utility functions */ +/*************************/ + +static int +tls_requested(getdns_network_req *netreq) { - getdns_network_req *netreq = (getdns_network_req *)userarg; - getdns_dns_req *dnsreq = netreq->owner; - int q; - - switch ((q = stub_tcp_write(netreq->fd, &netreq->tcp, netreq))) { - case STUB_TCP_AGAIN: - return; - - case STUB_TCP_ERROR: - stub_erred(netreq); - return; - - default: - netreq->query_id = (uint16_t) q; - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - stub_tcp_read_cb, NULL, stub_timeout_cb)); - return; - } + return (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_TLS || + netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_STARTTLS) ? + 1 : 0; } static int -stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq) +tls_should_write(getdns_upstream *upstream) +{ + /* Should messages be written on TLS upstream. Remember that for STARTTLS + * the first message should got over TCP as the handshake isn't started yet.*/ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + upstream->tls_hs_state != GETDNS_HS_NONE) ? 1 : 0; +} + +static int +tls_should_read(getdns_upstream *upstream) +{ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + !(upstream->tls_hs_state == GETDNS_HS_FAILED || + upstream->tls_hs_state == GETDNS_HS_NONE)) ? 1 : 0; +} + +static int +tls_failed(getdns_upstream *upstream) +{ + /* No messages should be scheduled onto an upstream in this state */ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + upstream->tls_hs_state == GETDNS_HS_FAILED) ? 1: 0; +} + +static SSL* +tls_create_object(getdns_context *context, int fd) +{ + /* Create SSL instance */ + if (context->tls_ctx == NULL) + return NULL; + SSL* ssl = SSL_new(context->tls_ctx); + if(!ssl) + return NULL; + /* Connect the SSL object with a file descriptor */ + if(!SSL_set_fd(ssl,fd)) { + SSL_free(ssl); + return NULL; + } + SSL_set_connect_state(ssl); + (void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY); + return ssl; +} + +static int +tls_do_handshake(getdns_upstream *upstream) +{ + stub_debug(__FUNCTION__); + int r; + int want; + ERR_clear_error(); + while ((r = SSL_do_handshake(upstream->tls_obj)) != 1) + { + want = SSL_get_error(upstream->tls_obj, r); + switch (want) { + case SSL_ERROR_WANT_READ: + upstream->event.read_cb = upstream_read_cb; + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_TLS, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_READ; + return STUB_TCP_AGAIN; + case SSL_ERROR_WANT_WRITE: + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_TLS, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_WRITE; + return STUB_TCP_AGAIN; + default: + return tls_cleanup(upstream); + } + } + upstream->tls_hs_state = GETDNS_HS_DONE; + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + /* Reset timeout on success*/ + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, + getdns_eventloop_event_init(&upstream->event, upstream, + NULL, upstream_write_cb, NULL)); + return 0; +} + +static int +tls_connected(getdns_upstream* upstream) +{ + /* Already have a connection*/ + if (upstream->tls_hs_state == GETDNS_HS_DONE && + (upstream->tls_obj != NULL) && (upstream->fd != -1)) + return 0; + + /* Already tried and failed, so let the fallback code take care of things */ + if (upstream->tls_hs_state == GETDNS_HS_FAILED) + return STUB_TLS_SETUP_ERROR; + + /* Lets make sure the connection is up before we try a handshake*/ + int error = 0; + socklen_t len = (socklen_t)sizeof(error); + getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); + if (error == EINPROGRESS || error == EWOULDBLOCK) + return STUB_TCP_AGAIN; /* try again */ + else if (error != 0) + return tls_cleanup(upstream); + + return tls_do_handshake(upstream); +} + +/***************************/ +/* TLS read/write functions*/ +/***************************/ + +static int +stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, + struct mem_funcs *mf) +{ + ssize_t read; + uint8_t *buf; + size_t buf_size; + SSL* tls_obj = upstream->tls_obj; + + int q = tls_connected(upstream); + if (q != 0) + return q; + + if (!tcp->read_buf) { + /* First time tls read, create a buffer for reading */ + if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) + return STUB_TCP_ERROR; + + tcp->read_buf_len = 4096; + tcp->read_pos = tcp->read_buf; + tcp->to_read = 2; /* Packet size */ + } + + ERR_clear_error(); + read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); + if (read <= 0) { + /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake + renegotiation. Need to keep handshake state to do that.*/ + int want = SSL_get_error(tls_obj, read); + if (want == SSL_ERROR_WANT_READ) { + return STUB_TCP_AGAIN; /* read more later */ + } else + return STUB_TCP_ERROR; + } + tcp->to_read -= read; + tcp->read_pos += read; + + if ((int)tcp->to_read > 0) + return STUB_TCP_AGAIN; + + read = tcp->read_pos - tcp->read_buf; + if (read == 2) { + /* Read the packet size short */ + tcp->to_read = gldns_read_uint16(tcp->read_buf); + + if (tcp->to_read < GLDNS_HEADER_SIZE) + return STUB_TCP_ERROR; + + /* Resize our buffer if needed */ + if (tcp->to_read > tcp->read_buf_len) { + buf_size = tcp->read_buf_len; + while (tcp->to_read > buf_size) + buf_size *= 2; + + if (!(buf = GETDNS_XREALLOC(*mf, + tcp->read_buf, uint8_t, buf_size))) + return STUB_TCP_ERROR; + + tcp->read_buf = buf; + tcp->read_buf_len = buf_size; + } + + /* Ready to start reading the packet */ + tcp->read_pos = tcp->read_buf; + read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); + if (read <= 0) { + /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake + renegotiation. Need to keep handshake state to do that.*/ + int want = SSL_get_error(tls_obj, read); + if (want == SSL_ERROR_WANT_READ) { + return STUB_TCP_AGAIN; /* read more later */ + } else + return STUB_TCP_ERROR; + } + tcp->to_read -= read; + tcp->read_pos += read; + if ((int)tcp->to_read > 0) + return STUB_TCP_AGAIN; + } + return GLDNS_ID_WIRE(tcp->read_buf); +} + +static int +stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, + getdns_network_req *netreq) { size_t pkt_len = netreq->response - netreq->query; ssize_t written; uint16_t query_id; intptr_t query_id_intptr; + SSL* tls_obj = upstream->tls_obj; + + int q = tls_connected(upstream); + if (q != 0) + return q; /* Do we have remaining data that we could not write before? */ if (! tcp->write_buf) { @@ -1050,7 +1007,7 @@ stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq) /* We have an initialized packet buffer. * Lets see how much of it we can write */ - // TODO[TLS]: Handle error cases, partial writes, renegotiation etc. + /* TODO[TLS]: Handle error cases, partial writes, renegotiation etc. */ ERR_clear_error(); written = SSL_write(tls_obj, netreq->query - 2, pkt_len + 2); if (written <= 0) @@ -1064,17 +1021,312 @@ stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq) return STUB_TCP_ERROR; } +/**************************/ +/* UDP callback functions */ +/**************************/ + +static void +stub_udp_read_cb(void *userarg) +{ + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_dns_req *dnsreq = netreq->owner; + getdns_upstream *upstream = netreq->upstream; + + ssize_t read; + + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + + read = recvfrom(netreq->fd, netreq->response, + netreq->max_udp_payload_size + 1, /* If read == max_udp_payload_size + * then all is good. If read == + * max_udp_payload_size + 1, then + * we receive more then requested! + * i.e. overflow + */ + 0, NULL, NULL); + if (read == -1 && (errno = EAGAIN || errno == EWOULDBLOCK)) + return; + + if (read < GLDNS_HEADER_SIZE) + return; /* Not DNS */ + + if (GLDNS_ID_WIRE(netreq->response) != netreq->query_id) + return; /* Cache poisoning attempt ;) */ + + if (netreq->owner->edns_cookies && match_and_process_server_cookie( + upstream, netreq->response, read)) + return; /* Client cookie didn't match? */ + + close(netreq->fd); + /*TODO[TLS]: Switch this to use the transport fallback list*/ + if (GLDNS_TC_WIRE(netreq->response) && + dnsreq->context->dns_transport == + GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP) { + + if ((netreq->fd = socket( + upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + goto done; + + getdns_sock_nonblock(netreq->fd); + if (connect(netreq->fd, (struct sockaddr *)&upstream->addr, + upstream->addr_len) == -1 && errno != EINPROGRESS) { + + close(netreq->fd); + goto done; + } + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, netreq->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + NULL, stub_tcp_write_cb, stub_timeout_cb)); + + return; + } + netreq->response_len = read; + dnsreq->upstreams->current = 0; + + /* TODO: DNSSEC */ + netreq->secure = 0; + netreq->bogus = 0; +done: + netreq->state = NET_REQ_FINISHED; + priv_getdns_check_dns_req_complete(dnsreq); +} + +static void +stub_udp_write_cb(void *userarg) +{ + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_dns_req *dnsreq = netreq->owner; + size_t pkt_len = netreq->response - netreq->query; + + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + + netreq->query_id = arc4random(); + GLDNS_ID_SET(netreq->query, netreq->query_id); + if (netreq->opt) { + if (netreq->edns_maximum_udp_payload_size == -1) + gldns_write_uint16(netreq->opt + 3, + ( netreq->max_udp_payload_size = + netreq->upstream->addr.ss_family == AF_INET6 + ? 1232 : 1432)); + if (netreq->owner->edns_cookies) { + netreq->response = attach_edns_cookie( + netreq->upstream, netreq->opt); + pkt_len = netreq->response - netreq->query; + } + } + + if ((ssize_t)pkt_len != sendto(netreq->fd, netreq->query, pkt_len, 0, + (struct sockaddr *)&netreq->upstream->addr, + netreq->upstream->addr_len)) { + close(netreq->fd); + return; + } + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, netreq->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + stub_udp_read_cb, NULL, stub_timeout_cb)); +} + +/**************************/ +/* TCP callback functions*/ +/**************************/ + +static void +stub_tcp_read_cb(void *userarg) +{ + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_dns_req *dnsreq = netreq->owner; + int q; + + switch ((q = stub_tcp_read(netreq->fd, &netreq->tcp, + &dnsreq->context->mf))) { + + case STUB_TCP_AGAIN: + return; + + case STUB_TCP_ERROR: + stub_erred(netreq); + return; + + default: + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + if (q != netreq->query_id) + return; + if (netreq->owner->edns_cookies && + match_and_process_server_cookie( + netreq->upstream, netreq->tcp.read_buf, + netreq->tcp.read_pos - netreq->tcp.read_buf)) + return; /* Client cookie didn't match? */ + netreq->state = NET_REQ_FINISHED; + netreq->response = netreq->tcp.read_buf; + netreq->response_len = + netreq->tcp.read_pos - netreq->tcp.read_buf; + netreq->tcp.read_buf = NULL; + dnsreq->upstreams->current = 0; + + /* TODO: DNSSEC */ + netreq->secure = 0; + netreq->bogus = 0; + + stub_cleanup(netreq); + close(netreq->fd); + priv_getdns_check_dns_req_complete(dnsreq); + } +} + +static void +stub_tcp_write_cb(void *userarg) +{ + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_dns_req *dnsreq = netreq->owner; + int q; + + switch ((q = stub_tcp_write(netreq->fd, &netreq->tcp, netreq))) { + case STUB_TCP_AGAIN: + return; + + case STUB_TCP_ERROR: + stub_erred(netreq); + return; + + default: + netreq->query_id = (uint16_t) q; + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, netreq->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + stub_tcp_read_cb, NULL, stub_timeout_cb)); + return; + } +} + +/**************************/ +/* Upstream callback functions*/ +/**************************/ + +static void +upstream_read_cb(void *userarg) +{ + stub_debug(__FUNCTION__); + getdns_upstream *upstream = (getdns_upstream *)userarg; + getdns_network_req *netreq; + getdns_dns_req *dnsreq; + int q; + uint16_t query_id; + intptr_t query_id_intptr; + + if (tls_should_read(upstream)) + q = stub_tls_read(upstream, &upstream->tcp, + &upstream->upstreams->mf); + else + q = stub_tcp_read(upstream->fd, &upstream->tcp, + &upstream->upstreams->mf); + + switch (q) { + case STUB_TCP_AGAIN: + return; + + case STUB_TCP_ERROR: + upstream_erred(upstream); + return; + + default: + + /* Lookup netreq */ + query_id = (uint16_t) q; + query_id_intptr = (intptr_t) query_id; + netreq = (getdns_network_req *)getdns_rbtree_delete( + &upstream->netreq_by_query_id, (void *)query_id_intptr); + if (! netreq) /* maybe canceled */ { + /* reset read buffer */ + upstream->tcp.read_pos = upstream->tcp.read_buf; + upstream->tcp.to_read = 2; + return; + } + + netreq->state = NET_REQ_FINISHED; + netreq->response = upstream->tcp.read_buf; + netreq->response_len = + upstream->tcp.read_pos - upstream->tcp.read_buf; + upstream->tcp.read_buf = NULL; + upstream->upstreams->current = 0; + + /* TODO: DNSSEC */ + netreq->secure = 0; + netreq->bogus = 0; + + stub_cleanup(netreq); + + /* More to read/write for syncronous lookups? */ + if (netreq->event.read_cb) { + dnsreq = netreq->owner; + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + if (upstream->netreq_by_query_id.count || + upstream->write_queue) + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, upstream->fd, + dnsreq->context->timeout, + getdns_eventloop_event_init( + &netreq->event, netreq, + ( upstream->netreq_by_query_id.count ? + netreq_upstream_read_cb : NULL ), + ( upstream->write_queue ? + netreq_upstream_write_cb : NULL), + stub_timeout_cb)); + } + + if (netreq->owner == upstream->starttls_req) { + dnsreq = netreq->owner; + if (is_starttls_response(netreq)) { + upstream->tls_obj = tls_create_object(dnsreq->context, + upstream->fd); + if (upstream->tls_obj == NULL) + upstream->tls_hs_state = GETDNS_HS_FAILED; + upstream->tls_hs_state = GETDNS_HS_WRITE; + } else + upstream->tls_hs_state = GETDNS_HS_FAILED; + dns_req_free(upstream->starttls_req); + upstream->starttls_req = NULL; + + /* Now reschedule the writes on this connection */ + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, + netreq->owner->context->timeout, + getdns_eventloop_event_init(&upstream->event, upstream, + NULL, upstream_write_cb, NULL)); + } else + priv_getdns_check_dns_req_complete(netreq->owner); + + /* Nothing more to read? Then deschedule the reads.*/ + if (! upstream->netreq_by_query_id.count) { + upstream->event.read_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + if (upstream->event.write_cb) + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, + &upstream->event); + } + } +} + +static void +netreq_upstream_read_cb(void *userarg) +{ + upstream_read_cb(((getdns_network_req *)userarg)->upstream); +} static void upstream_write_cb(void *userarg) { + stub_debug(__FUNCTION__); getdns_upstream *upstream = (getdns_upstream *)userarg; getdns_network_req *netreq = upstream->write_queue; getdns_dns_req *dnsreq = netreq->owner; int q; - if (upstream->tls_obj) - q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq); + if (tls_requested(netreq) && tls_should_write(upstream)) + q = stub_tls_write(upstream, &upstream->tcp, netreq); else q = stub_tcp_write(upstream->fd, &upstream->tcp, netreq); @@ -1086,9 +1338,14 @@ upstream_write_cb(void *userarg) stub_erred(netreq); return; + case STUB_TLS_SETUP_ERROR: + /* Could not complete the TLS set up. Need to fallback.*/ + if (fallback_on_write(netreq) == STUB_TCP_ERROR) + message_erred(netreq); + return; + default: netreq->query_id = (uint16_t) q; - /* Unqueue the netreq from the write_queue */ if (!(upstream->write_queue = netreq->write_queue_tail)) { upstream->write_queue_last = NULL; @@ -1109,6 +1366,14 @@ upstream_write_cb(void *userarg) GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, &upstream->event); } + if (upstream->starttls_req) { + /* Now deschedule any further writes on this connection until we get + * the STARTTLS answer*/ + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + } /* With synchonous lookups, schedule the read locally too */ if (netreq->event.write_cb) { GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); @@ -1116,7 +1381,7 @@ upstream_write_cb(void *userarg) dnsreq->loop, upstream->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, netreq_upstream_read_cb, - ( upstream->write_queue ? + (upstream->write_queue && !upstream->starttls_req ? netreq_upstream_write_cb : NULL), stub_timeout_cb)); } @@ -1130,9 +1395,264 @@ netreq_upstream_write_cb(void *userarg) upstream_write_cb(((getdns_network_req *)userarg)->upstream); } +/*****************************/ +/* Upstream utility functions*/ +/*****************************/ + +static int +upstream_transport_valid(getdns_upstream *upstream, + getdns_base_transport_t transport) +{ + /* For single shot transports, use only the TCP upstream. */ + if (transport == GETDNS_BASE_TRANSPORT_UDP || + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) + return (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP ? 1:0); + /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't + * upgraded to avoid opening a new connection if one is aleady open. */ + if (transport == GETDNS_BASE_TRANSPORT_TCP && + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && + upstream->tls_hs_state == GETDNS_HS_FAILED) + return 1; + /* Otherwise, transport must match, and not have failed */ + if (upstream->dns_base_transport != transport) + return 0; + if (tls_failed(upstream)) + return 0; + return 1; +} + +static getdns_upstream * +upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) +{ + getdns_upstream *upstream; + getdns_upstreams *upstreams = netreq->owner->upstreams; + size_t i; + + if (!upstreams->count) + return NULL; + + for (i = 0; i < upstreams->count; i++) + if (upstreams->upstreams[i].to_retry <= 0) + upstreams->upstreams[i].to_retry++; + + i = upstreams->current; + do { + if (upstreams->upstreams[i].to_retry > 0 && + upstream_transport_valid(&upstreams->upstreams[i], transport)) { + upstreams->current = i; + return &upstreams->upstreams[i]; + } + if (++i > upstreams->count) + i = 0; + } while (i != upstreams->current); + + upstream = upstreams->upstreams; + for (i = 0; i < upstreams->count; i++) + if (upstreams->upstreams[i].back_off < upstream->back_off && + upstream_transport_valid(&upstreams->upstreams[i], transport)) + upstream = &upstreams->upstreams[i]; + + /* Need to check again that the transport is valid */ + if (!upstream_transport_valid(upstream, transport)) + return NULL; + upstream->back_off++; + upstream->to_retry = 1; + upstreams->current = upstream - upstreams->upstreams; + return upstream; +} + + +int +upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, + getdns_dns_req *dnsreq) +{ + stub_debug(__FUNCTION__); + int fd = -1; + switch(transport) { + case GETDNS_BASE_TRANSPORT_UDP: + if ((fd = socket( + upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) + return -1; + getdns_sock_nonblock(fd); + return fd; + + case GETDNS_BASE_TRANSPORT_TCP: + /* Use existing if available*/ + if (upstream->fd != -1) + return upstream->fd; + /* Otherwise, fall through */ + case GETDNS_BASE_TRANSPORT_TCP_SINGLE: + fd = tcp_connect(upstream, transport); + upstream->loop = dnsreq->context->extension; + upstream->fd = fd; + break; + + case GETDNS_BASE_TRANSPORT_TLS: + /* Use existing if available*/ + if (upstream->fd != -1 && !tls_failed(upstream)) + return upstream->fd; + fd = tcp_connect(upstream, transport); + if (fd == -1) return -1; + upstream->tls_obj = tls_create_object(dnsreq->context, fd); + if (upstream->tls_obj == NULL) { + close(fd); + return -1; + } + upstream->tls_hs_state = GETDNS_HS_WRITE; + upstream->loop = dnsreq->context->extension; + upstream->fd = fd; + break; + case GETDNS_BASE_TRANSPORT_STARTTLS: + /* Use existing if available. Let the fallback code handle it if + * STARTTLS isn't availble. */ + if (upstream->fd != -1) + return upstream->fd; + fd = tcp_connect(upstream, transport); + if (fd == -1) return -1; + if (!create_starttls_request(dnsreq, upstream, dnsreq->loop)) + return GETDNS_RETURN_GENERIC_ERROR; + getdns_network_req *starttls_netreq = upstream->starttls_req->netreqs[0]; + upstream->loop = dnsreq->context->extension; + upstream->fd = fd; + upstream_schedule_netreq(upstream, starttls_netreq); + /* Schedule at least the timeout locally, but use less than half the + * context value so by default this timeouts before the TIMEOUT_TLS. + * And also the write if we perform a synchronous lookup */ + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, upstream->fd, dnsreq->context->timeout / 3, + getdns_eventloop_event_init(&starttls_netreq->event, + starttls_netreq, NULL, (dnsreq->loop != upstream->loop + ? netreq_upstream_write_cb : NULL), stub_timeout_cb)); + break; + default: + return -1; + /* Nothing to do*/ + } + return fd; +} + +static getdns_upstream* +find_upstream_for_specific_transport(getdns_network_req *netreq, + getdns_base_transport_t transport, + int *fd) +{ + /* TODO[TLS]: Fallback through upstreams....?*/ + getdns_upstream *upstream = upstream_select(netreq, transport); + if (!upstream) + return NULL; + *fd = upstream_connect(upstream, transport, netreq->owner); + return upstream; +} + +static int +find_upstream_for_netreq(getdns_network_req *netreq) +{ + int fd = -1; + int i = netreq->transport; + for (; i < GETDNS_BASE_TRANSPORT_MAX && + netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) { + netreq->upstream = find_upstream_for_specific_transport(netreq, + netreq->dns_base_transports[i], + &fd); + if (fd == -1 || !netreq->upstream) + continue; + netreq->transport = i; + return fd; + } + return -1; +} + +/************************/ +/* Scheduling functions */ +/***********************/ + +static int +move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, + getdns_upstream *new_upstream) +{ + stub_debug(__FUNCTION__); + /* Remove from queue, clearing event and fd if we are the last*/ + if (!(upstream->write_queue = netreq->write_queue_tail)) { + upstream->write_queue_last = NULL; + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + + close(upstream->fd); + upstream->fd = -1; + } + netreq->write_queue_tail = NULL; + + /* Schedule with the new upstream */ + netreq->upstream = new_upstream; + upstream_schedule_netreq(new_upstream, netreq); + + /* TODO[TLS]: Timout need to be adjusted and rescheduled on the new fd ...*/ + /* Note, setup timeout should be shorter than message timeout for + * messages with fallback or don't have time to re-try. */ + + /* For sync messages we must re-schedule the events here.*/ + if (netreq->owner->loop != upstream->loop) { + /* Create an event for the new upstream*/ + GETDNS_CLEAR_EVENT(netreq->owner->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT(netreq->owner->loop, new_upstream->fd, + netreq->owner->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + ( new_upstream->netreq_by_query_id.count ? + netreq_upstream_read_cb : NULL ), + ( new_upstream->write_queue ? + netreq_upstream_write_cb : NULL), + stub_timeout_cb)); + + /* Now one for the old upstream. Must schedule this last to make sure + * it is called back first....?*/ + if (upstream->write_queue) { + GETDNS_CLEAR_EVENT(netreq->owner->loop, &upstream->write_queue->event); + GETDNS_SCHEDULE_EVENT( + upstream->write_queue->owner->loop, upstream->fd, + upstream->write_queue->owner->context->timeout, + getdns_eventloop_event_init(&upstream->write_queue->event, + upstream->write_queue, + ( upstream->netreq_by_query_id.count ? + netreq_upstream_read_cb : NULL ), + ( upstream->write_queue ? + netreq_upstream_write_cb : NULL), + stub_timeout_cb)); + } + } + netreq->transport++; + return upstream->fd; +} + +static int +fallback_on_write(getdns_network_req *netreq) +{ + stub_debug(__FUNCTION__); + /* TODO[TLS]: Fallback through all transports.*/ + getdns_base_transport_t next_transport = + netreq->dns_base_transports[netreq->transport + 1]; + if (next_transport == GETDNS_BASE_TRANSPORT_NONE) + return STUB_TCP_ERROR; + + if (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_STARTTLS && + next_transport == GETDNS_BASE_TRANSPORT_TCP) { + /* Special case where can stay on same upstream*/ + netreq->transport++; + return netreq->upstream->fd; + } + getdns_upstream *upstream = netreq->upstream; + int fd; + getdns_upstream *new_upstream = + find_upstream_for_specific_transport(netreq, next_transport, &fd); + if (!new_upstream) + return STUB_TCP_ERROR; + return move_netreq(netreq, upstream, new_upstream); +} + static void upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) { + stub_debug(__FUNCTION__); /* We have a connected socket and a global event loop */ assert(upstream->fd >= 0); assert(upstream->loop); @@ -1140,172 +1660,63 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) /* Append netreq to write_queue */ if (!upstream->write_queue) { upstream->write_queue = upstream->write_queue_last = netreq; - upstream->event.write_cb = upstream_write_cb; GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); + if (upstream->tls_hs_state == GETDNS_HS_WRITE || + (upstream->starttls_req && + upstream->starttls_req->netreqs[0] == netreq)) { + /* Set a timeout on the upstream so we can catch failed setup*/ + /* TODO[TLS]: When generic fallback supported, we should decide how + * to split the timeout between transports. */ + GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, + netreq->owner->context->timeout / 2, + getdns_eventloop_event_init(&upstream->event, upstream, + NULL, upstream_write_cb, upstream_tls_timeout_cb)); + } else { + upstream->event.write_cb = upstream_write_cb; + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + } } else { upstream->write_queue_last->write_queue_tail = netreq; upstream->write_queue_last = netreq; } } -static in_port_t -get_port(struct sockaddr_storage* addr) -{ - return ntohs(addr->ss_family == AF_INET - ? ((struct sockaddr_in *)addr)->sin_port - : ((struct sockaddr_in6*)addr)->sin6_port); -} - -static void -set_port(struct sockaddr_storage* addr, in_port_t port) -{ - addr->ss_family == AF_INET - ? (((struct sockaddr_in *)addr)->sin_port = htons(port)) - : (((struct sockaddr_in6*)addr)->sin6_port = htons(port)); -} - -static int -tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) { - - int fd =-1; - struct sockaddr_storage connect_addr; - struct sockaddr_storage* addr = &upstream->addr; - socklen_t addr_len = upstream->addr_len; - - /* TODO[TLS]: For now, override the port to a hardcoded value*/ - if (transport == GETDNS_TRANSPORT_TLS && - (int)get_port(addr) != GETDNS_TLS_PORT) { - connect_addr = upstream->addr; - addr = &connect_addr; - set_port(addr, GETDNS_TLS_PORT); - } - - if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) - return -1; - - getdns_sock_nonblock(fd); -#ifdef USE_TCP_FASTOPEN - /* Leave the connect to the later call to sendto() if using TCP*/ - if (transport == GETDNS_TRANSPORT_TCP || - transport == GETDNS_TRANSPORT_TCP_SINGLE) - return fd; -#endif - if (connect(fd, (struct sockaddr *)addr, - addr_len) == -1) { - if (errno != EINPROGRESS) { - close(fd); - return -1; - } - } - return fd; -} - getdns_return_t priv_getdns_submit_stub_request(getdns_network_req *netreq) { - getdns_dns_req *dnsreq = netreq->owner; - getdns_upstream *upstream = pick_upstream(dnsreq); + stub_debug(__FUNCTION__); + int fd = -1; + getdns_dns_req *dnsreq = netreq->owner; - if (!upstream) - return GETDNS_RETURN_GENERIC_ERROR; + /* This does a best effort to get a initial fd. + * All other set up is done async*/ + fd = find_upstream_for_netreq(netreq); + if (fd == -1) + return GETDNS_RETURN_GENERIC_ERROR; - // Work out the primary and fallback transport options - getdns_base_transport_t transport = priv_get_base_transport( - dnsreq->context->dns_transport,0); - getdns_base_transport_t fb_transport = priv_get_base_transport( - dnsreq->context->dns_transport,1); + getdns_base_transport_t transport = + netreq->dns_base_transports[netreq->transport]; switch(transport) { - case GETDNS_TRANSPORT_UDP: - - if ((netreq->fd = socket( - upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - - getdns_sock_nonblock(netreq->fd); - netreq->upstream = upstream; - + case GETDNS_BASE_TRANSPORT_UDP: + case GETDNS_BASE_TRANSPORT_TCP_SINGLE: + netreq->fd = fd; GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_udp_write_cb, stub_timeout_cb)); - - return GETDNS_RETURN_GOOD; - - case GETDNS_TRANSPORT_TCP_SINGLE: - - if ((netreq->fd = tcp_connect(upstream, transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - netreq->upstream = upstream; - - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_tcp_write_cb, stub_timeout_cb)); - + NULL, (transport == GETDNS_BASE_TRANSPORT_UDP ? + stub_udp_write_cb: stub_tcp_write_cb), stub_timeout_cb)); return GETDNS_RETURN_GOOD; - case GETDNS_TRANSPORT_TCP: - case GETDNS_TRANSPORT_TLS: - - /* In coming comments, "global" means "context wide" */ - - /* Are we the first? (Is global socket initialized?) */ - if (upstream->fd == -1) { - /* TODO[TLS]: We should remember on the context if we had to fallback - * for this upstream so when re-connecting from a dropped TCP - * connection we don't retry TLS. */ - int fallback = 0; - - /* We are the first. Make global socket and connect. */ - if ((upstream->fd = tcp_connect(upstream, transport)) == -1) { - if (fb_transport == GETDNS_TRANSPORT_NONE) - return GETDNS_RETURN_GENERIC_ERROR; - if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - fallback = 1; - } - - /* Now do a handshake for TLS. Note waiting for this to succeed or - * timeout blocks the scheduling of any messages for this upstream*/ - if (transport == GETDNS_TRANSPORT_TLS && (fallback == 0)) { - upstream->tls_obj = do_tls_handshake(dnsreq, upstream); - if (!upstream->tls_obj) { - if (fb_transport == GETDNS_TRANSPORT_NONE) - return GETDNS_RETURN_GENERIC_ERROR; - close(upstream->fd); - if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - } - } - /* Attach to the global event loop - * so it can do it's own scheduling - */ - upstream->loop = dnsreq->context->extension; - } else { - /* Cater for the case of the user downgrading and existing TLS - connection to TCP for some reason...*/ - if (transport == GETDNS_TRANSPORT_TCP && upstream->tls_obj) { - SSL_shutdown(upstream->tls_obj); - SSL_free(upstream->tls_obj); - upstream->tls_obj = NULL; - } - } - netreq->upstream = upstream; - - /* We have a context wide socket. - * Now schedule the write request. - */ - upstream_schedule_netreq(upstream, netreq); - - /* Schedule at least the timeout locally. - * And also the write if we perform a synchronous lookup - */ + case GETDNS_BASE_TRANSPORT_STARTTLS: + case GETDNS_BASE_TRANSPORT_TLS: + case GETDNS_BASE_TRANSPORT_TCP: + upstream_schedule_netreq(netreq->upstream, netreq); + /* TODO[TLS]: Change scheduling for sync calls. */ GETDNS_SCHEDULE_EVENT( - dnsreq->loop, upstream->fd, dnsreq->context->timeout, + dnsreq->loop, netreq->upstream->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, NULL, - ( dnsreq->loop != upstream->loop /* Synchronous lookup? */ + ( dnsreq->loop != netreq->upstream->loop /* Synchronous lookup? */ ? netreq_upstream_write_cb : NULL), stub_timeout_cb)); return GETDNS_RETURN_GOOD; @@ -1314,4 +1725,4 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq) } } -/* stub.c */ +/* stub.c */ \ No newline at end of file diff --git a/src/test/getdns_query.c b/src/test/getdns_query.c index abe21529..3fcf0920 100644 --- a/src/test/getdns_query.c +++ b/src/test/getdns_query.c @@ -54,6 +54,7 @@ ipaddr_dict(getdns_context *context, char *ipstr) getdns_dict *r = getdns_dict_create_with_context(context); char *s = strchr(ipstr, '%'), *scope_id_str = ""; char *p = strchr(ipstr, '@'), *portstr = ""; + char *t = strchr(ipstr, '#'), *tls_portstr = ""; uint8_t buf[sizeof(struct in6_addr)]; getdns_bindata addr; @@ -68,6 +69,10 @@ ipaddr_dict(getdns_context *context, char *ipstr) *p = 0; portstr = p + 1; } + if (t) { + *t = 0; + tls_portstr = t + 1; + } if (strchr(ipstr, ':')) { getdns_dict_util_set_string(r, "address_type", "IPv6"); addr.size = 16; @@ -86,6 +91,8 @@ ipaddr_dict(getdns_context *context, char *ipstr) getdns_dict_set_bindata(r, "address_data", &addr); if (*portstr) getdns_dict_set_int(r, "port", (int32_t)atoi(portstr)); + if (*tls_portstr) + getdns_dict_set_int(r, "tls-port", (int32_t)atoi(tls_portstr)); if (*scope_id_str) getdns_dict_util_set_string(r, "scope_id", scope_id_str); @@ -121,6 +128,7 @@ print_usage(FILE *out, const char *progname) fprintf(out, "\t-O\tSet transport to TCP only keep connections open\n"); fprintf(out, "\t-L\tSet transport to TLS only keep connections open\n"); fprintf(out, "\t-E\tSet transport to TLS with TCP fallback only keep connections open\n"); + fprintf(out, "\t-R\tSet transport to STARTTLS with TCP fallback only keep connections open\n"); fprintf(out, "\t-u\tSet transport to UDP with TCP fallback\n"); fprintf(out, "\t-U\tSet transport to UDP only\n"); fprintf(out, "\t-B\tBatch mode. Schedule all messages before processing responses.\n"); @@ -369,6 +377,10 @@ getdns_return_t parse_args(int argc, char **argv) getdns_context_set_dns_transport(context, GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN); break; + case 'R': + getdns_context_set_dns_transport(context, + GETDNS_TRANSPORT_STARTTLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN); + break; case 'u': getdns_context_set_dns_transport(context, GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP); diff --git a/src/types-internal.h b/src/types-internal.h index 5220d599..53d9db2e 100644 --- a/src/types-internal.h +++ b/src/types-internal.h @@ -164,6 +164,18 @@ typedef struct getdns_tcp_state { } getdns_tcp_state; +/* TODO[TLS]: change this name to getdns_transport when API updated*/ +typedef enum getdns_base_transport { + GETDNS_BASE_TRANSPORT_MIN = 0, + GETDNS_BASE_TRANSPORT_NONE = 0, + GETDNS_BASE_TRANSPORT_UDP, + GETDNS_BASE_TRANSPORT_TCP_SINGLE, /* To be removed? */ + GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */ + GETDNS_BASE_TRANSPORT_TCP, + GETDNS_BASE_TRANSPORT_TLS, + GETDNS_BASE_TRANSPORT_MAX +} getdns_base_transport_t; + /** * Request data **/ @@ -191,6 +203,8 @@ typedef struct getdns_network_req /* For stub resolving */ struct getdns_upstream *upstream; int fd; + getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; + int transport; getdns_eventloop_event event; getdns_tcp_state tcp; uint16_t query_id;