diff --git a/src/context.c b/src/context.c index d6f0f1fb..f9b9d8b1 100644 --- a/src/context.c +++ b/src/context.c @@ -416,6 +416,7 @@ getdns_context_create_with_extended_memory_functions( if (!result) { return GETDNS_RETURN_GENERIC_ERROR; } + result->processing = 0; result->destroying = 0; result->my_mf.mf_arg = userarg; result->my_mf.mf.ext.malloc = malloc; @@ -534,6 +535,12 @@ getdns_context_destroy(struct getdns_context *context) if (context == NULL) { return; } + // If being destroyed during getdns callback, just flag it + // and destroy. See getdns_context_process_async + if (context->processing > 0) { + context->processing++; + return; + } context->destroying = 1; cancel_outstanding_requests(context, 1); getdns_extension_detach_eventloop(context); @@ -1505,12 +1512,24 @@ getdns_context_get_num_pending_requests(struct getdns_context* context, /* process async reqs */ getdns_return_t getdns_context_process_async(struct getdns_context* context) { RETURN_IF_NULL(context, GETDNS_RETURN_INVALID_PARAMETER); + context->processing = 1; if (ub_poll(context->unbound_ctx)) { if (ub_process(context->unbound_ctx) != 0) { /* need an async return code? */ return GETDNS_RETURN_GENERIC_ERROR; } } + if (context->processing > 1) { + // destroyed during callbacks + // clear flag so destroy continues + context->processing = 0; + getdns_context_destroy(context); + // return bad context now that the context + // is destroyed + return GETDNS_RETURN_BAD_CONTEXT; + } + // reset the processing flag + context->processing = 0; if (context->extension != NULL) { /* no need to process timeouts since it is delegated * to the extension */ diff --git a/src/context.h b/src/context.h index 6ccc5164..211af949 100644 --- a/src/context.h +++ b/src/context.h @@ -90,6 +90,7 @@ struct getdns_context { getdns_update_callback update_callback; + int processing; int destroying; struct mem_funcs mf; diff --git a/src/extension/libev.c b/src/extension/libev.c index 4a506ee7..ba978363 100644 --- a/src/extension/libev.c +++ b/src/extension/libev.c @@ -59,7 +59,10 @@ request_count_changed(uint32_t request_count, struct getdns_libev_data *ev_data) static void getdns_libev_cb(struct ev_loop *loop, struct ev_io *handle, int revents) { struct getdns_context* context = (struct getdns_context*) handle->data; - getdns_context_process_async(context); + if (getdns_context_process_async(context) == GETDNS_RETURN_BAD_CONTEXT) { + // context destroyed + return; + } uint32_t rc = getdns_context_get_num_pending_requests(context, NULL); struct getdns_libev_data* ev_data = (struct getdns_libev_data*) getdns_context_get_extension_data(context); diff --git a/src/extension/libevent.c b/src/extension/libevent.c index d0652f80..00b5e81e 100644 --- a/src/extension/libevent.c +++ b/src/extension/libevent.c @@ -88,7 +88,10 @@ request_count_changed(uint32_t request_count, struct event_data *ev_data) { static void getdns_libevent_cb(evutil_socket_t fd, short what, void *userarg) { struct getdns_context* context = (struct getdns_context*) userarg; - getdns_context_process_async(context); + if (getdns_context_process_async(context) == GETDNS_RETURN_BAD_CONTEXT) { + // context destroyed + return; + } uint32_t rc = getdns_context_get_num_pending_requests(context, NULL); struct event_data* ev_data = (struct event_data*) getdns_context_get_extension_data(context); diff --git a/src/extension/libuv.c b/src/extension/libuv.c index 9f6eac62..40a5f9e8 100644 --- a/src/extension/libuv.c +++ b/src/extension/libuv.c @@ -53,7 +53,10 @@ static void request_count_changed(uint32_t request_count, struct getdns_libuv_da static void getdns_libuv_cb(uv_poll_t* handle, int status, int events) { struct getdns_context* context = (struct getdns_context*) handle->data; - getdns_context_process_async(context); + if (getdns_context_process_async(context) == GETDNS_RETURN_BAD_CONTEXT) { + // context destroyed + return; + } uint32_t rc = getdns_context_get_num_pending_requests(context, NULL); struct getdns_libuv_data* uv_data = (struct getdns_libuv_data*) getdns_context_get_extension_data(context); diff --git a/src/test/check_getdns_common.c b/src/test/check_getdns_common.c index e0be357d..f122d1fa 100644 --- a/src/test/check_getdns_common.c +++ b/src/test/check_getdns_common.c @@ -258,6 +258,17 @@ void assert_ptr_in_answer(struct extracted_response *ex_response) ck_assert_msg(ptr_records == 1, "Expected to find one PTR record in answer section, got %d", ptr_records); } +void destroy_callbackfn(struct getdns_context *context, + getdns_callback_type_t callback_type, + struct getdns_dict *response, + void *userarg, + getdns_transaction_t transaction_id) { + int* flag = (int*)userarg; + *flag = 1; + getdns_dict_destroy(response); + getdns_context_destroy(context); +} + /* * callbackfn is the callback function given to all * asynchronous query tests. It is expected to only diff --git a/src/test/check_getdns_common.h b/src/test/check_getdns_common.h index 81cbf0a6..397d4f1a 100644 --- a/src/test/check_getdns_common.h +++ b/src/test/check_getdns_common.h @@ -186,6 +186,12 @@ */ void assert_ptr_in_answer(struct extracted_response *ex_response); + + void destroy_callbackfn(struct getdns_context *context, + getdns_callback_type_t callback_type, + struct getdns_dict *response, + void *userarg, + getdns_transaction_t transaction_id); /* * callbackfn is the callback function given to all * asynchronous query tests. It is expected to only diff --git a/src/test/check_getdns_context_destroy.h b/src/test/check_getdns_context_destroy.h index 7a836dd2..b7c80595 100644 --- a/src/test/check_getdns_context_destroy.h +++ b/src/test/check_getdns_context_destroy.h @@ -177,6 +177,31 @@ } END_TEST + START_TEST (getdns_context_destroy_7) + { + /* + * destroy called immediately following getdns_address + * expect: callback should be called before getdns_context_destroy() returns + */ + struct getdns_context *context = NULL; + void* eventloop = NULL; + getdns_transaction_t transaction_id = 0; + + int flag = 0; /* Initialize flag */ + + CONTEXT_CREATE(TRUE); + EVENT_BASE_CREATE; + + ASSERT_RC(getdns_address(context, "google.com", NULL, + &flag, &transaction_id, destroy_callbackfn), + GETDNS_RETURN_GOOD, "Return code from getdns_address()"); + + RUN_EVENT_LOOP; + + ck_assert_msg(flag == 1, "flag should == 1, got %d", flag); + } + END_TEST + void verify_getdns_context_destroy(struct extracted_response *ex_response) { /* @@ -208,6 +233,7 @@ tcase_add_test(tc_pos, getdns_context_destroy_4); tcase_add_test(tc_pos, getdns_context_destroy_5); tcase_add_test(tc_pos, getdns_context_destroy_6); + tcase_add_test(tc_pos, getdns_context_destroy_7); suite_add_tcase(s, tc_pos); return s; diff --git a/src/test/check_getdns_selectloop.c b/src/test/check_getdns_selectloop.c index 7724a08b..3baf3b9a 100644 --- a/src/test/check_getdns_selectloop.c +++ b/src/test/check_getdns_selectloop.c @@ -44,7 +44,10 @@ void run_event_loop_impl(struct getdns_context* context, void* eventloop) { FD_ZERO(&read_fds); FD_SET(fd, &read_fds); select(fd + 1, &read_fds, NULL, NULL, &tv); - getdns_context_process_async(context); + if (getdns_context_process_async(context) != GETDNS_RETURN_GOOD) { + // context destroyed + break; + } } }