diff --git a/src/test/getdns_query.c b/src/test/getdns_query.c index a1feab82..3208351d 100644 --- a/src/test/getdns_query.c +++ b/src/test/getdns_query.c @@ -1104,10 +1104,10 @@ void read_line_cb(void *userarg) typedef struct dns_msg { getdns_transaction_t request_id; - getdns_dict *query; + getdns_dict *request; uint32_t rt; uint32_t do_bit; - uint32_t cd_bit; + uint32_t cd_bit; } dns_msg; #if defined(TRACE_DEBUG) && TRACE_DEBUG @@ -1128,16 +1128,18 @@ void servfail(dns_msg *msg, getdns_dict **resp_p) getdns_dict_destroy(*resp_p); if (!(*resp_p = getdns_dict_create())) return; - if (!getdns_dict_get_dict(msg->query, "header", &dict)) - getdns_dict_set_dict(*resp_p, "header", dict); - if (!getdns_dict_get_dict(msg->query, "question", &dict)) - getdns_dict_set_dict(*resp_p, "question", dict); + if (msg) { + if (!getdns_dict_get_dict(msg->request, "header", &dict)) + getdns_dict_set_dict(*resp_p, "header", dict); + if (!getdns_dict_get_dict(msg->request, "question", &dict)) + getdns_dict_set_dict(*resp_p, "question", dict); + (void) getdns_dict_set_int(*resp_p, "/header/ra", + msg->rt == GETDNS_RESOLUTION_RECURSING ? 1 : 0); + } (void) getdns_dict_set_int( *resp_p, "/header/rcode", GETDNS_RCODE_SERVFAIL); (void) getdns_dict_set_int(*resp_p, "/header/qr", 1); (void) getdns_dict_set_int(*resp_p, "/header/ad", 0); - (void) getdns_dict_set_int(*resp_p, "/header/ra", - msg->rt == GETDNS_RESOLUTION_RECURSING ? 1 : 0); } void request_cb(getdns_context *context, getdns_callback_type_t callback_type, @@ -1162,7 +1164,7 @@ void request_cb(getdns_context *context, getdns_callback_type_t callback_type, else if (!response) SERVFAIL("Missing response", 0, msg, &response); - else if ((r = getdns_dict_get_int(msg->query, "/header/id", &qid)) || + else if ((r = getdns_dict_get_int(msg->request, "/header/id", &qid)) || (r=getdns_dict_set_int(response,"/replies_tree/0/header/id",qid))) SERVFAIL("Could not copy QID", r, msg, &response); @@ -1206,7 +1208,7 @@ void request_cb(getdns_context *context, getdns_callback_type_t callback_type, _getdns_cancel_reply(context, msg->request_id); } if (msg) { - getdns_dict_destroy(msg->query); + getdns_dict_destroy(msg->request); free(msg); } if (response) @@ -1226,6 +1228,8 @@ void incoming_request_handler(getdns_context *context, getdns_list *list; getdns_transaction_t transaction_id; getdns_dict *qext = NULL; + dns_msg *msg = NULL; + getdns_dict *response = NULL; if (!query_extensions_spc && !(query_extensions_spc = getdns_dict_create())) @@ -1240,12 +1244,17 @@ void incoming_request_handler(getdns_context *context, fprintf(stderr, "Could not get query extensions from space: %s" , getdns_get_errorstr_by_id(r)); + if (!(msg = malloc(sizeof(dns_msg)))) + goto error; + /* pass through the header and the OPT record */ n = 0; + msg->request_id = request_id; + msg->request = request; msg->do_bit = msg->cd_bit = 0; msg->rt = GETDNS_RESOLUTION_STUB; - (void) getdns_dict_get_int(msg->query, "/additional/0/do", &msg->do_bit); - (void) getdns_dict_get_int(msg->query, "/header/cd", &msg->cd_bit); + (void) getdns_dict_get_int(request, "/additional/0/do", &msg->do_bit); + (void) getdns_dict_get_int(request, "/header/cd", &msg->cd_bit); if ((r = getdns_context_get_resolution_type(context, &msg->rt))) fprintf(stderr, "Could get resolution type from context: %s\n", getdns_get_errorstr_by_id(r)); @@ -1253,7 +1262,7 @@ void incoming_request_handler(getdns_context *context, if (msg->rt == GETDNS_RESOLUTION_STUB) { (void)getdns_dict_set_int( qext , "/add_opt_parameters/do_bit", msg->do_bit); - if (!getdns_dict_get_dict(msg->query, "header", &header)) + if (!getdns_dict_get_dict(request, "header", &header)) (void)getdns_dict_set_dict(qext, "header", header); } else if (getdns_dict_get_int(extensions,"dnssec_return_status",&n) || @@ -1266,27 +1275,27 @@ void incoming_request_handler(getdns_context *context, (void) getdns_dict_set_int(qext, "dnssec_return_all_statuses", msg->cd_bit ? GETDNS_EXTENSION_TRUE : GETDNS_EXTENSION_FALSE); - if (!getdns_dict_get_int(msg->query,"/additional/0/extended_rcode",&n)) + if (!getdns_dict_get_int(request, "/additional/0/extended_rcode",&n)) (void)getdns_dict_set_int( qext, "/add_opt_parameters/extended_rcode", n); - if (!getdns_dict_get_int(msg->query, "/additional/0/version", &n)) + if (!getdns_dict_get_int(request, "/additional/0/version", &n)) (void)getdns_dict_set_int( qext, "/add_opt_parameters/version", n); if (!getdns_dict_get_int( - msg->query, "/additional/0/udp_payload_size", &n)) + request, "/additional/0/udp_payload_size", &n)) (void)getdns_dict_set_int(qext, "/add_opt_parameters/maximum_udp_payload_size", n); if (!getdns_dict_get_list( - msg->query, "/additional/0/rdata/options", &list)) + request, "/additional/0/rdata/options", &list)) (void)getdns_dict_set_list(qext, "/add_opt_parameters/options", list); #if 0 do { - char *str = getdns_pretty_print_dict(msg->query); + char *str = getdns_pretty_print_dict(request); fprintf(stderr, "query: %s\n", str); free(str); str = getdns_pretty_print_dict(qext); @@ -1294,7 +1303,7 @@ void incoming_request_handler(getdns_context *context, free(str); } while (0); #endif - if ((r = getdns_dict_get_bindata(msg->query,"/question/qname",&qname))) + if ((r = getdns_dict_get_bindata(request,"/question/qname",&qname))) fprintf(stderr, "Could not get qname from query: %s\n", getdns_get_errorstr_by_id(r)); @@ -1302,11 +1311,11 @@ void incoming_request_handler(getdns_context *context, fprintf(stderr, "Could not convert qname: %s\n", getdns_get_errorstr_by_id(r)); - else if ((r=getdns_dict_get_int(msg->query,"/question/qtype",&qtype))) + else if ((r=getdns_dict_get_int(request,"/question/qtype",&qtype))) fprintf(stderr, "Could get qtype from query: %s\n", getdns_get_errorstr_by_id(r)); - else if ((r=getdns_dict_get_int(msg->query,"/question/qclass",&qclass))) + else if ((r=getdns_dict_get_int(request,"/question/qclass",&qclass))) fprintf(stderr, "Could get qclass from query: %s\n", getdns_get_errorstr_by_id(r)); @@ -1324,19 +1333,24 @@ void incoming_request_handler(getdns_context *context, free(qname_str); return; } - free(qname_str); +error: + if (qname_str) + free(qname_str); servfail(msg, &response); if (!response) /* No response, no reply */ - _getdns_cancel_reply(context, msg->request_id); + _getdns_cancel_reply(context, request_id); - else if ((r = getdns_reply(context, msg->request_id, response))) { + else if ((r = getdns_reply(context, request_id, response))) { fprintf(stderr, "Could not reply: %s\n", getdns_get_errorstr_by_id(r)); - _getdns_cancel_reply(context, msg->request_id); + _getdns_cancel_reply(context, request_id); + } + if (msg) { + if (msg->request) + getdns_dict_destroy(msg->request); + free(msg); } - getdns_dict_destroy(msg->query); - free(msg); if (response) getdns_dict_destroy(response); }