First pass at making handshake async. Lots of issues with this code still

- timeouts are not being rescheduled on fallback
- several error cases are not being handled correctly (e.g. 8.8.8.8) and a user callback is not always called
- the fallback mechanism is not generic (specific to tls to tcp)
This commit is contained in:
Sara Dickinson 2015-04-19 17:16:58 +01:00
parent 2a6fc74314
commit f2ae55858f
5 changed files with 405 additions and 283 deletions

View File

@ -62,6 +62,18 @@ typedef struct host_name_addrs {
uint8_t host_name[];
} host_name_addrs;
static in_port_t
getdns_port_array[GETDNS_PORT_LAST] = {
GETDNS_PORT_NUM_TCP,
GETDNS_PORT_NUM_TLS
};
// char*
// getdns_port_str_array[] = {
// "53",
// "1021"
// };
/* Private functions */
getdns_return_t create_default_namespaces(struct getdns_context *context);
static struct getdns_list *create_default_root_servers(void);
@ -240,7 +252,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa)
break;
port = ntohs(((struct sockaddr_in *)sa)->sin_port);
if (port != 0 && port != 53 &&
if (port != 0 && port != GETDNS_PORT_NUM_TCP &&
getdns_dict_set_int(address, "port", (uint32_t)port))
break;
@ -256,7 +268,7 @@ sockaddr_dict(getdns_context *context, struct sockaddr *sa)
break;
port = ntohs(((struct sockaddr_in6 *)sa)->sin6_port);
if (port != 0 && port != 53 &&
if (port != 0 && port != GETDNS_PORT_NUM_TCP &&
getdns_dict_set_int(address, "port", (uint32_t)port))
break;
@ -527,10 +539,7 @@ upstream_ntop_buf(getdns_upstream *upstream, getdns_transport_t transport,
if (upstream_scope_id(upstream))
(void) snprintf(buf + strlen(buf), len - strlen(buf),
"%%%d", (int)*upstream_scope_id(upstream));
if (transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN)
(void) snprintf(buf + strlen(buf), len - strlen(buf),
"@%d", GETDNS_TLS_PORT);
else if (upstream_port(upstream) != 53 && upstream_port(upstream) != 0)
else if (upstream_port(upstream) != GETDNS_PORT_NUM_TCP && upstream_port(upstream) != 0)
(void) snprintf(buf + strlen(buf), len - strlen(buf),
"@%d", (int)upstream_port(upstream));
}
@ -557,6 +566,10 @@ upstream_init(getdns_upstream *upstream,
/* For sharing a socket to this upstream with TCP */
upstream->fd = -1;
upstream->tls_obj = NULL;
upstream->base_transport = (upstream_port(upstream) == GETDNS_PORT_NUM_TLS ?
GETDNS_TRANSPORT_TLS :
GETDNS_TRANSPORT_TCP);
upstream->tls_hs_state = GETDNS_HS_NONE;
upstream->loop = NULL;
(void) getdns_eventloop_event_init(
&upstream->event, upstream, NULL, NULL, NULL);
@ -659,7 +672,11 @@ set_os_defaults(struct getdns_context *context)
token = parse + strcspn(parse, " \t\r\n");
*token = 0;
if ((s = getaddrinfo(parse, "53", &hints, &result)))
//getdns_port_type_t port_type = GETDNS_PORT_FIRST;
//for (; port_type < GETDNS_PORT_LAST; port_type++) {
// TODO[TLS]: Seeing strange crash in ub_create_ctx when using the loop here....
//fprintf(stderr,"creating upstream %s\n", parse);
if ((s = getaddrinfo(parse, "53", /*getdns_port_str_array[port_type],*/ &hints, &result)))
continue;
/* No lookups, so maximal 1 result */
@ -673,6 +690,7 @@ set_os_defaults(struct getdns_context *context)
upstream = &context->upstreams->
upstreams[context->upstreams->count++];
upstream_init(upstream, context->upstreams, result);
//}
freeaddrinfo(result);
}
fclose(in);
@ -1456,8 +1474,11 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context,
hints.ai_addr = NULL;
hints.ai_next = NULL;
upstreams = upstreams_create(context, count);
upstreams = upstreams_create(context, count*2);
for (i = 0; i < count; i++) {
/* Loop twice to create TCP and TLS upstreams*/
getdns_port_type_t port_type = GETDNS_PORT_FIRST;
for (; port_type < GETDNS_PORT_LAST; port_type++) {
getdns_dict *dict;
getdns_bindata *address_type;
getdns_bindata *address_data;
@ -1493,7 +1514,8 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context,
addrstr, 1024) == NULL)
goto invalid_parameter;
port = 53;
/* So should we be throwing away the port the user set?*/
port = (uint32_t)(int)getdns_port_array[port_type];
(void) getdns_dict_get_int(dict, "port", &port);
(void) snprintf(portstr, 1024, "%d", (int)port);
@ -1514,6 +1536,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context,
upstreams->count++;
freeaddrinfo(ai);
}
}
priv_getdns_upstreams_dereference(context->upstreams);
context->upstreams = upstreams;
dispatch_updated(context,
@ -1729,6 +1752,7 @@ ub_setup_stub(struct ub_ctx *ctx, getdns_context *context)
getdns_upstreams *upstreams = context->upstreams;
(void) ub_ctx_set_fwd(ctx, NULL);
/*TODO[TLS]: Order the upstreams so the TLS ones are first if doing TLS*/
for (i = 0; i < upstreams->count; i++) {
upstream = &upstreams->upstreams[i];
upstream_ntop_buf(upstream, context->dns_transport, addr, 1024);

View File

@ -49,7 +49,10 @@ struct ub_ctx;
#define GETDNS_FN_RESOLVCONF "/etc/resolv.conf"
#define GETDNS_FN_HOSTS "/etc/hosts"
#define GETDNS_TLS_PORT 1021
#define GETDNS_PORT_NUM_TCP 53
#define GETDNS_PORT_NUM_TLS 1021
#define GETDNS_PORT_STR_TCP "53"
#define GETDNS_PORT_STR_TLS "1021"
enum filechgs { GETDNS_FCHG_ERRORS = -1
, GETDNS_FCHG_NOERROR = 0
@ -80,6 +83,21 @@ typedef enum getdns_base_transport {
GETDNS_TRANSPORT_TLS
} getdns_base_transport_t;
typedef enum getdns_port_type {
GETDNS_PORT_FIRST = 0,
GETDNS_PORT_TCP = 0,
GETDNS_PORT_TLS = 1,
GETDNS_PORT_LAST = 2
} getdns_port_type_t;
typedef enum getdns_tls_hs_state {
GETDNS_HS_NONE,
GETDNS_HS_WRITE,
GETDNS_HS_READ,
GETDNS_HS_DONE,
GETDNS_HS_FAILED
} getdns_tls_hs_state_t;
typedef struct getdns_upstream {
struct getdns_upstreams *upstreams;
@ -93,6 +111,8 @@ typedef struct getdns_upstream {
/* For sharing a TCP socket to this upstream */
int fd;
SSL* tls_obj;
getdns_base_transport_t base_transport;
getdns_tls_hs_state_t tls_hs_state;
getdns_eventloop_event event;
getdns_eventloop *loop;
getdns_tcp_state tcp;

View File

@ -89,6 +89,7 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner,
net_req->upstream = NULL;
net_req->fd = -1;
net_req->transport = GETDNS_TRANSPORT_UDP_FIRST_AND_FALL_BACK_TO_TCP;
memset(&net_req->event, 0, sizeof(net_req->event));
memset(&net_req->tcp, 0, sizeof(net_req->tcp));
net_req->query_id = 0;

View File

@ -41,10 +41,23 @@
#include "util-internal.h"
#include "general.h"
#define STUB_TLS_SETUP_ERROR -3
#define STUB_TCP_AGAIN -2
#define STUB_TCP_ERROR -1
static time_t secret_rollover_time = 0;
static uint32_t secret = 0;
static uint32_t prev_secret = 0;
static void upstream_read_cb(void *userarg);
static void upstream_write_cb(void *userarg);
static int tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport);
static int connect_to_upstream(getdns_upstream *upstream,
getdns_base_transport_t transport,
getdns_context *context);
static void upstream_schedule_netreq(getdns_upstream *upstream,
getdns_network_req *netreq);
static void
rollover_secret()
{
@ -305,7 +318,7 @@ static void
upstream_erred(getdns_upstream *upstream)
{
getdns_network_req *netreq;
fprintf(stderr,"[TLS]: upstream_erred\n");
while ((netreq = upstream->write_queue)) {
stub_cleanup(netreq);
netreq->state = NET_REQ_FINISHED;
@ -318,8 +331,6 @@ upstream_erred(getdns_upstream *upstream)
netreq->state = NET_REQ_FINISHED;
priv_getdns_check_dns_req_complete(netreq->owner);
}
/* TODO[TLS]: When we get an error (which is probably a timeout) and are
* using to keep connections open should we leave the connection up here? */
if (upstream->tls_obj) {
SSL_shutdown(upstream->tls_obj);
SSL_free(upstream->tls_obj);
@ -327,6 +338,7 @@ upstream_erred(getdns_upstream *upstream)
}
close(upstream->fd);
upstream->fd = -1;
/*TODO[TLS]: Upstream errors don't trigger the user callback....*/
}
void
@ -339,8 +351,11 @@ priv_getdns_cancel_stub_request(getdns_network_req *netreq)
static void
stub_erred(getdns_network_req *netreq)
{
fprintf(stderr,"[TLS]: stub_erred\n");
stub_next_upstream(netreq);
stub_cleanup(netreq);
/* TODO[TLS]: When we get an error (which is probably a timeout) and are
* using to keep connections open should we leave the connection up here? */
if (netreq->fd >= 0) close(netreq->fd);
netreq->state = NET_REQ_FINISHED;
priv_getdns_check_dns_req_complete(netreq->owner);
@ -349,6 +364,7 @@ stub_erred(getdns_network_req *netreq)
static void
stub_timeout_cb(void *userarg)
{
fprintf(stderr,"[TLS]: stub_timeout_cb\n");
getdns_network_req *netreq = (getdns_network_req *)userarg;
stub_next_upstream(netreq);
@ -460,8 +476,19 @@ stub_udp_write_cb(void *userarg)
stub_udp_read_cb, NULL, stub_timeout_cb));
}
static int
transport_matches(struct getdns_upstream *upstream, getdns_base_transport_t transport) {
if (upstream->base_transport != transport)
return 0;
if (transport == GETDNS_TRANSPORT_TLS &&
upstream->tls_hs_state == GETDNS_HS_FAILED)
return 0;
return 1;
}
static getdns_upstream *
pick_upstream(getdns_dns_req *dnsreq)
pick_upstream(getdns_dns_req *dnsreq, int level)
{
getdns_upstream *upstream;
size_t i;
@ -469,13 +496,17 @@ pick_upstream(getdns_dns_req *dnsreq)
if (!dnsreq->upstreams->count)
return NULL;
getdns_base_transport_t transport = priv_get_base_transport(
dnsreq->context->dns_transport, level);
for (i = 0; i < dnsreq->upstreams->count; i++)
if (dnsreq->upstreams->upstreams[i].to_retry <= 0)
dnsreq->upstreams->upstreams[i].to_retry++;
i = dnsreq->upstreams->current;
do {
if (dnsreq->upstreams->upstreams[i].to_retry > 0) {
if (dnsreq->upstreams->upstreams[i].to_retry > 0 &&
transport_matches(&dnsreq->upstreams->upstreams[i], transport)) {
dnsreq->upstreams->current = i;
return &dnsreq->upstreams->upstreams[i];
}
@ -485,8 +516,8 @@ pick_upstream(getdns_dns_req *dnsreq)
upstream = dnsreq->upstreams->upstreams;
for (i = 1; i < dnsreq->upstreams->count; i++)
if (dnsreq->upstreams->upstreams[i].back_off <
upstream->back_off)
if (dnsreq->upstreams->upstreams[i].back_off < upstream->back_off &&
transport_matches(&dnsreq->upstreams->upstreams[i], transport))
upstream = &dnsreq->upstreams->upstreams[i];
upstream->back_off++;
@ -495,9 +526,6 @@ pick_upstream(getdns_dns_req *dnsreq)
return upstream;
}
#define STUB_TCP_AGAIN -2
#define STUB_TCP_ERROR -1
static int
stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf)
{
@ -602,117 +630,174 @@ stub_tcp_read_cb(void *userarg)
}
}
/** wait for a socket to become ready */
static int
sock_wait(int sockfd)
{
int ret;
fd_set fds;
FD_ZERO(&fds);
FD_SET(FD_SET_T sockfd, &fds);
/*TODO[TLS]: Pick up this timeout from the context*/
struct timeval timeout = {5, 0 };
ret = select(sockfd+1, NULL, &fds, NULL, &timeout);
if(ret == 0)
/* timeout expired */
return 0;
else if(ret == -1)
/* error */
return 0;
return 1;
}
static int
sock_connected(int sockfd)
{
/* wait(write) until connected or error */
while(1) {
int error = 0;
socklen_t len = (socklen_t)sizeof(error);
if(!sock_wait(sockfd)) {
close(sockfd);
return -1;
}
/* check if there is a pending error for nonblocking connect */
if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, (void*)&error, &len) < 0) {
error = errno; /* on solaris errno is error */
}
if (error == EINPROGRESS || error == EWOULDBLOCK)
continue; /* try again */
else if (error != 0) {
close(sockfd);
return -1;
}
/* connected */
break;
}
return sockfd;
}
/* The connection testing and handshake should be handled by integrating this
* with the event loop framework, but for now just implement a standalone
* handshake method.*/
static SSL*
do_tls_handshake(getdns_dns_req *dnsreq, getdns_upstream *upstream)
create_tls_object(getdns_context *context, int fd)
{
/*Lets make sure the connection is up before we try a handshake*/
if (errno == EINPROGRESS && sock_connected(upstream->fd) == -1) {
return NULL;
}
/* Create SSL instance */
if (dnsreq->context->tls_ctx == NULL)
if (context->tls_ctx == NULL)
return NULL;
SSL* ssl = SSL_new(dnsreq->context->tls_ctx);
SSL* ssl = SSL_new(context->tls_ctx);
if(!ssl) {
return NULL;
}
/* Connect the SSL object with a file descriptor */
if(!SSL_set_fd(ssl, upstream->fd)) {
if(!SSL_set_fd(ssl,fd)) {
SSL_free(ssl);
return NULL;
}
SSL_set_connect_state(ssl);
(void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY);
int r;
int want;
fd_set fds;
FD_ZERO(&fds);
FD_SET(upstream->fd, &fds);
struct timeval timeout = {dnsreq->context->timeout/1000, 0 };
while ((r = SSL_do_handshake(ssl)) != 1)
{
want = SSL_get_error(ssl, r);
switch (want) {
case SSL_ERROR_WANT_READ:
if (select(upstream->fd + 1, &fds, NULL, NULL, &timeout) == 0) {
SSL_free(ssl);
return NULL;
}
break;
case SSL_ERROR_WANT_WRITE:
if (select(upstream->fd + 1, NULL, &fds, NULL, &timeout) == 0) {
SSL_free(ssl);
return NULL;
}
break;
default:
SSL_free(ssl);
return NULL;
}
}
return ssl;
}
static int
stub_tls_read(SSL* tls_obj, getdns_tcp_state *tcp, struct mem_funcs *mf)
do_tls_handshake(getdns_upstream *upstream)
{
fprintf(stderr,"[TLS]: do_tls_handshake\n");
int r;
int want;
ERR_clear_error();
while ((r = SSL_do_handshake(upstream->tls_obj)) != 1)
{
want = SSL_get_error(upstream->tls_obj, r);
switch (want) {
case SSL_ERROR_WANT_READ:
fprintf(stderr,"[TLS]: SSL_ERROR_WANT_READ\n");
upstream->event.read_cb = upstream_read_cb;
upstream->event.write_cb = NULL;
GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event);
GETDNS_SCHEDULE_EVENT(upstream->loop,
upstream->fd, TIMEOUT_FOREVER, &upstream->event);
upstream->tls_hs_state = GETDNS_HS_READ;
return STUB_TCP_AGAIN;
case SSL_ERROR_WANT_WRITE:
fprintf(stderr,"[TLS]: SSL_ERROR_WANT_WRITE\n");
upstream->event.read_cb = NULL;
upstream->event.write_cb = upstream_write_cb;
GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event);
GETDNS_SCHEDULE_EVENT(upstream->loop,
upstream->fd, TIMEOUT_FOREVER, &upstream->event);
upstream->tls_hs_state = GETDNS_HS_WRITE;
return STUB_TCP_AGAIN;
default:
SSL_free(upstream->tls_obj);
upstream->tls_obj = NULL;
upstream->tls_hs_state = GETDNS_HS_FAILED;
upstream->fd = -1;
return STUB_TLS_SETUP_ERROR;
}
}
upstream->tls_hs_state = GETDNS_HS_DONE;
upstream->event.read_cb = NULL;
upstream->event.write_cb = upstream_write_cb;
GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event);
GETDNS_SCHEDULE_EVENT(upstream->loop,
upstream->fd, TIMEOUT_FOREVER, &upstream->event);
return 0;
}
/* TODO[TLS]: Could think about fallback on read error aswell.*/
static int
fallback_on_write(getdns_network_req *netreq) {
/* This should really check if any request in the queue can fallback...*/
if (netreq->transport != GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN)
return STUB_TCP_ERROR;
/* Deal with old upstream */
getdns_upstream *upstream = netreq->upstream;
upstream->write_queue = NULL;
upstream->write_queue_last = NULL;
upstream->event.write_cb = NULL;
GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event);
/* Now set up new upstream */
getdns_upstream *new_upstream = pick_upstream(netreq->owner, 1);
if (!new_upstream)
return STUB_TCP_ERROR;
/* get transport generically*/
int fd = connect_to_upstream(new_upstream, GETDNS_TRANSPORT_TCP, netreq->owner->context);
if (fd == -1)
return STUB_TCP_ERROR;
fprintf(stderr,"[TLS]: tcp_fallback to %d \n", new_upstream->fd);
getdns_network_req *next_req;
while (netreq != NULL) {
next_req = netreq->write_queue_tail;
if (netreq->transport == GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN) {
netreq->upstream = new_upstream;
upstream_schedule_netreq(new_upstream, netreq);
/* TODO: Timout need to be adjusted and rescheduled on the new fd ....*/
/* Note, setup timeout should be shorter than message timeout for
* messages with fallback or don't have time to re-try. */
}
/*else.... leave request to timeout?*/
netreq = next_req;
}
return STUB_TCP_AGAIN;
}
static int
setup_tls(getdns_upstream* upstream)
{
int ret;
/* Already have a connection*/
if (upstream->tls_hs_state == GETDNS_HS_DONE &&
(upstream->tls_obj != NULL) && (upstream->fd != -1))
return 0;
/* Lets make sure the connection is up before we try a handshake*/
int error = 0;
socklen_t len = (socklen_t)sizeof(error);
/* TODO: This doesn't handle the case where the far end doesn't do a reset
* as is the case with e.g. 8.8.8.8. For that case the timeout kicks in
* and the user callback fails the message without the chance to fallback...
* Note that acutally the TCP code doesn't check the connection state before
* doing a first write either....
* Perhaps we should have a write_timeout_cb on the write and then schedule
* the stub_timeout_cb for matching the response??? */
getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len);
if (error == EINPROGRESS || error == EWOULDBLOCK) {
fprintf(stderr,"[TLS]: blocking.......\n");
return STUB_TCP_AGAIN; /* try again */
}
else if (error != 0) {
fprintf(stderr,"[TLS]: died gettting connection\n");
SSL_free(upstream->tls_obj);
upstream->tls_obj = NULL;
upstream->tls_hs_state = GETDNS_HS_FAILED;
upstream->fd = -1;
return STUB_TLS_SETUP_ERROR;
}
ret = do_tls_handshake(upstream);
switch (ret) {
case STUB_TCP_AGAIN:
return ret;
case STUB_TCP_ERROR:
fprintf(stderr,"[TLS]: W: Handshake has failed %d\n", upstream->tls_hs_state);
return STUB_TLS_SETUP_ERROR;
default:
fprintf(stderr,"[TLS]: W:after handshake %d, %s\n", upstream->tls_hs_state, upstream->tls_obj== NULL? "NULL":"Not NULL" );
return 0;
}
}
static int
stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, struct mem_funcs *mf)
{
ssize_t read;
uint8_t *buf;
size_t buf_size;
SSL* tls_obj = upstream->tls_obj;
int q = setup_tls(upstream);
if (q != 0)
return q;
if (!tcp->read_buf) {
/* First time tls read, create a buffer for reading */
@ -795,8 +880,11 @@ upstream_read_cb(void *userarg)
uint16_t query_id;
intptr_t query_id_intptr;
fprintf(stderr,"[TLS]: upstream_read_cb on %d\n", upstream->fd);
if (upstream->tls_obj)
q = stub_tls_read(upstream->tls_obj, &upstream->tcp,
q = stub_tls_read(upstream, &upstream->tcp,
&upstream->upstreams->mf);
else
q = stub_tcp_read(upstream->fd, &upstream->tcp,
@ -1018,12 +1106,17 @@ stub_tcp_write_cb(void *userarg)
}
static int
stub_tls_write(SSL* tls_obj, getdns_tcp_state *tcp, getdns_network_req *netreq)
stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, getdns_network_req *netreq)
{
size_t pkt_len = netreq->response - netreq->query;
ssize_t written;
uint16_t query_id;
intptr_t query_id_intptr;
SSL* tls_obj = upstream->tls_obj;
int q = setup_tls(upstream);
if (q != 0)
return q;
/* Do we have remaining data that we could not write before? */
if (! tcp->write_buf) {
@ -1073,8 +1166,11 @@ upstream_write_cb(void *userarg)
getdns_dns_req *dnsreq = netreq->owner;
int q;
fprintf(stderr,"[TLS]: method: upstream_write_cb %d\n", upstream->fd);
if (upstream->tls_obj)
q = stub_tls_write(upstream->tls_obj, &upstream->tcp, netreq);
q = stub_tls_write(upstream, &upstream->tcp, netreq);
else
q = stub_tcp_write(upstream->fd, &upstream->tcp, netreq);
@ -1086,8 +1182,16 @@ upstream_write_cb(void *userarg)
stub_erred(netreq);
return;
case STUB_TLS_SETUP_ERROR:
/* Could not complete the TLS set up. Need to fallback on this upstream
* if possible.*/
if (fallback_on_write(netreq) == STUB_TCP_ERROR)
stub_erred(netreq);
return;
default:
netreq->query_id = (uint16_t) q;
fprintf(stderr,"[TLS]: method: upstream_write_cb, successfull write %d\n", upstream->fd);
/* Unqueue the netreq from the write_queue */
if (!(upstream->write_queue = netreq->write_queue_tail)) {
@ -1137,6 +1241,8 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq)
assert(upstream->fd >= 0);
assert(upstream->loop);
fprintf(stderr,"[TLS]: method: upstream_schedule_netreq %d\n", upstream->fd);
/* Append netreq to write_queue */
if (!upstream->write_queue) {
upstream->write_queue = upstream->write_queue_last = netreq;
@ -1150,39 +1256,12 @@ upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq)
}
}
static in_port_t
get_port(struct sockaddr_storage* addr)
{
return ntohs(addr->ss_family == AF_INET
? ((struct sockaddr_in *)addr)->sin_port
: ((struct sockaddr_in6*)addr)->sin6_port);
}
static void
set_port(struct sockaddr_storage* addr, in_port_t port)
{
addr->ss_family == AF_INET
? (((struct sockaddr_in *)addr)->sin_port = htons(port))
: (((struct sockaddr_in6*)addr)->sin6_port = htons(port));
}
static int
tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) {
tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport)
{
int fd = -1;
struct sockaddr_storage connect_addr;
struct sockaddr_storage* addr = &upstream->addr;
socklen_t addr_len = upstream->addr_len;
/* TODO[TLS]: For now, override the port to a hardcoded value*/
if (transport == GETDNS_TRANSPORT_TLS &&
(int)get_port(addr) != GETDNS_TLS_PORT) {
connect_addr = upstream->addr;
addr = &connect_addr;
set_port(addr, GETDNS_TLS_PORT);
}
if ((fd = socket(addr->ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1)
return -1;
getdns_sock_nonblock(fd);
@ -1192,8 +1271,8 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) {
transport == GETDNS_TRANSPORT_TCP_SINGLE)
return fd;
#endif
if (connect(fd, (struct sockaddr *)addr,
addr_len) == -1) {
if (connect(fd, (struct sockaddr *)&upstream->addr,
upstream->addr_len) == -1) {
if (errno != EINPROGRESS) {
close(fd);
return -1;
@ -1202,106 +1281,103 @@ tcp_connect (getdns_upstream *upstream, getdns_base_transport_t transport) {
return fd;
}
int
connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport,
getdns_context *context)
{
if ((transport == GETDNS_TRANSPORT_TCP ||
transport == GETDNS_TRANSPORT_TLS)
&& upstream->fd != -1) {
fprintf(stderr,"[TLS]: method: tcp_connect using existing fd %d\n", upstream->fd);
return upstream->fd;
}
int fd;
switch(transport) {
case GETDNS_TRANSPORT_UDP:
if ((fd = socket(
upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1)
return -1;
getdns_sock_nonblock(fd);
return fd;
case GETDNS_TRANSPORT_TCP_SINGLE:
case GETDNS_TRANSPORT_TCP:
fd = tcp_connect(upstream, transport);
break;
case GETDNS_TRANSPORT_TLS:
fd = tcp_connect(upstream, transport);
if (fd == -1 ||
(upstream->tls_obj = create_tls_object(context, fd)) == NULL ) {
close(fd);
return -1;
}
upstream->tls_hs_state = GETDNS_HS_WRITE;
break;
default:
return -1;
/* Nothing to do*/
}
if (fd != -1) {
upstream->loop = context->extension;
upstream->fd = fd;
}
fprintf(stderr,"[TLS]: method: tcp_connect created new connection %d\n", fd);
return fd;
}
getdns_return_t
priv_getdns_submit_stub_request(getdns_network_req *netreq)
{
getdns_dns_req *dnsreq = netreq->owner;
getdns_upstream *upstream = pick_upstream(dnsreq);
if (!upstream)
return GETDNS_RETURN_GENERIC_ERROR;
// Work out the primary and fallback transport options
/* TODO[TLS - 1]: This will become a double while loop trying all the upstreams on all the
* transports for a connection since we need a fd to schedule on, using previous known capabilities
* All other set up is done async*/
/* Work out the primary and fallback transport options */
getdns_base_transport_t transport = priv_get_base_transport(
dnsreq->context->dns_transport,0);
getdns_base_transport_t fb_transport = priv_get_base_transport(
dnsreq->context->dns_transport,1);
getdns_upstream *upstream = pick_upstream(dnsreq, 0);
if (!upstream)
return GETDNS_RETURN_GENERIC_ERROR;
int fd = connect_to_upstream(upstream, transport, dnsreq->context);
if (fd == -1) {
if (fb_transport == GETDNS_TRANSPORT_NONE)
return GETDNS_RETURN_GENERIC_ERROR;
upstream = pick_upstream(dnsreq, 1);
if ((fd = connect_to_upstream(upstream, fb_transport, dnsreq->context)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
}
netreq->upstream = upstream;
netreq->transport = dnsreq->context->dns_transport;
switch(transport) {
case GETDNS_TRANSPORT_UDP:
if ((netreq->fd = socket(
upstream->addr.ss_family, SOCK_DGRAM, IPPROTO_UDP)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
getdns_sock_nonblock(netreq->fd);
netreq->upstream = upstream;
GETDNS_SCHEDULE_EVENT(
dnsreq->loop, netreq->fd, dnsreq->context->timeout,
getdns_eventloop_event_init(&netreq->event, netreq,
NULL, stub_udp_write_cb, stub_timeout_cb));
return GETDNS_RETURN_GOOD;
case GETDNS_TRANSPORT_TCP_SINGLE:
if ((netreq->fd = tcp_connect(upstream, transport)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
netreq->upstream = upstream;
netreq->fd = fd;
GETDNS_SCHEDULE_EVENT(
dnsreq->loop, netreq->fd, dnsreq->context->timeout,
getdns_eventloop_event_init(&netreq->event, netreq,
NULL, stub_tcp_write_cb, stub_timeout_cb));
NULL, (transport == GETDNS_TRANSPORT_UDP ? stub_udp_write_cb:
stub_tcp_write_cb), stub_timeout_cb));
return GETDNS_RETURN_GOOD;
case GETDNS_TRANSPORT_TCP:
case GETDNS_TRANSPORT_TLS:
/* In coming comments, "global" means "context wide" */
/* Are we the first? (Is global socket initialized?) */
if (upstream->fd == -1) {
/* TODO[TLS]: We should remember on the context if we had to fallback
* for this upstream so when re-connecting from a dropped TCP
* connection we don't retry TLS. */
int fallback = 0;
/* We are the first. Make global socket and connect. */
if ((upstream->fd = tcp_connect(upstream, transport)) == -1) {
if (fb_transport == GETDNS_TRANSPORT_NONE)
return GETDNS_RETURN_GENERIC_ERROR;
if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
fallback = 1;
}
/* Now do a handshake for TLS. Note waiting for this to succeed or
* timeout blocks the scheduling of any messages for this upstream*/
if (transport == GETDNS_TRANSPORT_TLS && (fallback == 0)) {
upstream->tls_obj = do_tls_handshake(dnsreq, upstream);
if (!upstream->tls_obj) {
if (fb_transport == GETDNS_TRANSPORT_NONE)
return GETDNS_RETURN_GENERIC_ERROR;
close(upstream->fd);
if ((upstream->fd = tcp_connect(upstream, fb_transport)) == -1)
return GETDNS_RETURN_GENERIC_ERROR;
}
}
/* Attach to the global event loop
* so it can do it's own scheduling
*/
upstream->loop = dnsreq->context->extension;
} else {
/* Cater for the case of the user downgrading and existing TLS
connection to TCP for some reason...*/
if (transport == GETDNS_TRANSPORT_TCP && upstream->tls_obj) {
SSL_shutdown(upstream->tls_obj);
SSL_free(upstream->tls_obj);
upstream->tls_obj = NULL;
upstream->fd = fd;
}
}
netreq->upstream = upstream;
/* We have a context wide socket.
* Now schedule the write request.
*/
upstream_schedule_netreq(upstream, netreq);
/* Schedule at least the timeout locally.
* And also the write if we perform a synchronous lookup
*/
/* TODO[TLS]: Timeout handling for async calls must change....
* Maybe even change scheduling for sync calls here too*/
GETDNS_SCHEDULE_EVENT(
dnsreq->loop, upstream->fd, dnsreq->context->timeout,
getdns_eventloop_event_init(&netreq->event, netreq, NULL,

View File

@ -191,6 +191,7 @@ typedef struct getdns_network_req
/* For stub resolving */
struct getdns_upstream *upstream;
int fd;
getdns_transport_t transport;
getdns_eventloop_event event;
getdns_tcp_state tcp;
uint16_t query_id;