diff --git a/src/context.c b/src/context.c index 45e4b472..e1f31e24 100644 --- a/src/context.c +++ b/src/context.c @@ -62,6 +62,18 @@ typedef struct host_name_addrs { uint8_t host_name[]; } host_name_addrs; +static in_port_t +getdns_port_array[GETDNS_PORT_LAST] = { + GETDNS_PORT_NUM_TCP, + GETDNS_PORT_NUM_TLS +}; + +// char* +// getdns_port_str_array[] = { +// "53", +// "1021" +// }; + /* Private functions */ getdns_return_t create_default_namespaces(struct getdns_context *context); static struct getdns_list *create_default_root_servers(void); @@ -240,7 +252,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in *)sa)->sin_port); - if (port != 0 && port != 53 && + if (port != 0 && port != GETDNS_PORT_NUM_TCP && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -256,7 +268,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa) break; port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port); - if (port != 0 && port != 53 && + if (port != 0 && port != GETDNS_PORT_NUM_TCP && getdns_dict_set_int(address, "port", (uint32_t)port)) break; @@ -527,10 +539,7 @@ upstream_ntop_buf(getdns_upstream *upstream, getdns_transport_t transport, if (upstream_scope_id(upstream)) (void) snprintf(buf + strlen(buf), len - strlen(buf), "%%%d", (int)*upstream_scope_id(upstream)); - if (transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) - (void) snprintf(buf + strlen(buf), len - strlen(buf), - "@%d", GETDNS_TLS_PORT); - else if (upstream_port(upstream) != 53 && upstream_port(upstream) != 0) + else if (upstream_port(upstream) != GETDNS_PORT_NUM_TCP && upstream_port(upstream) != 0) (void) snprintf(buf + strlen(buf), len - strlen(buf), "@%d", (int)upstream_port(upstream)); } @@ -557,6 +566,10 @@ upstream_init(getdns_upstream *upstream, /* For sharing a socket to this upstream with TCP */ upstream->fd = -1; upstream->tls_obj = NULL; + upstream->base_transport = (upstream_port(upstream) == GETDNS_PORT_NUM_TLS ? + GETDNS_TRANSPORT_TLS : + GETDNS_TRANSPORT_TCP); + upstream->tls_hs_state = GETDNS_HS_NONE; upstream->loop = NULL; (void) getdns_eventloop_event_init( &upstream->event, upstream, NULL, NULL, NULL); @@ -659,20 +672,25 @@ set_os_defaults(struct getdns_context *context) token = parse + strcspn(parse, " \t\r\n"); *token = 0; - if ((s = getaddrinfo(parse, "53", &hints, &result))) - continue; + //getdns_port_type_t port_type = GETDNS_PORT_FIRST; + //for (; port_type < GETDNS_PORT_LAST; port_type++) { + // TODO[TLS]: Seeing strange crash in ub_create_ctx when using the loop here.... + //fprintf(stderr,"creating upstream %s\n", parse); + if ((s = getaddrinfo(parse, "53", /*getdns_port_str_array[port_type],*/ &hints, &result))) + continue; - /* No lookups, so maximal 1 result */ - if (! result) continue; + /* No lookups, so maximal 1 result */ + if (! result) continue; - /* Grow array when needed */ - if (context->upstreams->count == upstreams_limit) - context->upstreams = upstreams_resize( - context->upstreams, (upstreams_limit *= 2)); + /* Grow array when needed */ + if (context->upstreams->count == upstreams_limit) + context->upstreams = upstreams_resize( + context->upstreams, (upstreams_limit *= 2)); - upstream = &context->upstreams-> - upstreams[context->upstreams->count++]; - upstream_init(upstream, context->upstreams, result); + upstream = &context->upstreams-> + upstreams[context->upstreams->count++]; + upstream_init(upstream, context->upstreams, result); + //} freeaddrinfo(result); } fclose(in); @@ -1456,63 +1474,68 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, hints.ai_addr = NULL; hints.ai_next = NULL; - upstreams = upstreams_create(context, count); + upstreams = upstreams_create(context, count*2); for (i = 0; i < count; i++) { - getdns_dict *dict; - getdns_bindata *address_type; - getdns_bindata *address_data; - uint32_t port; - getdns_bindata *scope_id; - struct addrinfo *ai; - getdns_upstream *upstream; + /* Loop twice to create TCP and TLS upstreams*/ + getdns_port_type_t port_type = GETDNS_PORT_FIRST; + for (; port_type < GETDNS_PORT_LAST; port_type++) { + getdns_dict *dict; + getdns_bindata *address_type; + getdns_bindata *address_data; + uint32_t port; + getdns_bindata *scope_id; + struct addrinfo *ai; + getdns_upstream *upstream; - upstream = &upstreams->upstreams[upstreams->count]; - if ((r = getdns_list_get_dict(upstream_list, i, &dict))) - goto error; + upstream = &upstreams->upstreams[upstreams->count]; + if ((r = getdns_list_get_dict(upstream_list, i, &dict))) + goto error; - if ((r = getdns_dict_get_bindata( - dict, "address_type",&address_type))) - goto error; - if (address_type->size < 4) - goto invalid_parameter; - if (strncmp((char *)address_type->data, "IPv4", 4) == 0) - upstream->addr.ss_family = AF_INET; - else if (strncmp((char *)address_type->data, "IPv6", 4) == 0) - upstream->addr.ss_family = AF_INET6; - else goto invalid_parameter; - - if ((r = getdns_dict_get_bindata( - dict, "address_data", &address_data))) - goto error; - if ((upstream->addr.ss_family == AF_INET && - address_data->size != 4) || - (upstream->addr.ss_family == AF_INET6 && - address_data->size != 16)) - goto invalid_parameter; - if (inet_ntop(upstream->addr.ss_family, address_data->data, - addrstr, 1024) == NULL) - goto invalid_parameter; - - port = 53; - (void) getdns_dict_get_int(dict, "port", &port); - (void) snprintf(portstr, 1024, "%d", (int)port); - - if (getdns_dict_get_bindata(dict, "scope_id", &scope_id) == - GETDNS_RETURN_GOOD) { - if (strlen(addrstr) + scope_id->size > 1022) + if ((r = getdns_dict_get_bindata( + dict, "address_type",&address_type))) + goto error; + if (address_type->size < 4) goto invalid_parameter; - eos = &addrstr[strlen(addrstr)]; - *eos++ = '%'; - (void) memcpy(eos, scope_id->data, scope_id->size); - eos[scope_id->size] = 0; + if (strncmp((char *)address_type->data, "IPv4", 4) == 0) + upstream->addr.ss_family = AF_INET; + else if (strncmp((char *)address_type->data, "IPv6", 4) == 0) + upstream->addr.ss_family = AF_INET6; + else goto invalid_parameter; + + if ((r = getdns_dict_get_bindata( + dict, "address_data", &address_data))) + goto error; + if ((upstream->addr.ss_family == AF_INET && + address_data->size != 4) || + (upstream->addr.ss_family == AF_INET6 && + address_data->size != 16)) + goto invalid_parameter; + if (inet_ntop(upstream->addr.ss_family, address_data->data, + addrstr, 1024) == NULL) + goto invalid_parameter; + + /* So should we be throwing away the port the user set?*/ + port = (uint32_t)(int)getdns_port_array[port_type]; + (void) getdns_dict_get_int(dict, "port", &port); + (void) snprintf(portstr, 1024, "%d", (int)port); + + if (getdns_dict_get_bindata(dict, "scope_id", &scope_id) == + GETDNS_RETURN_GOOD) { + if (strlen(addrstr) + scope_id->size > 1022) + goto invalid_parameter; + eos = &addrstr[strlen(addrstr)]; + *eos++ = '%'; + (void) memcpy(eos, scope_id->data, scope_id->size); + eos[scope_id->size] = 0; + } + + if (getaddrinfo(addrstr, portstr, &hints, &ai)) + goto invalid_parameter; + + upstream_init(upstream, upstreams, ai); + upstreams->count++; + freeaddrinfo(ai); } - - if (getaddrinfo(addrstr, portstr, &hints, &ai)) - goto invalid_parameter; - - upstream_init(upstream, upstreams, ai); - upstreams->count++; - freeaddrinfo(ai); } priv_getdns_upstreams_dereference(context->upstreams); context->upstreams = upstreams; @@ -1729,6 +1752,7 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) getdns_upstreams *upstreams = context->upstreams; (void) ub_ctx_set_fwd(ctx, NULL); + /*TODO[TLS]: Order the upstreams so the TLS ones are first if doing TLS*/ for (i = 0; i < upstreams->count; i++) { upstream = &upstreams->upstreams[i]; upstream_ntop_buf(upstream, context->dns_transport, addr, 1024); diff --git a/src/context.h b/src/context.h index 12a2c263..f2c1f698 100644 --- a/src/context.h +++ b/src/context.h @@ -49,7 +49,10 @@ struct ub_ctx; #define GETDNS_FN_RESOLVCONF "/etc/resolv.conf" #define GETDNS_FN_HOSTS "/etc/hosts" -#define GETDNS_TLS_PORT 1021 +#define GETDNS_PORT_NUM_TCP 53 +#define GETDNS_PORT_NUM_TLS 1021 +#define GETDNS_PORT_STR_TCP "53" +#define GETDNS_PORT_STR_TLS "1021" enum filechgs { GETDNS_FCHG_ERRORS = -1 , GETDNS_FCHG_NOERROR = 0 @@ -80,6 +83,21 @@ typedef enum getdns_base_transport { GETDNS_TRANSPORT_TLS } getdns_base_transport_t; +typedef enum getdns_port_type { + GETDNS_PORT_FIRST = 0, + GETDNS_PORT_TCP = 0, + GETDNS_PORT_TLS = 1, + GETDNS_PORT_LAST = 2 +} getdns_port_type_t; + +typedef enum getdns_tls_hs_state { + GETDNS_HS_NONE, + GETDNS_HS_WRITE, + GETDNS_HS_READ, + GETDNS_HS_DONE, + GETDNS_HS_FAILED +} getdns_tls_hs_state_t; + typedef struct getdns_upstream { struct getdns_upstreams *upstreams; @@ -93,6 +111,8 @@ typedef struct getdns_upstream { /* For sharing a TCP socket to this upstream */ int fd; SSL* tls_obj; + getdns_base_transport_t base_transport; + getdns_tls_hs_state_t tls_hs_state; getdns_eventloop_event event; getdns_eventloop *loop; getdns_tcp_state tcp; diff --git a/src/request-internal.c b/src/request-internal.c index 0889b3fc..e9bb3b22 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -89,6 +89,7 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner, net_req->upstream = NULL; net_req->fd = -1; + net_req->transport = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP; memset(&net_req->event, 0, sizeof(net_req->event)); memset(&net_req->tcp, 0, sizeof(net_req->tcp)); net_req->query_id = 0; diff --git a/src/stub.c b/src/stub.c index 0914aa57..775ff9ad 100755 --- a/src/stub.c +++ b/src/stub.c @@ -41,10 +41,23 @@ #include "util-internal.h" #include "general.h" +#define STUB_TLS_SETUP_ERROR -3 +#define STUB_TCP_AGAIN -2 +#define STUB_TCP_ERROR -1 + static time_t secret_rollover_time = 0; static uint32_t secret = 0; static uint32_t prev_secret = 0; +static void upstream_read_cb(void *userarg); +static void upstream_write_cb(void *userarg); +static int tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport); +static int connect_to_upstream(getdns_upstream *upstream, + getdns_base_transport_t transport, + getdns_context *context); +static void upstream_schedule_netreq(getdns_upstream *upstream, + getdns_network_req *netreq); + static void rollover_secret() { @@ -305,7 +318,7 @@ static void upstream_erred(getdns_upstream *upstream) { getdns_network_req *netreq; - + fprintf(stderr,"[TLS]: upstream_erred\n"); while ((netreq = upstream->write_queue)) { stub_cleanup(netreq); netreq->state = NET_REQ_FINISHED; @@ -318,8 +331,6 @@ upstream_erred(getdns_upstream *upstream) netreq->state = NET_REQ_FINISHED; priv_getdns_check_dns_req_complete(netreq->owner); } - /* TODO[TLS]: When we get an error (which is probably a timeout) and are - * using to keep connections open should we leave the connection up here? */ if (upstream->tls_obj) { SSL_shutdown(upstream->tls_obj); SSL_free(upstream->tls_obj); @@ -327,6 +338,7 @@ upstream_erred(getdns_upstream *upstream) } close(upstream->fd); upstream->fd = -1; + /*TODO[TLS]: Upstream errors don't trigger the user callback....*/ } void @@ -339,8 +351,11 @@ priv_getdns_cancel_stub_request(getdns_network_req *netreq) static void stub_erred(getdns_network_req *netreq) { + fprintf(stderr,"[TLS]: stub_erred\n"); stub_next_upstream(netreq); stub_cleanup(netreq); + /* TODO[TLS]: When we get an error (which is probably a timeout) and are + * using to keep connections open should we leave the connection up here? */ if (netreq->fd >= 0) close(netreq->fd); netreq->state = NET_REQ_FINISHED; priv_getdns_check_dns_req_complete(netreq->owner); @@ -349,6 +364,7 @@ stub_erred(getdns_network_req *netreq) static void stub_timeout_cb(void *userarg) { + fprintf(stderr,"[TLS]: stub_timeout_cb\n"); getdns_network_req *netreq = (getdns_network_req *)userarg; stub_next_upstream(netreq); @@ -460,8 +476,19 @@ stub_udp_write_cb(void *userarg) stub_udp_read_cb, NULL, stub_timeout_cb)); } + +static int +transport_matches(struct getdns_upstream *upstream, getdns_base_transport_t transport) { + if (upstream->base_transport != transport) + return 0; + if (transport == GETDNS_TRANSPORT_TLS && + upstream->tls_hs_state == GETDNS_HS_FAILED) + return 0; + return 1; +} + static getdns_upstream * -pick_upstream(getdns_dns_req *dnsreq) +pick_upstream(getdns_dns_req *dnsreq, int level) { getdns_upstream *upstream; size_t i; @@ -469,13 +496,17 @@ pick_upstream(getdns_dns_req *dnsreq) if (!dnsreq->upstreams->count) return NULL; + getdns_base_transport_t transport = priv_get_base_transport( + dnsreq->context->dns_transport, level); + for (i = 0; i < dnsreq->upstreams->count; i++) if (dnsreq->upstreams->upstreams[i].to_retry <= 0) dnsreq->upstreams->upstreams[i].to_retry++; i = dnsreq->upstreams->current; do { - if (dnsreq->upstreams->upstreams[i].to_retry > 0) { + if (dnsreq->upstreams->upstreams[i].to_retry > 0 && + transport_matches(&dnsreq->upstreams->upstreams[i], transport)) { dnsreq->upstreams->current = i; return &dnsreq->upstreams->upstreams[i]; } @@ -485,8 +516,8 @@ pick_upstream(getdns_dns_req *dnsreq) upstream = dnsreq->upstreams->upstreams; for (i = 1; i < dnsreq->upstreams->count; i++) - if (dnsreq->upstreams->upstreams[i].back_off < - upstream->back_off) + if (dnsreq->upstreams->upstreams[i].back_off < upstream->back_off && + transport_matches(&dnsreq->upstreams->upstreams[i], transport)) upstream = &dnsreq->upstreams->upstreams[i]; upstream->back_off++; @@ -495,9 +526,6 @@ pick_upstream(getdns_dns_req *dnsreq) return upstream; } -#define STUB_TCP_AGAIN -2 -#define STUB_TCP_ERROR -1 - static int stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf) { @@ -602,117 +630,174 @@ stub_tcp_read_cb(void *userarg) } } -/** wait for a socket to become ready */ -static int -sock_wait(int sockfd) -{ - int ret; - fd_set fds; - FD_ZERO(&fds); - FD_SET(FD_SET_T sockfd, &fds); - /*TODO[TLS]: Pick up this timeout from the context*/ - struct timeval timeout = {5, 0 }; - ret = select(sockfd+1, NULL, &fds, NULL, &timeout); - if(ret == 0) - /* timeout expired */ - return 0; - else if(ret == -1) - /* error */ - return 0; - return 1; -} - -static int -sock_connected(int sockfd) -{ - /* wait(write) until connected or error */ - while(1) { - int error = 0; - socklen_t len = (socklen_t)sizeof(error); - - if(!sock_wait(sockfd)) { - close(sockfd); - return -1; - } - - /* check if there is a pending error for nonblocking connect */ - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, (void*)&error, &len) < 0) { - error = errno; /* on solaris errno is error */ - } - if (error == EINPROGRESS || error == EWOULDBLOCK) - continue; /* try again */ - else if (error != 0) { - close(sockfd); - return -1; - } - /* connected */ - break; - } - return sockfd; -} - -/* The connection testing and handshake should be handled by integrating this - * with the event loop framework, but for now just implement a standalone - * handshake method.*/ static SSL* -do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream) +create_tls_object(getdns_context *context, int fd) { - /*Lets make sure the connection is up before we try a handshake*/ - if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) { - return NULL; - } - /* Create SSL instance */ - if (dnsreq->context->tls_ctx == NULL) + if (context->tls_ctx == NULL) return NULL; - SSL* ssl = SSL_new(dnsreq->context->tls_ctx); + SSL* ssl = SSL_new(context->tls_ctx); if(!ssl) { return NULL; } /* Connect the SSL object with a file descriptor */ - if(!SSL_set_fd(ssl, upstream->fd)) { + if(!SSL_set_fd(ssl,fd)) { SSL_free(ssl); return NULL; } SSL_set_connect_state(ssl); (void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY); - - int r; - int want; - fd_set fds; - FD_ZERO(&fds); - FD_SET(upstream->fd, &fds); - struct timeval timeout = {dnsreq->context->timeout/1000, 0 }; - while ((r = SSL_do_handshake(ssl)) != 1) - { - want = SSL_get_error(ssl, r); - switch (want) { - case SSL_ERROR_WANT_READ: - if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) { - SSL_free(ssl); - return NULL; - } - break; - case SSL_ERROR_WANT_WRITE: - if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) { - SSL_free(ssl); - return NULL; - } - break; - default: - SSL_free(ssl); - return NULL; - } - } return ssl; } static int -stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf) +do_tls_handshake(getdns_upstream *upstream) +{ + + fprintf(stderr,"[TLS]: do_tls_handshake\n"); + + int r; + int want; + ERR_clear_error(); + while ((r = SSL_do_handshake(upstream->tls_obj)) != 1) + { + want = SSL_get_error(upstream->tls_obj, r); + switch (want) { + case SSL_ERROR_WANT_READ: + fprintf(stderr,"[TLS]: SSL_ERROR_WANT_READ\n"); + upstream->event.read_cb = upstream_read_cb; + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_READ; + return STUB_TCP_AGAIN; + case SSL_ERROR_WANT_WRITE: + fprintf(stderr,"[TLS]: SSL_ERROR_WANT_WRITE\n"); + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_WRITE; + return STUB_TCP_AGAIN; + default: + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + upstream->tls_hs_state = GETDNS_HS_FAILED; + upstream->fd = -1; + return STUB_TLS_SETUP_ERROR; + } + } + upstream->tls_hs_state = GETDNS_HS_DONE; + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + return 0; +} + +/* TODO[TLS]: Could think about fallback on read error aswell.*/ +static int +fallback_on_write(getdns_network_req *netreq) { + + /* This should really check if any request in the queue can fallback...*/ + if (netreq->transport != GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN) + return STUB_TCP_ERROR; + + /* Deal with old upstream */ + getdns_upstream *upstream = netreq->upstream; + upstream->write_queue = NULL; + upstream->write_queue_last = NULL; + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + + /* Now set up new upstream */ + getdns_upstream *new_upstream = pick_upstream(netreq->owner, 1); + if (!new_upstream) + return STUB_TCP_ERROR; + + /* get transport generically*/ + int fd = connect_to_upstream(new_upstream, GETDNS_TRANSPORT_TCP, netreq->owner->context); + if (fd == -1) + return STUB_TCP_ERROR; + + fprintf(stderr,"[TLS]: tcp_fallback to %d \n", new_upstream->fd); + getdns_network_req *next_req; + while (netreq != NULL) { + next_req = netreq->write_queue_tail; + if (netreq->transport == GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN) { + netreq->upstream = new_upstream; + upstream_schedule_netreq(new_upstream, netreq); + /* TODO: Timout need to be adjusted and rescheduled on the new fd ....*/ + /* Note, setup timeout should be shorter than message timeout for + * messages with fallback or don't have time to re-try. */ + } + /*else.... leave request to timeout?*/ + netreq = next_req; + } + + return STUB_TCP_AGAIN; +} + +static int +setup_tls(getdns_upstream* upstream) +{ + int ret; + /* Already have a connection*/ + if (upstream->tls_hs_state == GETDNS_HS_DONE && + (upstream->tls_obj != NULL) && (upstream->fd != -1)) + return 0; + + /* Lets make sure the connection is up before we try a handshake*/ + int error = 0; + socklen_t len = (socklen_t)sizeof(error); + /* TODO: This doesn't handle the case where the far end doesn't do a reset + * as is the case with e.g. 8.8.8.8. For that case the timeout kicks in + * and the user callback fails the message without the chance to fallback... + * Note that acutally the TCP code doesn't check the connection state before + * doing a first write either.... + * Perhaps we should have a write_timeout_cb on the write and then schedule + * the stub_timeout_cb for matching the response??? */ + getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); + if (error == EINPROGRESS || error == EWOULDBLOCK) { + fprintf(stderr,"[TLS]: blocking.......\n"); + return STUB_TCP_AGAIN; /* try again */ + } + else if (error != 0) { + fprintf(stderr,"[TLS]: died gettting connection\n"); + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + upstream->tls_hs_state = GETDNS_HS_FAILED; + upstream->fd = -1; + return STUB_TLS_SETUP_ERROR; + } + + ret = do_tls_handshake(upstream); + switch (ret) { + case STUB_TCP_AGAIN: + return ret; + case STUB_TCP_ERROR: + fprintf(stderr,"[TLS]: W: Handshake has failed %d\n", upstream->tls_hs_state); + return STUB_TLS_SETUP_ERROR; + default: + fprintf(stderr,"[TLS]: W:after handshake %d, %s\n", upstream->tls_hs_state, upstream->tls_obj== NULL? "NULL":"Not NULL" ); + return 0; + } +} + +static int +stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, struct mem_funcs *mf) { ssize_t read; uint8_t *buf; size_t buf_size; + SSL* tls_obj = upstream->tls_obj; + + int q = setup_tls(upstream); + if (q != 0) + return q; if (!tcp->read_buf) { /* First time tls read, create a buffer for reading */ @@ -794,9 +879,12 @@ upstream_read_cb(void *userarg) int q; uint16_t query_id; intptr_t query_id_intptr; + + + fprintf(stderr,"[TLS]: upstream_read_cb on %d\n", upstream->fd); if (upstream->tls_obj) - q = stub_tls_read(upstream->tls_obj, &upstream->tcp, + q = stub_tls_read(upstream, &upstream->tcp, &upstream->upstreams->mf); else q = stub_tcp_read(upstream->fd, &upstream->tcp, @@ -1018,12 +1106,17 @@ stub_tcp_write_cb(void *userarg) } static int -stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq) +stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, getdns_network_req *netreq) { size_t pkt_len = netreq->response - netreq->query; ssize_t written; uint16_t query_id; intptr_t query_id_intptr; + SSL* tls_obj = upstream->tls_obj; + + int q = setup_tls(upstream); + if (q != 0) + return q; /* Do we have remaining data that we could not write before? */ if (! tcp->write_buf) { @@ -1073,8 +1166,11 @@ upstream_write_cb(void *userarg) getdns_dns_req *dnsreq = netreq->owner; int q; + + fprintf(stderr,"[TLS]: method: upstream_write_cb %d\n", upstream->fd); + if (upstream->tls_obj) - q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq); + q = stub_tls_write(upstream, &upstream->tcp, netreq); else q = stub_tcp_write(upstream->fd, &upstream->tcp, netreq); @@ -1086,8 +1182,16 @@ upstream_write_cb(void *userarg) stub_erred(netreq); return; + case STUB_TLS_SETUP_ERROR: + /* Could not complete the TLS set up. Need to fallback on this upstream + * if possible.*/ + if (fallback_on_write(netreq) == STUB_TCP_ERROR) + stub_erred(netreq); + return; + default: netreq->query_id = (uint16_t) q; + fprintf(stderr,"[TLS]: method: upstream_write_cb, successfull write %d\n", upstream->fd); /* Unqueue the netreq from the write_queue */ if (!(upstream->write_queue = netreq->write_queue_tail)) { @@ -1137,6 +1241,8 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) assert(upstream->fd >= 0); assert(upstream->loop); + fprintf(stderr,"[TLS]: method: upstream_schedule_netreq %d\n", upstream->fd); + /* Append netreq to write_queue */ if (!upstream->write_queue) { upstream->write_queue = upstream->write_queue_last = netreq; @@ -1150,39 +1256,12 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) } } -static in_port_t -get_port(struct sockaddr_storage* addr) -{ - return ntohs(addr->ss_family == AF_INET - ? ((struct sockaddr_in *)addr)->sin_port - : ((struct sockaddr_in6*)addr)->sin6_port); -} - -static void -set_port(struct sockaddr_storage* addr, in_port_t port) -{ - addr->ss_family == AF_INET - ? (((struct sockaddr_in *)addr)->sin_port = htons(port)) - : (((struct sockaddr_in6*)addr)->sin6_port = htons(port)); -} - static int -tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) { +tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +{ - int fd =-1; - struct sockaddr_storage connect_addr; - struct sockaddr_storage* addr = &upstream->addr; - socklen_t addr_len = upstream->addr_len; - - /* TODO[TLS]: For now, override the port to a hardcoded value*/ - if (transport == GETDNS_TRANSPORT_TLS && - (int)get_port(addr) != GETDNS_TLS_PORT) { - connect_addr = upstream->addr; - addr = &connect_addr; - set_port(addr, GETDNS_TLS_PORT); - } - - if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + int fd = -1; + if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) return -1; getdns_sock_nonblock(fd); @@ -1192,8 +1271,8 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) { transport == GETDNS_TRANSPORT_TCP_SINGLE) return fd; #endif - if (connect(fd, (struct sockaddr *)addr, - addr_len) == -1) { + if (connect(fd, (struct sockaddr *)&upstream->addr, + upstream->addr_len) == -1) { if (errno != EINPROGRESS) { close(fd); return -1; @@ -1202,106 +1281,103 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) { return fd; } +int +connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport, + getdns_context *context) +{ + + if ((transport == GETDNS_TRANSPORT_TCP || + transport == GETDNS_TRANSPORT_TLS) + && upstream->fd != -1) { + fprintf(stderr,"[TLS]: method: tcp_connect using existing fd %d\n", upstream->fd); + return upstream->fd; + } + + int fd; + switch(transport) { + case GETDNS_TRANSPORT_UDP: + if ((fd = socket( + upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) + return -1; + getdns_sock_nonblock(fd); + return fd; + + case GETDNS_TRANSPORT_TCP_SINGLE: + case GETDNS_TRANSPORT_TCP: + fd = tcp_connect(upstream, transport); + break; + + case GETDNS_TRANSPORT_TLS: + fd = tcp_connect(upstream, transport); + if (fd == -1 || + (upstream->tls_obj = create_tls_object(context, fd)) == NULL ) { + close(fd); + return -1; + } + upstream->tls_hs_state = GETDNS_HS_WRITE; + break; + default: + return -1; + /* Nothing to do*/ + } + if (fd != -1) { + upstream->loop = context->extension; + upstream->fd = fd; + } + fprintf(stderr,"[TLS]: method: tcp_connect created new connection %d\n", fd); + return fd; +} + getdns_return_t priv_getdns_submit_stub_request(getdns_network_req *netreq) { - getdns_dns_req *dnsreq = netreq->owner; - getdns_upstream *upstream = pick_upstream(dnsreq); + getdns_dns_req *dnsreq = netreq->owner; - if (!upstream) - return GETDNS_RETURN_GENERIC_ERROR; - - // Work out the primary and fallback transport options + /* TODO[TLS - 1]: This will become a double while loop trying all the upstreams on all the + * transports for a connection since we need a fd to schedule on, using previous known capabilities + * All other set up is done async*/ + /* Work out the primary and fallback transport options */ getdns_base_transport_t transport = priv_get_base_transport( dnsreq->context->dns_transport,0); getdns_base_transport_t fb_transport = priv_get_base_transport( dnsreq->context->dns_transport,1); + getdns_upstream *upstream = pick_upstream(dnsreq, 0); + if (!upstream) + return GETDNS_RETURN_GENERIC_ERROR; + int fd = connect_to_upstream(upstream, transport, dnsreq->context); + if (fd == -1) { + if (fb_transport == GETDNS_TRANSPORT_NONE) + return GETDNS_RETURN_GENERIC_ERROR; + upstream = pick_upstream(dnsreq, 1); + if ((fd = connect_to_upstream(upstream, fb_transport, dnsreq->context)) == -1) + return GETDNS_RETURN_GENERIC_ERROR; + } + + netreq->upstream = upstream; + netreq->transport = dnsreq->context->dns_transport; + switch(transport) { case GETDNS_TRANSPORT_UDP: - - if ((netreq->fd = socket( - upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - - getdns_sock_nonblock(netreq->fd); - netreq->upstream = upstream; - - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_udp_write_cb, stub_timeout_cb)); - - return GETDNS_RETURN_GOOD; - case GETDNS_TRANSPORT_TCP_SINGLE: - - if ((netreq->fd = tcp_connect(upstream, transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - netreq->upstream = upstream; - + netreq->fd = fd; GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, - NULL, stub_tcp_write_cb, stub_timeout_cb)); - + NULL, (transport == GETDNS_TRANSPORT_UDP ? stub_udp_write_cb: + stub_tcp_write_cb), stub_timeout_cb)); return GETDNS_RETURN_GOOD; case GETDNS_TRANSPORT_TCP: case GETDNS_TRANSPORT_TLS: /* In coming comments, "global" means "context wide" */ - - /* Are we the first? (Is global socket initialized?) */ if (upstream->fd == -1) { - /* TODO[TLS]: We should remember on the context if we had to fallback - * for this upstream so when re-connecting from a dropped TCP - * connection we don't retry TLS. */ - int fallback = 0; - - /* We are the first. Make global socket and connect. */ - if ((upstream->fd = tcp_connect(upstream, transport)) == -1) { - if (fb_transport == GETDNS_TRANSPORT_NONE) - return GETDNS_RETURN_GENERIC_ERROR; - if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - fallback = 1; - } - - /* Now do a handshake for TLS. Note waiting for this to succeed or - * timeout blocks the scheduling of any messages for this upstream*/ - if (transport == GETDNS_TRANSPORT_TLS && (fallback == 0)) { - upstream->tls_obj = do_tls_handshake(dnsreq, upstream); - if (!upstream->tls_obj) { - if (fb_transport == GETDNS_TRANSPORT_NONE) - return GETDNS_RETURN_GENERIC_ERROR; - close(upstream->fd); - if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1) - return GETDNS_RETURN_GENERIC_ERROR; - } - } - /* Attach to the global event loop - * so it can do it's own scheduling - */ upstream->loop = dnsreq->context->extension; - } else { - /* Cater for the case of the user downgrading and existing TLS - connection to TCP for some reason...*/ - if (transport == GETDNS_TRANSPORT_TCP && upstream->tls_obj) { - SSL_shutdown(upstream->tls_obj); - SSL_free(upstream->tls_obj); - upstream->tls_obj = NULL; - } + upstream->fd = fd; } - netreq->upstream = upstream; - - /* We have a context wide socket. - * Now schedule the write request. - */ upstream_schedule_netreq(upstream, netreq); - - /* Schedule at least the timeout locally. - * And also the write if we perform a synchronous lookup - */ + /* TODO[TLS]: Timeout handling for async calls must change.... + * Maybe even change scheduling for sync calls here too*/ GETDNS_SCHEDULE_EVENT( dnsreq->loop, upstream->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, NULL, @@ -1314,4 +1390,4 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq) } } -/* stub.c */ +/* stub.c */ \ No newline at end of file diff --git a/src/types-internal.h b/src/types-internal.h index 5220d599..7ca48d7b 100644 --- a/src/types-internal.h +++ b/src/types-internal.h @@ -191,6 +191,7 @@ typedef struct getdns_network_req /* For stub resolving */ struct getdns_upstream *upstream; int fd; + getdns_transport_t transport; getdns_eventloop_event event; getdns_tcp_state tcp; uint16_t query_id;