diff --git a/src/context.c b/src/context.c index 8dfde4c9..8dc571fa 100644 --- a/src/context.c +++ b/src/context.c @@ -703,6 +703,7 @@ set_os_defaults(struct getdns_context *context) upstream = &context->upstreams-> upstreams[context->upstreams->count++]; + fprintf(stderr, "[TLS]: OS: creating upstream %d, %p, with port %s with transport %d\n", (int)context->upstreams->count, (void*)upstream, port_str, base_transport); upstream_init(upstream, context->upstreams, result); upstream->dns_base_transport = base_transport; } @@ -1487,6 +1488,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, getdns_return_t r; size_t count = 0; size_t i; + //size_t upstreams_limit; getdns_upstreams *upstreams; char addrstr[1024], portstr[1024], *eos; struct addrinfo hints; @@ -1507,70 +1509,84 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, hints.ai_addr = NULL; hints.ai_next = NULL; - /* TODO[TLS]: Resize on the fly to avoid hardcoding this*/ upstreams = upstreams_create(context, count*3); + //upstreams_limit = count; for (i = 0; i < count; i++) { + getdns_dict *dict; + getdns_bindata *address_type; + getdns_bindata *address_data; + struct sockaddr_storage addr; + + getdns_bindata *scope_id; + getdns_upstream *upstream; + + if ((r = getdns_list_get_dict(upstream_list, i, &dict))) + goto error; + + if ((r = getdns_dict_get_bindata( + dict, "address_type",&address_type))) + goto error; + if (address_type->size < 4) + goto invalid_parameter; + if (strncmp((char *)address_type->data, "IPv4", 4) == 0) + addr.ss_family = AF_INET; + else if (strncmp((char *)address_type->data, "IPv6", 4) == 0) + addr.ss_family = AF_INET6; + else goto invalid_parameter; + + if ((r = getdns_dict_get_bindata( + dict, "address_data", &address_data))) + goto error; + if ((addr.ss_family == AF_INET && + address_data->size != 4) || + (addr.ss_family == AF_INET6 && + address_data->size != 16)) + goto invalid_parameter; + if (inet_ntop(addr.ss_family, address_data->data, + addrstr, 1024) == NULL) + goto invalid_parameter; + + if (getdns_dict_get_bindata(dict, "scope_id", &scope_id) == + GETDNS_RETURN_GOOD) { + if (strlen(addrstr) + scope_id->size > 1022) + goto invalid_parameter; + eos = &addrstr[strlen(addrstr)]; + *eos++ = '%'; + (void) memcpy(eos, scope_id->data, scope_id->size); + eos[scope_id->size] = 0; + } + /* Loop to create upstreams as needed*/ getdns_base_transport_t base_transport = GETDNS_BASE_TRANSPORT_MIN; for (; base_transport < GETDNS_BASE_TRANSPORT_MAX; base_transport++) { - getdns_dict *dict; - getdns_bindata *address_type; - getdns_bindata *address_data; uint32_t port; - getdns_bindata *scope_id; struct addrinfo *ai; - getdns_upstream *upstream; - port = getdns_port_array[base_transport]; if (port == GETDNS_PORT_ZERO) continue; - upstream = &upstreams->upstreams[upstreams->count]; - if ((r = getdns_list_get_dict(upstream_list, i, &dict))) - goto error; - - if ((r = getdns_dict_get_bindata( - dict, "address_type",&address_type))) - goto error; - if (address_type->size < 4) - goto invalid_parameter; - if (strncmp((char *)address_type->data, "IPv4", 4) == 0) - upstream->addr.ss_family = AF_INET; - else if (strncmp((char *)address_type->data, "IPv6", 4) == 0) - upstream->addr.ss_family = AF_INET6; - else goto invalid_parameter; - - if ((r = getdns_dict_get_bindata( - dict, "address_data", &address_data))) - goto error; - if ((upstream->addr.ss_family == AF_INET && - address_data->size != 4) || - (upstream->addr.ss_family == AF_INET6 && - address_data->size != 16)) - goto invalid_parameter; - if (inet_ntop(upstream->addr.ss_family, address_data->data, - addrstr, 1024) == NULL) - goto invalid_parameter; - - (void) getdns_dict_get_int(dict, "port", &port); + /* TODO[TLS]:Respect the user port for TCP and STARTTLS, but for + * now hardcode the TLS port */ + if (base_transport != GETDNS_BASE_TRANSPORT_TLS) + (void) getdns_dict_get_int(dict, "port", &port); (void) snprintf(portstr, 1024, "%d", (int)port); - if (getdns_dict_get_bindata(dict, "scope_id", &scope_id) == - GETDNS_RETURN_GOOD) { - if (strlen(addrstr) + scope_id->size > 1022) - goto invalid_parameter; - eos = &addrstr[strlen(addrstr)]; - *eos++ = '%'; - (void) memcpy(eos, scope_id->data, scope_id->size); - eos[scope_id->size] = 0; - } - if (getaddrinfo(addrstr, portstr, &hints, &ai)) goto invalid_parameter; /* TODO[TLS]: Should probably check that the upstream doesn't - * already exist (in case user has specified port explicitly)*/ + * already exist (in case user has specified TLS port explicitly and + * to prevent duplicates) */ + + /* TODO[TLS]: Grow array when needed. This causes a crash later.... + if (upstreams->count == upstreams_limit) + upstreams = upstreams_resize( + upstreams, (upstreams_limit *= 2)); */ + + upstream = &upstreams->upstreams[upstreams->count]; + upstream->addr.ss_family = addr.ss_family; upstream_init(upstream, upstreams, ai); + fprintf(stderr, "[TLS]: creating upstream %d, %p, with port %d with transport %d\n", (int)upstreams->count, (void*)upstream,(int)port, base_transport); upstream->dns_base_transport = base_transport; upstreams->count++; freeaddrinfo(ai); @@ -1912,8 +1928,8 @@ getdns_context_prepare_for_resolution(struct getdns_context *context, } /* Block use of TLS ONLY in recursive mode as it won't work */ /* TODO[TLS]: Check if TLS is the only option in the list*/ - if (context->resolution_type == GETDNS_RESOLUTION_RECURSING - && context->dns_transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) + if (context->resolution_type == GETDNS_RESOLUTION_RECURSING && + context->dns_transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) return GETDNS_RETURN_BAD_CONTEXT; if (context->resolution_type_set == context->resolution_type) diff --git a/src/request-internal.c b/src/request-internal.c index f83d0805..b4267022 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -91,7 +91,7 @@ network_req_init(getdns_network_req *net_req, getdns_dns_req *owner, net_req->fd = -1; for (i = 0; i < GETDNS_BASE_TRANSPORT_MAX; i++) net_req->dns_base_transports[i] = owner->context->dns_base_transports[i]; - net_req->dns_base_transport = net_req->dns_base_transports; + net_req->transport = 0; memset(&net_req->event, 0, sizeof(net_req->event)); memset(&net_req->tcp, 0, sizeof(net_req->tcp)); net_req->query_id = 0; diff --git a/src/stub.c b/src/stub.c index e4dfc910..c22a68f4 100755 --- a/src/stub.c +++ b/src/stub.c @@ -58,7 +58,13 @@ static void upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq); static void netreq_upstream_read_cb(void *userarg); static void netreq_upstream_write_cb(void *userarg); -static int fallback_on_write(getdns_network_req *netreq); +static int fallback_on_write(getdns_network_req *netreq); + +static void stub_tcp_write_cb(void *userarg); + +/*****************************/ +/* General utility functions */ +/*****************************/ static void rollover_secret() @@ -231,7 +237,7 @@ match_and_process_server_cookie( } static int -create_starttls_request(getdns_dns_req *dnsreq, getdns_upstream *upstream, +create_starttls_request(getdns_dns_req *dnsreq, getdns_upstream *upstream, getdns_eventloop *loop) { getdns_return_t r = GETDNS_RETURN_GOOD; @@ -246,14 +252,13 @@ create_starttls_request(getdns_dns_req *dnsreq, getdns_upstream *upstream, } upstream->starttls_req = dns_req_new(dnsreq->context, loop, "STARTTLS", GETDNS_RRTYPE_TXT, extensions); - /*TODO[STARTTLS]: TO BIT*/ + /*TODO[TLS]: TO BIT*/ if (upstream->starttls_req == NULL) return 0; getdns_dict_destroy(extensions); upstream->starttls_req->netreqs[0]->upstream = upstream; return 1; - } static int @@ -272,8 +277,8 @@ dname_equal(uint8_t *s1, uint8_t *s2) } static int -is_starttls_response(getdns_network_req *netreq) { - +is_starttls_response(getdns_network_req *netreq) +{ priv_getdns_rr_iter rr_iter_storage, *rr_iter; priv_getdns_rdf_iter rdf_iter_storage, *rdf_iter; uint16_t rr_type; @@ -286,7 +291,8 @@ is_starttls_response(getdns_network_req *netreq) { /* Servers that are not STARTTLS aware will refuse the CH query*/ if (LDNS_RCODE_NOERROR != GLDNS_RCODE_WIRE(netreq->response)) { - fprintf(stderr, "[STARTTLS] STARTTLS response had error %d\n", GLDNS_RCODE_WIRE(netreq->response)); + fprintf(stderr, "[TLS] STARTTLS response had error %d\n", + GLDNS_RCODE_WIRE(netreq->response)); return 0; } @@ -319,10 +325,12 @@ is_starttls_response(getdns_network_req *netreq) { starttls_name = priv_getdns_rdf_if_or_as_decompressed( rdf_iter,starttls_name_space,&starttls_name_len); if (dname_equal(starttls_name, owner_name)) { - fprintf(stderr, "[STARTTLS] STARTTLS response received :%s:\n", (char*)starttls_name); + fprintf(stderr, "[TLS] STARTTLS response received :%s:\n", + (char*)starttls_name); return 1; } else { - fprintf(stderr, "[STARTTLS] NO_TLS response received :%s:\n", (char*)starttls_name); + fprintf(stderr, "[TLS] NO_TLS response received :%s:\n", + (char*)starttls_name); return 0; } continue; @@ -330,8 +338,6 @@ is_starttls_response(getdns_network_req *netreq) { return 0; } - - /** best effort to set nonblocking */ static void getdns_sock_nonblock(int sockfd) @@ -352,6 +358,36 @@ getdns_sock_nonblock(int sockfd) #endif } +static int +tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +{ + + int fd = -1; + if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) + return -1; + + getdns_sock_nonblock(fd); +#ifdef USE_TCP_FASTOPEN + /* Leave the connect to the later call to sendto() if using TCP*/ + if (transport == GETDNS_BASE_TRANSPORT_TCP || + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || + transport == GETDNS_BASE_TRANSPORT_STARTTLS) + return fd; +#endif + if (connect(fd, (struct sockaddr *)&upstream->addr, + upstream->addr_len) == -1) { + if (errno != EINPROGRESS) { + close(fd); + return -1; + } + } + return fd; +} + +/**************************/ +/* Error/cleanup functions*/ +/**************************/ + static void stub_next_upstream(getdns_network_req *netreq) { @@ -444,10 +480,10 @@ upstream_erred(getdns_upstream *upstream) } close(upstream->fd); upstream->fd = -1; - /*TODO[TLS]: Upstream errors don't trigger the user callback....*/ } + static void -message_erred(getdns_network_req *netreq) +message_erred(getdns_network_req *netreq) { stub_cleanup(netreq); netreq->state = NET_REQ_FINISHED; @@ -483,6 +519,14 @@ stub_timeout_cb(void *userarg) fprintf(stderr,"[TLS]: TIMEOUT(stub_timeout_cb)\n"); getdns_network_req *netreq = (getdns_network_req *)userarg; + + /* For now, mark a STARTTLS timeout as a failured negotiation and allow + * fallback but don't close the connection. */ + if (is_starttls_response(netreq)) { + netreq->upstream->tls_hs_state = GETDNS_HS_FAILED; + stub_next_upstream(netreq); + stub_cleanup(netreq); + } stub_next_upstream(netreq); stub_cleanup(netreq); @@ -490,7 +534,471 @@ stub_timeout_cb(void *userarg) (void) getdns_context_request_timed_out(netreq->owner); } -static void stub_tcp_write_cb(void *userarg); +/****************************/ +/* TCP read/write functions */ +/****************************/ + +static int +stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf) +{ + ssize_t read; + uint8_t *buf; + size_t buf_size; + + if (!tcp->read_buf) { + /* First time tcp read, create a buffer for reading */ + if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) + return STUB_TCP_ERROR; + + tcp->read_buf_len = 4096; + tcp->read_pos = tcp->read_buf; + tcp->to_read = 2; /* Packet size */ + } + read = recv(fd, tcp->read_pos, tcp->to_read, 0); + if (read == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + return STUB_TCP_AGAIN; + else + return STUB_TCP_ERROR; + } else if (read == 0) { + /* Remote end closed the socket */ + /* TODO: Try to reconnect */ + return STUB_TCP_ERROR; + } + tcp->to_read -= read; + tcp->read_pos += read; + + if ((int)tcp->to_read > 0) + return STUB_TCP_AGAIN; + + read = tcp->read_pos - tcp->read_buf; + if (read == 2) { + /* Read the packet size short */ + tcp->to_read = gldns_read_uint16(tcp->read_buf); + + if (tcp->to_read < GLDNS_HEADER_SIZE) + return STUB_TCP_ERROR; + + /* Resize our buffer if needed */ + if (tcp->to_read > tcp->read_buf_len) { + buf_size = tcp->read_buf_len; + while (tcp->to_read > buf_size) + buf_size *= 2; + + if (!(buf = GETDNS_XREALLOC(*mf, + tcp->read_buf, uint8_t, buf_size))) + return STUB_TCP_ERROR; + + tcp->read_buf = buf; + tcp->read_buf_len = buf_size; + } + /* Ready to start reading the packet */ + tcp->read_pos = tcp->read_buf; + return STUB_TCP_AGAIN; + } + return GLDNS_ID_WIRE(tcp->read_buf); +} + +/* stub_tcp_write(fd, tcp, netreq) + * will return STUB_TCP_AGAIN when we need to come back again, + * STUB_TCP_ERROR on error and a query_id on successfull sent. + */ +static int +stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) +{ + + size_t pkt_len = netreq->response - netreq->query; + ssize_t written; + uint16_t query_id; + intptr_t query_id_intptr; + + /* Do we have remaining data that we could not write before? */ + if (! tcp->write_buf) { + /* No, this is an initial write. Try to send + */ + + /* Not keeping connections open? Then the first random number + * will do as the query id. + * + * Otherwise find a unique query_id not already written (or in + * the write_queue) for that upstream. Register this netreq + * by query_id in the process. + */ + if ((netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_TCP_SINGLE) || + (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_UDP)) + query_id = arc4random(); + else do { + query_id = arc4random(); + query_id_intptr = (intptr_t)query_id; + netreq->node.key = (void *)query_id_intptr; + + } while (!getdns_rbtree_insert( + &netreq->upstream->netreq_by_query_id, &netreq->node)); + + GLDNS_ID_SET(netreq->query, query_id); + if (netreq->opt) { + /* no limits on the max udp payload size with tcp */ + gldns_write_uint16(netreq->opt + 3, 65535); + + if (netreq->owner->edns_cookies) { + netreq->response = attach_edns_cookie( + netreq->upstream, netreq->opt); + pkt_len = netreq->response - netreq->query; + gldns_write_uint16(netreq->query - 2, pkt_len); + } + } + /* We have an initialized packet buffer. + * Lets see how much of it we can write + */ +#ifdef USE_TCP_FASTOPEN + /* We use sendto() here which will do both a connect and send */ + written = sendto(fd, netreq->query - 2, pkt_len + 2, + MSG_FASTOPEN, (struct sockaddr *)&(netreq->upstream->addr), + netreq->upstream->addr_len); + /* If pipelining we will find that the connection is already up so + just fall back to a 'normal' write. */ + if (written == -1 && errno == EISCONN) + written = write(fd, netreq->query - 2, pkt_len + 2); + + if ((written == -1 && (errno == EAGAIN || + errno == EWOULDBLOCK || + /* Add the error case where the connection is in progress which is when + a cookie is not available (e.g. when doing the first request to an + upstream). We must let the handshake complete since non-blocking. */ + errno == EINPROGRESS)) || + written < pkt_len + 2) { +#else + written = write(fd, netreq->query - 2, pkt_len + 2); + if ((written == -1 && (errno == EAGAIN || + errno == EWOULDBLOCK)) || + written < pkt_len + 2) { +#endif + /* We couldn't write the whole packet. + * We have to return with STUB_TCP_AGAIN. + * Setup tcp to track the state. + */ + tcp->write_buf = netreq->query - 2; + tcp->write_buf_len = pkt_len + 2; + tcp->written = written >= 0 ? written : 0; + + return STUB_TCP_AGAIN; + + } else if (written == -1) + return STUB_TCP_ERROR; + + /* We were able to write everything! Start reading. */ + return (int) query_id; + + } else {/* if (! tcp->write_buf) */ + + /* Coming back from an earlier unfinished write or handshake. + * Try to send remaining data */ + written = write(fd, tcp->write_buf + tcp->written, + tcp->write_buf_len - tcp->written); + if (written == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) + return STUB_TCP_AGAIN; + else + return STUB_TCP_ERROR; + } + tcp->written += written; + if (tcp->written < tcp->write_buf_len) + /* Still more to send */ + return STUB_TCP_AGAIN; + + query_id = (int)GLDNS_ID_WIRE(tcp->write_buf + 2); + /* Done. Start reading */ + tcp->write_buf = NULL; + return query_id; + + } /* if (! tcp->write_buf) */ +} + +/*************************/ +/* TLS Utility functions */ +/*************************/ + +static int +tls_requested(getdns_network_req *netreq) +{ + return (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_TLS || + netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_STARTTLS) ? + 1 : 0; +} + +static int +tls_should_write(getdns_upstream *upstream) +{ + /* Should messages be written on TLS upstream. Remember that for STARTTLS + * the first message should got over TCP as the handshake isn't started yet.*/ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + upstream->tls_hs_state != GETDNS_HS_NONE) ? 1 : 0; +} + +static int +tls_should_read(getdns_upstream *upstream) +{ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + !(upstream->tls_hs_state == GETDNS_HS_FAILED || + upstream->tls_hs_state == GETDNS_HS_NONE)) ? 1 : 0; +} + +static int +tls_failed(getdns_upstream *upstream) +{ + /* No messages should be scheduled onto an upstream in this state */ + return ((upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS) && + upstream->tls_hs_state == GETDNS_HS_FAILED) ? 1: 0; +} + +static SSL* +tls_create_object(getdns_context *context, int fd) +{ + /* Create SSL instance */ + if (context->tls_ctx == NULL) + return NULL; + SSL* ssl = SSL_new(context->tls_ctx); + if(!ssl) + return NULL; + /* Connect the SSL object with a file descriptor */ + if(!SSL_set_fd(ssl,fd)) { + SSL_free(ssl); + return NULL; + } + SSL_set_connect_state(ssl); + (void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY); + return ssl; +} + +static int +tls_do_handshake(getdns_upstream *upstream) +{ + fprintf(stderr,"[TLS]: TLS(tls_do_handshake)\n"); + + int r; + int want; + ERR_clear_error(); + while ((r = SSL_do_handshake(upstream->tls_obj)) != 1) + { + want = SSL_get_error(upstream->tls_obj, r); + switch (want) { + case SSL_ERROR_WANT_READ: + fprintf(stderr,"[TLS]: SSL_ERROR_WANT_READ\n"); + upstream->event.read_cb = upstream_read_cb; + upstream->event.write_cb = NULL; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_READ; + return STUB_TCP_AGAIN; + case SSL_ERROR_WANT_WRITE: + fprintf(stderr,"[TLS]: SSL_ERROR_WANT_WRITE\n"); + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + upstream->tls_hs_state = GETDNS_HS_WRITE; + return STUB_TCP_AGAIN; + default: + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + upstream->tls_hs_state = GETDNS_HS_FAILED; + return STUB_TLS_SETUP_ERROR; + } + } + upstream->tls_hs_state = GETDNS_HS_DONE; + upstream->event.read_cb = NULL; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + return 0; +} + +static int +tls_connected(getdns_upstream* upstream) +{ + /* Already have a connection*/ + if (upstream->tls_hs_state == GETDNS_HS_DONE && + (upstream->tls_obj != NULL) && (upstream->fd != -1)) + return 0; + + /* Already tried and failed, so let the fallback code take care of things */ + if (upstream->tls_hs_state == GETDNS_HS_FAILED) + return STUB_TLS_SETUP_ERROR; + + /* Lets make sure the connection is up before we try a handshake*/ + int error = 0; + socklen_t len = (socklen_t)sizeof(error); + /* TODO: This doesn't handle the case where the far end doesn't do a reset + * as is the case with e.g. 8.8.8.8. For that case the timeout kicks in + * and the user callback fails the message without the chance to fallback.*/ + getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); + if (error == EINPROGRESS || error == EWOULDBLOCK) + return STUB_TCP_AGAIN; /* try again */ + else if (error != 0) { + + fprintf(stderr,"[TLS]: TLS(tls_connected): died gettting connection\n"); + SSL_free(upstream->tls_obj); + upstream->tls_obj = NULL; + upstream->tls_hs_state = GETDNS_HS_FAILED; + return STUB_TLS_SETUP_ERROR; + } + + return tls_do_handshake(upstream); +} + +/***************************/ +/* TLS read/write functions*/ +/***************************/ + +static int +stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, + struct mem_funcs *mf) +{ + ssize_t read; + uint8_t *buf; + size_t buf_size; + SSL* tls_obj = upstream->tls_obj; + + int q = tls_connected(upstream); + if (q != 0) + return q; + + if (!tcp->read_buf) { + /* First time tls read, create a buffer for reading */ + if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) + return STUB_TCP_ERROR; + + tcp->read_buf_len = 4096; + tcp->read_pos = tcp->read_buf; + tcp->to_read = 2; /* Packet size */ + } + + ERR_clear_error(); + read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); + if (read <= 0) { + /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake + renegotiation. Need to keep handshake state to do that.*/ + int want = SSL_get_error(tls_obj, read); + if (want == SSL_ERROR_WANT_READ) { + return STUB_TCP_AGAIN; /* read more later */ + } else + return STUB_TCP_ERROR; + } + tcp->to_read -= read; + tcp->read_pos += read; + + if ((int)tcp->to_read > 0) + return STUB_TCP_AGAIN; + + read = tcp->read_pos - tcp->read_buf; + if (read == 2) { + /* Read the packet size short */ + tcp->to_read = gldns_read_uint16(tcp->read_buf); + + if (tcp->to_read < GLDNS_HEADER_SIZE) + return STUB_TCP_ERROR; + + /* Resize our buffer if needed */ + if (tcp->to_read > tcp->read_buf_len) { + buf_size = tcp->read_buf_len; + while (tcp->to_read > buf_size) + buf_size *= 2; + + if (!(buf = GETDNS_XREALLOC(*mf, + tcp->read_buf, uint8_t, buf_size))) + return STUB_TCP_ERROR; + + tcp->read_buf = buf; + tcp->read_buf_len = buf_size; + } + + /* Ready to start reading the packet */ + tcp->read_pos = tcp->read_buf; + read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); + if (read <= 0) { + /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake + renegotiation. Need to keep handshake state to do that.*/ + int want = SSL_get_error(tls_obj, read); + if (want == SSL_ERROR_WANT_READ) { + return STUB_TCP_AGAIN; /* read more later */ + } else + return STUB_TCP_ERROR; + } + tcp->to_read -= read; + tcp->read_pos += read; + if ((int)tcp->to_read > 0) + return STUB_TCP_AGAIN; + } + return GLDNS_ID_WIRE(tcp->read_buf); +} + +static int +stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, + getdns_network_req *netreq) +{ + size_t pkt_len = netreq->response - netreq->query; + ssize_t written; + uint16_t query_id; + intptr_t query_id_intptr; + SSL* tls_obj = upstream->tls_obj; + + int q = tls_connected(upstream); + if (q != 0) + return q; + + /* Do we have remaining data that we could not write before? */ + if (! tcp->write_buf) { + /* No, this is an initial write. Try to send + */ + + /* Find a unique query_id not already written (or in + * the write_queue) for that upstream. Register this netreq + * by query_id in the process. + */ + do { + query_id = ldns_get_random(); + query_id_intptr = (intptr_t)query_id; + netreq->node.key = (void *)query_id_intptr; + + } while (!getdns_rbtree_insert( + &netreq->upstream->netreq_by_query_id, &netreq->node)); + + GLDNS_ID_SET(netreq->query, query_id); + if (netreq->opt) + /* no limits on the max udp payload size with tcp */ + gldns_write_uint16(netreq->opt + 3, 65535); + + /* We have an initialized packet buffer. + * Lets see how much of it we can write */ + + /* TODO[TLS]: Handle error cases, partial writes, renegotiation etc. */ + ERR_clear_error(); + written = SSL_write(tls_obj, netreq->query - 2, pkt_len + 2); + if (written <= 0) + return STUB_TCP_ERROR; + + /* We were able to write everything! Start reading. */ + return (int) query_id; + + } + + return STUB_TCP_ERROR; +} + +/**************************/ +/* UDP callback functions */ +/**************************/ + static void stub_udp_read_cb(void *userarg) { @@ -594,133 +1102,9 @@ stub_udp_write_cb(void *userarg) stub_udp_read_cb, NULL, stub_timeout_cb)); } -static int -transport_valid(struct getdns_upstream *upstream, getdns_base_transport_t transport) { - /* For single shot transports, use only the TCP upstream. */ - if (transport == GETDNS_BASE_TRANSPORT_UDP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) { - if (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP) - return 1; - else - return 0; - } - /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't upgraded - * to avoid opening a new connection if one is aleady open. */ - if (transport == GETDNS_BASE_TRANSPORT_TCP && - upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && - upstream->tls_hs_state == GETDNS_HS_FAILED) - return 1; - /* Otherwise, transport must match */ - if (upstream->dns_base_transport != transport) - return 0; - /* But don't use if upgrade failed for (START)TLS*/ - if ((transport == GETDNS_BASE_TRANSPORT_TLS || - transport == GETDNS_BASE_TRANSPORT_STARTTLS) - && upstream->tls_hs_state == GETDNS_HS_FAILED) - return 0; - return 1; -} - -static getdns_upstream * -pick_upstream(getdns_network_req *netreq, getdns_base_transport_t transport) -{ - getdns_upstream *upstream; - getdns_upstreams *upstreams = netreq->owner->upstreams; - size_t i; - - if (!upstreams->count) - return NULL; - - for (i = 0; i < upstreams->count; i++) - if (upstreams->upstreams[i].to_retry <= 0) - upstreams->upstreams[i].to_retry++; - - i = upstreams->current; - do { - if (upstreams->upstreams[i].to_retry > 0 && - transport_valid(&upstreams->upstreams[i], transport)) { - upstreams->current = i; - return &upstreams->upstreams[i]; - } - if (++i > upstreams->count) - i = 0; - } while (i != upstreams->current); - - upstream = upstreams->upstreams; - for (i = 1; i < upstreams->count; i++) - if (upstreams->upstreams[i].back_off < upstream->back_off && - transport_valid(&upstreams->upstreams[i], transport)) - upstream = &upstreams->upstreams[i]; - - /* Need to check again that the transport is valid */ - if (!transport_valid(upstream, transport)) - return NULL; - upstream->back_off++; - upstream->to_retry = 1; - upstreams->current = upstream - upstreams->upstreams; - return upstream; -} - -static int -stub_tcp_read(int fd, getdns_tcp_state *tcp, struct mem_funcs *mf) -{ - ssize_t read; - uint8_t *buf; - size_t buf_size; - - if (!tcp->read_buf) { - /* First time tcp read, create a buffer for reading */ - if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) - return STUB_TCP_ERROR; - - tcp->read_buf_len = 4096; - tcp->read_pos = tcp->read_buf; - tcp->to_read = 2; /* Packet size */ - } - read = recv(fd, tcp->read_pos, tcp->to_read, 0); - if (read == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) - return STUB_TCP_AGAIN; - else - return STUB_TCP_ERROR; - } else if (read == 0) { - /* Remote end closed the socket */ - /* TODO: Try to reconnect */ - return STUB_TCP_ERROR; - } - tcp->to_read -= read; - tcp->read_pos += read; - - if ((int)tcp->to_read > 0) - return STUB_TCP_AGAIN; - - read = tcp->read_pos - tcp->read_buf; - if (read == 2) { - /* Read the packet size short */ - tcp->to_read = gldns_read_uint16(tcp->read_buf); - - if (tcp->to_read < GLDNS_HEADER_SIZE) - return STUB_TCP_ERROR; - - /* Resize our buffer if needed */ - if (tcp->to_read > tcp->read_buf_len) { - buf_size = tcp->read_buf_len; - while (tcp->to_read > buf_size) - buf_size *= 2; - - if (!(buf = GETDNS_XREALLOC(*mf, - tcp->read_buf, uint8_t, buf_size))) - return STUB_TCP_ERROR; - - tcp->read_buf = buf; - tcp->read_buf_len = buf_size; - } - /* Ready to start reading the packet */ - tcp->read_pos = tcp->read_buf; - return STUB_TCP_AGAIN; - } - return GLDNS_ID_WIRE(tcp->read_buf); -} +/**************************/ +/* TCP callback functions*/ +/**************************/ static void stub_tcp_read_cb(void *userarg) @@ -765,192 +1149,35 @@ stub_tcp_read_cb(void *userarg) } } -static SSL* -create_tls_object(getdns_context *context, int fd) +static void +stub_tcp_write_cb(void *userarg) { - /* Create SSL instance */ - if (context->tls_ctx == NULL) - return NULL; - SSL* ssl = SSL_new(context->tls_ctx); - if(!ssl) - return NULL; - /* Connect the SSL object with a file descriptor */ - if(!SSL_set_fd(ssl,fd)) { - SSL_free(ssl); - return NULL; + getdns_network_req *netreq = (getdns_network_req *)userarg; + getdns_dns_req *dnsreq = netreq->owner; + int q; + + switch ((q = stub_tcp_write(netreq->fd, &netreq->tcp, netreq))) { + case STUB_TCP_AGAIN: + return; + + case STUB_TCP_ERROR: + stub_erred(netreq); + return; + + default: + netreq->query_id = (uint16_t) q; + GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); + GETDNS_SCHEDULE_EVENT( + dnsreq->loop, netreq->fd, dnsreq->context->timeout, + getdns_eventloop_event_init(&netreq->event, netreq, + stub_tcp_read_cb, NULL, stub_timeout_cb)); + return; } - SSL_set_connect_state(ssl); - (void) SSL_set_mode(ssl, SSL_MODE_AUTO_RETRY); - return ssl; } -static int -do_tls_handshake(getdns_upstream *upstream) -{ - - fprintf(stderr,"[TLS]: TLS(do_tls_handshake)\n"); - - int r; - int want; - ERR_clear_error(); - while ((r = SSL_do_handshake(upstream->tls_obj)) != 1) - { - want = SSL_get_error(upstream->tls_obj, r); - switch (want) { - case SSL_ERROR_WANT_READ: - fprintf(stderr,"[TLS]: SSL_ERROR_WANT_READ\n"); - upstream->event.read_cb = upstream_read_cb; - upstream->event.write_cb = NULL; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); - upstream->tls_hs_state = GETDNS_HS_READ; - return STUB_TCP_AGAIN; - case SSL_ERROR_WANT_WRITE: - fprintf(stderr,"[TLS]: SSL_ERROR_WANT_WRITE\n"); - upstream->event.read_cb = NULL; - upstream->event.write_cb = upstream_write_cb; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); - upstream->tls_hs_state = GETDNS_HS_WRITE; - return STUB_TCP_AGAIN; - default: - SSL_free(upstream->tls_obj); - upstream->tls_obj = NULL; - upstream->tls_hs_state = GETDNS_HS_FAILED; - return STUB_TLS_SETUP_ERROR; - } - } - upstream->tls_hs_state = GETDNS_HS_DONE; - upstream->event.read_cb = NULL; - upstream->event.write_cb = upstream_write_cb; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); - return 0; -} - -static int -tls_handshake_active(getdns_tls_hs_state_t hs_state) -{ - return (hs_state == GETDNS_HS_FAILED || - hs_state == GETDNS_HS_NONE) ? 0 : 1; -} - -static int -check_tls(getdns_upstream* upstream) -{ - /* Already have a connection*/ - if (upstream->tls_hs_state == GETDNS_HS_DONE && - (upstream->tls_obj != NULL) && (upstream->fd != -1)) - return 0; - - /* This upstream can't be used, so let the fallback code take care of things */ - if (upstream->tls_hs_state == GETDNS_HS_FAILED) - return STUB_TLS_SETUP_ERROR; - - /* Lets make sure the connection is up before we try a handshake*/ - int error = 0; - socklen_t len = (socklen_t)sizeof(error); - /* TODO: This doesn't handle the case where the far end doesn't do a reset - * as is the case with e.g. 8.8.8.8. For that case the timeout kicks in - * and the user callback fails the message without the chance to fallback.*/ - getsockopt(upstream->fd, SOL_SOCKET, SO_ERROR, (void*)&error, &len); - if (error == EINPROGRESS || error == EWOULDBLOCK) - return STUB_TCP_AGAIN; /* try again */ - else if (error != 0) { - - fprintf(stderr,"[TLS]: TLS(check_tls): died gettting connection\n"); - SSL_free(upstream->tls_obj); - upstream->tls_obj = NULL; - upstream->tls_hs_state = GETDNS_HS_FAILED; - return STUB_TLS_SETUP_ERROR; - } - - return do_tls_handshake(upstream); -} - -static int -stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, struct mem_funcs *mf) -{ - ssize_t read; - uint8_t *buf; - size_t buf_size; - SSL* tls_obj = upstream->tls_obj; - - int q = check_tls(upstream); - if (q != 0) - return q; - - if (!tcp->read_buf) { - /* First time tls read, create a buffer for reading */ - if (!(tcp->read_buf = GETDNS_XMALLOC(*mf, uint8_t, 4096))) - return STUB_TCP_ERROR; - - tcp->read_buf_len = 4096; - tcp->read_pos = tcp->read_buf; - tcp->to_read = 2; /* Packet size */ - } - - ERR_clear_error(); - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake - renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { - return STUB_TCP_AGAIN; /* read more later */ - } else - return STUB_TCP_ERROR; - } - tcp->to_read -= read; - tcp->read_pos += read; - - if ((int)tcp->to_read > 0) - return STUB_TCP_AGAIN; - - read = tcp->read_pos - tcp->read_buf; - if (read == 2) { - /* Read the packet size short */ - tcp->to_read = gldns_read_uint16(tcp->read_buf); - - if (tcp->to_read < GLDNS_HEADER_SIZE) - return STUB_TCP_ERROR; - - /* Resize our buffer if needed */ - if (tcp->to_read > tcp->read_buf_len) { - buf_size = tcp->read_buf_len; - while (tcp->to_read > buf_size) - buf_size *= 2; - - if (!(buf = GETDNS_XREALLOC(*mf, - tcp->read_buf, uint8_t, buf_size))) - return STUB_TCP_ERROR; - - tcp->read_buf = buf; - tcp->read_buf_len = buf_size; - } - - /* Ready to start reading the packet */ - tcp->read_pos = tcp->read_buf; - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake - renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { - return STUB_TCP_AGAIN; /* read more later */ - } else - return STUB_TCP_ERROR; - } - tcp->to_read -= read; - tcp->read_pos += read; - if ((int)tcp->to_read > 0) - return STUB_TCP_AGAIN; - } - return GLDNS_ID_WIRE(tcp->read_buf); -} +/**************************/ +/* Upstream callback functions*/ +/**************************/ static void upstream_read_cb(void *userarg) @@ -966,9 +1193,7 @@ upstream_read_cb(void *userarg) fprintf(stderr,"[TLS]: **********CALLBACK***********\n"); fprintf(stderr,"[TLS]: READ(upstream_read_cb): on %d\n", upstream->fd); - if (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && - tls_handshake_active(upstream->tls_hs_state))) + if (tls_should_read(upstream)) q = stub_tls_read(upstream, &upstream->tcp, &upstream->upstreams->mf); else @@ -1030,9 +1255,9 @@ upstream_read_cb(void *userarg) if (netreq->owner == upstream->starttls_req) { dnsreq = netreq->owner; - fprintf(stderr, "[STARTTLS] processing STARTTLS response!\n"); + fprintf(stderr, "[TLS]: processing STARTTLS response!\n"); if (is_starttls_response(netreq)) { - upstream->tls_obj = create_tls_object(dnsreq->context, upstream->fd); + upstream->tls_obj = tls_create_object(dnsreq->context, upstream->fd); if (upstream->tls_obj == NULL) { fprintf(stderr,"[TLS]: could not create tls object\n"); upstream->tls_hs_state = GETDNS_HS_FAILED; @@ -1045,12 +1270,13 @@ upstream_read_cb(void *userarg) // Now reschedule the writes on this connection upstream->event.write_cb = upstream_write_cb; - fprintf(stderr, "[STARTTLS] method: upstream_schedule_netreq -> re-instating writes\n"); + fprintf(stderr, "[TLS] method: upstream_schedule_netreq ->" + "re-instating writes\n"); GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); GETDNS_SCHEDULE_EVENT(upstream->loop, upstream->fd, TIMEOUT_FOREVER, &upstream->event); } else { - fprintf(stderr, "[STARTTLS] processing standard response....\n"); + fprintf(stderr, "[TLS]: processing standard response....\n"); priv_getdns_check_dns_req_complete(netreq->owner); } @@ -1072,201 +1298,6 @@ netreq_upstream_read_cb(void *userarg) upstream_read_cb(((getdns_network_req *)userarg)->upstream); } -/* stub_tcp_write(fd, tcp, netreq) - * will return STUB_TCP_AGAIN when we need to come back again, - * STUB_TCP_ERROR on error and a query_id on successfull sent. - */ -static int -stub_tcp_write(int fd, getdns_tcp_state *tcp, getdns_network_req *netreq) -{ - - size_t pkt_len = netreq->response - netreq->query; - ssize_t written; - uint16_t query_id; - intptr_t query_id_intptr; - - /* Do we have remaining data that we could not write before? */ - if (! tcp->write_buf) { - /* No, this is an initial write. Try to send - */ - - /* Not keeping connections open? Then the first random number - * will do as the query id. - * - * Otherwise find a unique query_id not already written (or in - * the write_queue) for that upstream. Register this netreq - * by query_id in the process. - */ - if ((*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) || - (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_UDP)) - query_id = arc4random(); - else do { - query_id = arc4random(); - query_id_intptr = (intptr_t)query_id; - netreq->node.key = (void *)query_id_intptr; - - } while (!getdns_rbtree_insert( - &netreq->upstream->netreq_by_query_id, &netreq->node)); - - GLDNS_ID_SET(netreq->query, query_id); - if (netreq->opt) { - /* no limits on the max udp payload size with tcp */ - gldns_write_uint16(netreq->opt + 3, 65535); - - if (netreq->owner->edns_cookies) { - netreq->response = attach_edns_cookie( - netreq->upstream, netreq->opt); - pkt_len = netreq->response - netreq->query; - gldns_write_uint16(netreq->query - 2, pkt_len); - } - } - /* We have an initialized packet buffer. - * Lets see how much of it we can write - */ -#ifdef USE_TCP_FASTOPEN - /* We use sendto() here which will do both a connect and send */ - written = sendto(fd, netreq->query - 2, pkt_len + 2, - MSG_FASTOPEN, (struct sockaddr *)&(netreq->upstream->addr), - netreq->upstream->addr_len); - /* If pipelining we will find that the connection is already up so - just fall back to a 'normal' write. */ - if (written == -1 && errno == EISCONN) - written = write(fd, netreq->query - 2, pkt_len + 2); - - if ((written == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK || - /* Add the error case where the connection is in progress which is when - a cookie is not available (e.g. when doing the first request to an - upstream). We must let the handshake complete since non-blocking. */ - errno == EINPROGRESS)) || - written < pkt_len + 2) { -#else - written = write(fd, netreq->query - 2, pkt_len + 2); - if ((written == -1 && (errno == EAGAIN || - errno == EWOULDBLOCK)) || - written < pkt_len + 2) { -#endif - /* We couldn't write the whole packet. - * We have to return with STUB_TCP_AGAIN. - * Setup tcp to track the state. - */ - tcp->write_buf = netreq->query - 2; - tcp->write_buf_len = pkt_len + 2; - tcp->written = written >= 0 ? written : 0; - - return STUB_TCP_AGAIN; - - } else if (written == -1) - return STUB_TCP_ERROR; - - /* We were able to write everything! Start reading. */ - return (int) query_id; - - } else {/* if (! tcp->write_buf) */ - - /* Coming back from an earlier unfinished write or handshake. - * Try to send remaining data */ - written = write(fd, tcp->write_buf + tcp->written, - tcp->write_buf_len - tcp->written); - if (written == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) - return STUB_TCP_AGAIN; - else - return STUB_TCP_ERROR; - } - tcp->written += written; - if (tcp->written < tcp->write_buf_len) - /* Still more to send */ - return STUB_TCP_AGAIN; - - query_id = (int)GLDNS_ID_WIRE(tcp->write_buf + 2); - /* Done. Start reading */ - tcp->write_buf = NULL; - return query_id; - - } /* if (! tcp->write_buf) */ -} - -static void -stub_tcp_write_cb(void *userarg) -{ - getdns_network_req *netreq = (getdns_network_req *)userarg; - getdns_dns_req *dnsreq = netreq->owner; - int q; - - switch ((q = stub_tcp_write(netreq->fd, &netreq->tcp, netreq))) { - case STUB_TCP_AGAIN: - return; - - case STUB_TCP_ERROR: - stub_erred(netreq); - return; - - default: - netreq->query_id = (uint16_t) q; - GETDNS_CLEAR_EVENT(dnsreq->loop, &netreq->event); - GETDNS_SCHEDULE_EVENT( - dnsreq->loop, netreq->fd, dnsreq->context->timeout, - getdns_eventloop_event_init(&netreq->event, netreq, - stub_tcp_read_cb, NULL, stub_timeout_cb)); - return; - } -} - -static int -stub_tls_write(getdns_upstream *upstream, getdns_tcp_state *tcp, - getdns_network_req *netreq) -{ - size_t pkt_len = netreq->response - netreq->query; - ssize_t written; - uint16_t query_id; - intptr_t query_id_intptr; - SSL* tls_obj = upstream->tls_obj; - - int q = check_tls(upstream); - if (q != 0) - return q; - - /* Do we have remaining data that we could not write before? */ - if (! tcp->write_buf) { - /* No, this is an initial write. Try to send - */ - - /* Find a unique query_id not already written (or in - * the write_queue) for that upstream. Register this netreq - * by query_id in the process. - */ - do { - query_id = ldns_get_random(); - query_id_intptr = (intptr_t)query_id; - netreq->node.key = (void *)query_id_intptr; - - } while (!getdns_rbtree_insert( - &netreq->upstream->netreq_by_query_id, &netreq->node)); - - GLDNS_ID_SET(netreq->query, query_id); - if (netreq->opt) - /* no limits on the max udp payload size with tcp */ - gldns_write_uint16(netreq->opt + 3, 65535); - - /* We have an initialized packet buffer. - * Lets see how much of it we can write */ - - // TODO[TLS]: Handle error cases, partial writes, renegotiation etc. - ERR_clear_error(); - written = SSL_write(tls_obj, netreq->query - 2, pkt_len + 2); - if (written <= 0) - return STUB_TCP_ERROR; - - /* We were able to write everything! Start reading. */ - return (int) query_id; - - } - - return STUB_TCP_ERROR; -} - - static void upstream_write_cb(void *userarg) { @@ -1275,14 +1306,11 @@ upstream_write_cb(void *userarg) getdns_dns_req *dnsreq = netreq->owner; int q; - fprintf(stderr,"[TLS]: **********CALLBACK***********\n"); fprintf(stderr,"[TLS]: WRITE(upstream_write_cb): upstream fd %d, SEND" " netreq %p \n", upstream->fd, netreq); - if (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_TLS || - (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && - upstream->tls_hs_state != GETDNS_HS_NONE)) + if (tls_requested(netreq) && tls_should_write(upstream)) q = stub_tls_write(upstream, &upstream->tcp, netreq); else q = stub_tcp_write(upstream->fd, &upstream->tcp, netreq); @@ -1330,7 +1358,7 @@ upstream_write_cb(void *userarg) if (upstream->starttls_req) { /* Now deschedule any further writes on this connection until we get the STARTTLS answer*/ - fprintf(stderr, "[STARTTLS] method: upstream_write_cb -> STARTTTLS -" + fprintf(stderr, "[TLS] method: upstream_write_cb -> STARTTTLS -" "clearing upstream->event.write_cb\n"); upstream->event.write_cb = NULL; GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); @@ -1358,57 +1386,78 @@ netreq_upstream_write_cb(void *userarg) upstream_write_cb(((getdns_network_req *)userarg)->upstream); } -static void -upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) -{ - /* We have a connected socket and a global event loop */ - assert(upstream->fd >= 0); - assert(upstream->loop); - - - fprintf(stderr,"[TLS]: SCHEDULE(upstream_schedule_netreq): fd %d\n", upstream->fd); - - /* Append netreq to write_queue */ - if (!upstream->write_queue) { - upstream->write_queue = upstream->write_queue_last = netreq; - upstream->event.write_cb = upstream_write_cb; - GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); - GETDNS_SCHEDULE_EVENT(upstream->loop, - upstream->fd, TIMEOUT_FOREVER, &upstream->event); - } else { - upstream->write_queue_last->write_queue_tail = netreq; - upstream->write_queue_last = netreq; - } -} +/*****************************/ +/* Upstream utility functions*/ +/*****************************/ static int -tcp_connect(getdns_upstream *upstream, getdns_base_transport_t transport) +upstream_transport_valid(getdns_upstream *upstream, + getdns_base_transport_t transport) { - - int fd = -1; - if ((fd = socket(upstream->addr.ss_family, SOCK_STREAM, IPPROTO_TCP)) == -1) - return -1; - - getdns_sock_nonblock(fd); -#ifdef USE_TCP_FASTOPEN - /* Leave the connect to the later call to sendto() if using TCP*/ - if (transport == GETDNS_BASE_TRANSPORT_TCP || - transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE || - transport == GETDNS_BASE_TRANSPORT_STARTTLS) - return fd; -#endif - if (connect(fd, (struct sockaddr *)&upstream->addr, - upstream->addr_len) == -1) { - if (errno != EINPROGRESS) { - close(fd); - return -1; - } + /* For single shot transports, use only the TCP upstream. */ + fprintf(stderr,"[TLS]: upstream_transport_valid checking upstream %p against transport %d\n",(void*)upstream, transport); + if (transport == GETDNS_BASE_TRANSPORT_UDP || + transport == GETDNS_BASE_TRANSPORT_TCP_SINGLE) + return (upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_TCP ? 1:0); + /* Allow TCP messages to be sent on a STARTTLS upstream that hasn't upgraded + * to avoid opening a new connection if one is aleady open. */ + if (transport == GETDNS_BASE_TRANSPORT_TCP && + upstream->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && + upstream->tls_hs_state == GETDNS_HS_FAILED) + return 1; + /* Otherwise, transport must match, and not have failed */ + if (upstream->dns_base_transport != transport) + return 0; + if (tls_failed(upstream)) { + fprintf(stderr,"[TLS]: tls_failed\n"); + return 0; } - return fd; + return 1; } +static getdns_upstream * +upstream_select(getdns_network_req *netreq, getdns_base_transport_t transport) +{ + getdns_upstream *upstream; + getdns_upstreams *upstreams = netreq->owner->upstreams; + size_t i; + + if (!upstreams->count) + return NULL; + + for (i = 0; i < upstreams->count; i++) + if (upstreams->upstreams[i].to_retry <= 0) + upstreams->upstreams[i].to_retry++; + + i = upstreams->current; + do { + if (upstreams->upstreams[i].to_retry > 0 && + upstream_transport_valid(&upstreams->upstreams[i], transport)) { + upstreams->current = i; + return &upstreams->upstreams[i]; + } + if (++i > upstreams->count) + i = 0; + } while (i != upstreams->current); + + upstream = upstreams->upstreams; + for (i = 0; i < upstreams->count; i++) + if (upstreams->upstreams[i].back_off < upstream->back_off && + upstream_transport_valid(&upstreams->upstreams[i], transport)) + upstream = &upstreams->upstreams[i]; + + /* Need to check again that the transport is valid */ + if (!upstream_transport_valid(upstream, transport)) + return NULL; + upstream->back_off++; + upstream->to_retry = 1; + upstreams->current = upstream - upstreams->upstreams; + return upstream; +} + + int -connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport, +upstream_connect(getdns_upstream *upstream, getdns_base_transport_t transport, getdns_dns_req *dnsreq) { int fd = -1; @@ -1433,11 +1482,11 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport case GETDNS_BASE_TRANSPORT_TLS: /* Use existing if available*/ - if (upstream->fd != 1 && tls_handshake_active(upstream->tls_hs_state)) + if (upstream->fd != -1 && !tls_failed(upstream)) return upstream->fd; fd = tcp_connect(upstream, transport); if (fd == -1) return -1; - upstream->tls_obj = create_tls_object(dnsreq->context, fd); + upstream->tls_obj = tls_create_object(dnsreq->context, fd); if (upstream->tls_obj == NULL) { fprintf(stderr,"[TLS]: could not create tls object\n"); close(fd); @@ -1460,11 +1509,11 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport upstream->loop = dnsreq->context->extension; upstream->fd = fd; upstream_schedule_netreq(upstream, starttls_netreq); - /* Schedule at least the timeout locally. + /* Schedule at least the timeout locally, but use half the context value. * And also the write if we perform a synchronous lookup */ /* TODO[TLS]: How should we handle timeout on STARTTLS negotiation?*/ GETDNS_SCHEDULE_EVENT( - dnsreq->loop, upstream->fd, dnsreq->context->timeout, + dnsreq->loop, upstream->fd, dnsreq->context->timeout / 2, getdns_eventloop_event_init(&starttls_netreq->event, starttls_netreq, NULL, (dnsreq->loop != upstream->loop ? netreq_upstream_write_cb : NULL), stub_timeout_cb)); @@ -1473,7 +1522,7 @@ connect_to_upstream(getdns_upstream *upstream, getdns_base_transport_t transport return -1; /* Nothing to do*/ } - fprintf(stderr,"[TLS]: CONNECT(connect_to_upstream):" + fprintf(stderr,"[TLS]: CONNECT(upstream_connect):" " created new connection %d\n", fd); return fd; } @@ -1484,10 +1533,12 @@ find_upstream_for_specific_transport(getdns_network_req *netreq, int *fd) { /* TODO[TLS]: Fallback through upstreams....?*/ - getdns_upstream *upstream = pick_upstream(netreq, transport); + getdns_upstream *upstream = upstream_select(netreq, transport); + fprintf(stderr,"[TLS]: find_upstream_for_specific_transport selected " + "upstream %p for %d\n", (void*)upstream, transport); if (!upstream) return NULL; - *fd = connect_to_upstream(upstream, transport, netreq->owner); + *fd = upstream_connect(upstream, transport, netreq->owner); return upstream; } @@ -1495,19 +1546,24 @@ static int find_upstream_for_netreq(getdns_network_req *netreq) { int fd = -1; - for (int i = 0; i < GETDNS_BASE_TRANSPORT_MAX && + int i = netreq->transport; + for (; i < GETDNS_BASE_TRANSPORT_MAX && netreq->dns_base_transports[i] != GETDNS_BASE_TRANSPORT_NONE; i++) { netreq->upstream = find_upstream_for_specific_transport(netreq, netreq->dns_base_transports[i], &fd); if (fd == -1 || !netreq->upstream) continue; - netreq->dns_base_transport = &netreq->dns_base_transports[i]; + netreq->transport = i; return fd; } return -1; } +/************************/ +/* Scheduling functions */ +/***********************/ + static int move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, getdns_upstream *new_upstream) @@ -1563,7 +1619,7 @@ move_netreq(getdns_network_req *netreq, getdns_upstream *upstream, stub_timeout_cb)); } } - netreq->dns_base_transport++; + netreq->transport++; return upstream->fd; } @@ -1574,26 +1630,50 @@ fallback_on_write(getdns_network_req *netreq) fprintf(stderr,"[TLS]: FALLBACK(fallback_on_write)\n"); /* TODO[TLS]: Fallback through all transports.*/ - getdns_base_transport_t *next_transport = netreq->dns_base_transport; - if (*(++next_transport) == GETDNS_BASE_TRANSPORT_NONE) + getdns_base_transport_t next_transport = + netreq->dns_base_transports[netreq->transport + 1]; + if (next_transport == GETDNS_BASE_TRANSPORT_NONE) return STUB_TCP_ERROR; - if (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_STARTTLS && - *next_transport == GETDNS_BASE_TRANSPORT_TCP) { + if (netreq->dns_base_transports[netreq->transport] == + GETDNS_BASE_TRANSPORT_STARTTLS && + next_transport == GETDNS_BASE_TRANSPORT_TCP) { fprintf(stderr,"[TLS]: FALLBACK(fallback_on_write) STARTTLS->TCP\n"); /* Special case where can stay on same upstream*/ - netreq->dns_base_transport++; + netreq->transport++; return netreq->upstream->fd; } getdns_upstream *upstream = netreq->upstream; int fd; getdns_upstream *new_upstream = - find_upstream_for_specific_transport(netreq, *next_transport, &fd); + find_upstream_for_specific_transport(netreq, next_transport, &fd); if (!new_upstream) return STUB_TCP_ERROR; return move_netreq(netreq, upstream, new_upstream); } +static void +upstream_schedule_netreq(getdns_upstream *upstream, getdns_network_req *netreq) +{ + /* We have a connected socket and a global event loop */ + assert(upstream->fd >= 0); + assert(upstream->loop); + + fprintf(stderr,"[TLS]: SCHEDULE(upstream_schedule_netreq): fd %d\n", upstream->fd); + + /* Append netreq to write_queue */ + if (!upstream->write_queue) { + upstream->write_queue = upstream->write_queue_last = netreq; + upstream->event.write_cb = upstream_write_cb; + GETDNS_CLEAR_EVENT(upstream->loop, &upstream->event); + GETDNS_SCHEDULE_EVENT(upstream->loop, + upstream->fd, TIMEOUT_FOREVER, &upstream->event); + } else { + upstream->write_queue_last->write_queue_tail = netreq; + upstream->write_queue_last = netreq; + } +} + getdns_return_t priv_getdns_submit_stub_request(getdns_network_req *netreq) { @@ -1606,14 +1686,15 @@ priv_getdns_submit_stub_request(getdns_network_req *netreq) if (fd == -1) return GETDNS_RETURN_GENERIC_ERROR; - switch(*netreq->dns_base_transport) { + getdns_base_transport_t transport = netreq->dns_base_transports[netreq->transport]; + switch(transport) { case GETDNS_BASE_TRANSPORT_UDP: case GETDNS_BASE_TRANSPORT_TCP_SINGLE: netreq->fd = fd; GETDNS_SCHEDULE_EVENT( dnsreq->loop, netreq->fd, dnsreq->context->timeout, getdns_eventloop_event_init(&netreq->event, netreq, - NULL, (*netreq->dns_base_transport == GETDNS_BASE_TRANSPORT_UDP ? + NULL, (transport == GETDNS_BASE_TRANSPORT_UDP ? stub_udp_write_cb: stub_tcp_write_cb), stub_timeout_cb)); return GETDNS_RETURN_GOOD; diff --git a/src/types-internal.h b/src/types-internal.h index 96216b06..ebcd4261 100644 --- a/src/types-internal.h +++ b/src/types-internal.h @@ -169,7 +169,7 @@ typedef enum getdns_base_transport { GETDNS_BASE_TRANSPORT_NONE = 0, GETDNS_BASE_TRANSPORT_UDP, GETDNS_BASE_TRANSPORT_TCP_SINGLE, /* To be removed? */ - GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback when scheduling*/ + GETDNS_BASE_TRANSPORT_STARTTLS, /* Define before TCP to allow fallback */ GETDNS_BASE_TRANSPORT_TCP, GETDNS_BASE_TRANSPORT_TLS, GETDNS_BASE_TRANSPORT_MAX @@ -203,7 +203,7 @@ typedef struct getdns_network_req struct getdns_upstream *upstream; int fd; getdns_base_transport_t dns_base_transports[GETDNS_BASE_TRANSPORT_MAX]; - getdns_base_transport_t *dns_base_transport; + int transport; getdns_eventloop_event event; getdns_tcp_state tcp; uint16_t query_id;