diff --git a/src/context.c b/src/context.c index 1c6caddd..45e4b472 100644 --- a/src/context.c +++ b/src/context.c @@ -470,29 +470,22 @@ upstreams_resize(getdns_upstreams *upstreams, size_t size) return r; } -static void -upstreams_cleanup(getdns_upstreams *upstreams) +void +priv_getdns_upstreams_dereference(getdns_upstreams *upstreams) { - if (!upstreams) - return; - for (int i = 0; i < (int)upstreams->count; i++) { - if (upstreams->upstreams[i].tls_obj != NULL) { - SSL_shutdown(upstreams->upstreams[i].tls_obj); - SSL_free(upstreams->upstreams[i].tls_obj); - upstreams->upstreams[i].tls_obj = NULL; - } - if (upstreams->upstreams[i].fd != -1) { - close(upstreams->upstreams[i].fd); - upstreams->upstreams[i].fd = -1; - } - } -} + size_t i; -static void -upstreams_dereference(getdns_upstreams *upstreams) -{ - if (upstreams && --upstreams->referenced == 0) + if (upstreams && --upstreams->referenced == 0) { + for (i = 0; i < upstreams->count; i++) { + if (upstreams->upstreams[i].tls_obj != NULL) { + SSL_shutdown(upstreams->upstreams[i].tls_obj); + SSL_free(upstreams->upstreams[i].tls_obj); + } + if (upstreams->upstreams[i].fd != -1) + close(upstreams->upstreams[i].fd); + } GETDNS_FREE(upstreams->mf, upstreams); + } } static uint8_t* @@ -537,11 +530,9 @@ upstream_ntop_buf(getdns_upstream *upstream, getdns_transport_t transport, if (transport == GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) (void) snprintf(buf + strlen(buf), len - strlen(buf), "@%d", GETDNS_TLS_PORT); - else { - if (upstream_port(upstream) != 53 && upstream_port(upstream) != 0) - (void) snprintf(buf + strlen(buf), len - strlen(buf), - "@%d", (int)upstream_port(upstream)); - } + else if (upstream_port(upstream) != 53 && upstream_port(upstream) != 0) + (void) snprintf(buf + strlen(buf), len - strlen(buf), + "@%d", (int)upstream_port(upstream)); } static int @@ -919,8 +910,7 @@ getdns_context_destroy(struct getdns_context *context) getdns_traverse_postorder(&context->local_hosts, destroy_local_host, context); - upstreams_cleanup(context->upstreams); - upstreams_dereference(context->upstreams); + priv_getdns_upstreams_dereference(context->upstreams); GETDNS_FREE(context->my_mf, context); } /* getdns_context_destroy */ @@ -1524,8 +1514,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, upstreams->count++; freeaddrinfo(ai); } - upstreams_dereference(context->upstreams); - /*Don't the existing upstreams need to be handled before overwritting here?*/ + priv_getdns_upstreams_dereference(context->upstreams); context->upstreams = upstreams; dispatch_updated(context, GETDNS_CONTEXT_CODE_UPSTREAM_RECURSIVE_SERVERS); @@ -1535,7 +1524,7 @@ getdns_context_set_upstream_recursive_servers(struct getdns_context *context, invalid_parameter: r = GETDNS_RETURN_INVALID_PARAMETER; error: - upstreams_dereference(upstreams); + priv_getdns_upstreams_dereference(upstreams); return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; } /* getdns_context_set_upstream_recursive_servers */ @@ -1731,14 +1720,14 @@ getdns_cancel_callback(getdns_context *context, } /* getdns_cancel_callback */ static getdns_return_t -ub_setup_stub(struct ub_ctx *ctx, struct getdns_context *context) +ub_setup_stub(struct ub_ctx *ctx, getdns_context *context) { getdns_return_t r = GETDNS_RETURN_GOOD; size_t i; getdns_upstream *upstream; char addr[1024]; - getdns_upstreams *upstreams = context->upstreams; + (void) ub_ctx_set_fwd(ctx, NULL); for (i = 0; i < upstreams->count; i++) { upstream = &upstreams->upstreams[i]; @@ -1840,21 +1829,21 @@ getdns_context_prepare_for_resolution(struct getdns_context *context, /* TODO: move this transport logic to a separate functions*/ if (context->resolution_type == GETDNS_RESOLUTION_STUB) { switch (context->dns_transport) { - case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: - case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: - if (context->tls_ctx == NULL) { + case GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN: + case GETDNS_TRANSPORT_TLS_FIRST_AND_FALL_BACK_TO_TCP_KEEP_CONNECTIONS_OPEN: + if (context->tls_ctx == NULL) { #ifdef HAVE_LIBTLS1_2 - /* Create client context, use TLS v1.2 only for now */ - context->tls_ctx = SSL_CTX_new(TLSv1_2_client_method()); + /* Create client context, use TLS v1.2 only for now */ + context->tls_ctx = SSL_CTX_new(TLSv1_2_client_method()); #endif - if(!context->tls_ctx && context->dns_transport == - GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) { - return GETDNS_RETURN_BAD_CONTEXT; - } + if(!context->tls_ctx && context->dns_transport == + GETDNS_TRANSPORT_TLS_ONLY_KEEP_CONNECTIONS_OPEN) { + return GETDNS_RETURN_BAD_CONTEXT; } - break; - default: - break; + } + break; + default: + break; } } /* Block use of TLS ONLY in recursive mode as it won't work */ diff --git a/src/context.h b/src/context.h index 569d42c5..12a2c263 100644 --- a/src/context.h +++ b/src/context.h @@ -233,4 +233,6 @@ void priv_getdns_context_ub_read_cb(void *userarg); getdns_base_transport_t priv_get_base_transport(getdns_transport_t transport, int level); +void priv_getdns_upstreams_dereference(getdns_upstreams *upstreams); + #endif /* _GETDNS_CONTEXT_H_ */ diff --git a/src/request-internal.c b/src/request-internal.c index 99e5c7d6..0889b3fc 100644 --- a/src/request-internal.c +++ b/src/request-internal.c @@ -177,8 +177,7 @@ dns_req_free(getdns_dns_req * req) return; } - if (req->upstreams && --req->upstreams->referenced == 0) - GETDNS_FREE(req->upstreams->mf, req->upstreams); + priv_getdns_upstreams_dereference(req->upstreams); /* cleanup network requests */ for (net_req = req->netreqs; *net_req; net_req++)