diff --git a/src/test/getdns_context_set_listen_addresses.c b/src/test/getdns_context_set_listen_addresses.c index 4b98a6f3..5b1edb9c 100644 --- a/src/test/getdns_context_set_listen_addresses.c +++ b/src/test/getdns_context_set_listen_addresses.c @@ -47,17 +47,6 @@ typedef struct listen_data { getdns_request_handler_t handler; } listen_data; - -typedef struct dns_msg { - listen_data *ld; -} dns_msg; - -typedef struct udp_msg { - listen_data *ld; - struct sockaddr_storage remote_in; - socklen_t addrlen; -} udp_msg; - typedef struct tcp_to_write tcp_to_write; struct tcp_to_write { size_t write_buf_len; @@ -66,10 +55,17 @@ struct tcp_to_write { uint8_t write_buf[]; }; -typedef struct downstream { +typedef struct connection { listen_data *ld; struct sockaddr_storage remote_in; socklen_t addrlen; +} connection; + +typedef struct tcp_connection { + listen_data *ld; + struct sockaddr_storage remote_in; + socklen_t addrlen; + int fd; getdns_eventloop_event event; @@ -80,14 +76,9 @@ typedef struct downstream { tcp_to_write *to_write; size_t to_answer; -} downstream; +} tcp_connection; -typedef struct tcp_msg { - listen_data *ld; - downstream *conn; -} tcp_msg; - -static void downstream_destroy(downstream *conn) +static void tcp_connection_destroy(tcp_connection *conn) { struct mem_funcs *mf; getdns_eventloop *loop; @@ -116,7 +107,7 @@ static void downstream_destroy(downstream *conn) static void tcp_write_cb(void *userarg) { - downstream *conn = (downstream *)userarg; + tcp_connection *conn = (tcp_connection *)userarg; struct mem_funcs *mf; getdns_eventloop *loop; @@ -131,7 +122,7 @@ static void tcp_write_cb(void *userarg) if (getdns_context_get_eventloop(conn->ld->context, &loop)) return; - /* Reset downstream idle timeout */ + /* Reset tcp_connection idle timeout */ loop->vmt->clear(loop, &conn->event); if (!conn->to_write) { @@ -147,7 +138,7 @@ static void tcp_write_cb(void *userarg) /* IO error, close connection */ conn->event.read_cb = conn->event.write_cb = conn->event.timeout_cb = NULL; - downstream_destroy(conn); + tcp_connection_destroy(conn); return; } to_write->written += written; @@ -166,12 +157,20 @@ void _getdns_cancel_reply(getdns_context *context, getdns_transaction_t request_id) { /* TODO: Check request_id at context->outbound_requests */ - dns_msg *msg = (dns_msg *)(intptr_t)request_id; + connection *conn = (connection *)(intptr_t)request_id; struct mem_funcs *mf; - if (context && msg && - (mf = priv_getdns_context_mf(msg->ld->context))) - GETDNS_FREE(*mf, msg); + if (!context || !conn) + return; + + if (conn->ld->transport == GETDNS_TRANSPORT_TCP) { + tcp_connection *conn = (tcp_connection *)(intptr_t)request_id; + if (conn->to_answer > 0) + conn->to_answer--; + + } else if (conn->ld->transport == GETDNS_TRANSPORT_UDP && + (mf = priv_getdns_context_mf(conn->ld->context))) + GETDNS_FREE(*mf, conn); } getdns_return_t @@ -179,35 +178,34 @@ getdns_reply( getdns_context *context, getdns_transaction_t request_id, getdns_dict *reply) { /* TODO: Check request_id at context->outbound_requests */ - dns_msg *msg = (dns_msg *)(intptr_t)request_id; + connection *conn = (connection *)(intptr_t)request_id; struct mem_funcs *mf; getdns_eventloop *loop; uint8_t buf[65536]; size_t len; getdns_return_t r; - if (!context || !reply || !msg) + if (!context || !reply || !conn) return GETDNS_RETURN_INVALID_PARAMETER; - if (!(mf = priv_getdns_context_mf(msg->ld->context))) + if (!(mf = priv_getdns_context_mf(conn->ld->context))) return GETDNS_RETURN_GENERIC_ERROR;; - if ((r = getdns_context_get_eventloop(msg->ld->context, &loop))) + if ((r = getdns_context_get_eventloop(conn->ld->context, &loop))) return r; len = sizeof(buf); if ((r = getdns_msg_dict2wire_buf(reply, buf, &len))) return r; - else if (msg->ld->transport == GETDNS_TRANSPORT_UDP) { - udp_msg *msg = (udp_msg *)(intptr_t)request_id; - - if (sendto(msg->ld->fd, buf, len, 0, - (struct sockaddr *)&msg->remote_in, msg->addrlen) == -1) + else if (conn->ld->transport == GETDNS_TRANSPORT_UDP) { + if (sendto(conn->ld->fd, buf, len, 0, + (struct sockaddr *)&conn->remote_in, conn->addrlen) == -1) ; /* IO error, TODO: cleanup this listener */ + GETDNS_FREE(*mf, conn); - } else if (msg->ld->transport == GETDNS_TRANSPORT_TCP) { - tcp_msg *msg = (tcp_msg *)(intptr_t)request_id; + } else if (conn->ld->transport == GETDNS_TRANSPORT_TCP) { + tcp_connection *conn = (tcp_connection *)(intptr_t)request_id; tcp_to_write **to_write_p; tcp_to_write *to_write = (tcp_to_write *)GETDNS_XMALLOC( *mf, uint8_t, sizeof(tcp_to_write) + len + 2); @@ -223,30 +221,29 @@ getdns_reply( (void) memcpy(to_write->write_buf + 2, buf, len); /* Appen to_write to conn->to_write list */ - for ( to_write_p = &msg->conn->to_write + for ( to_write_p = &conn->to_write ; *to_write_p ; to_write_p = &(*to_write_p)->next) ; /* pass */ *to_write_p = to_write; - loop->vmt->clear(loop, &msg->conn->event); - msg->conn->event.write_cb = tcp_write_cb; + loop->vmt->clear(loop, &conn->event); + conn->event.write_cb = tcp_write_cb; + if (conn->to_answer > 0) + conn->to_answer--; (void) loop->vmt->schedule(loop, - msg->conn->fd, DOWNSTREAM_IDLE_TIMEOUT, - &msg->conn->event); + conn->fd, DOWNSTREAM_IDLE_TIMEOUT, + &conn->event); } /* TODO: other transport types */ - if (msg) - GETDNS_FREE(*mf, msg); return r; } static void tcp_read_cb(void *userarg) { - downstream *conn = (downstream *)userarg; + tcp_connection *conn = (tcp_connection *)userarg; ssize_t bytes_read; - tcp_msg *msg; getdns_return_t r; struct mem_funcs *mf; getdns_eventloop *loop; @@ -260,7 +257,7 @@ static void tcp_read_cb(void *userarg) if ((r = getdns_context_get_eventloop(conn->ld->context, &loop))) return; - /* Reset downstream idle timeout */ + /* Reset tcp_connection idle timeout */ loop->vmt->clear(loop, &conn->event); (void) loop->vmt->schedule(loop, conn->fd, DOWNSTREAM_IDLE_TIMEOUT, &conn->event); @@ -270,12 +267,12 @@ static void tcp_read_cb(void *userarg) return; /* Come back to do the read later */ /* IO error, close connection */ - downstream_destroy(conn); + tcp_connection_destroy(conn); return; } if (bytes_read == 0) { /* remote end closed connection, cleanup */ - downstream_destroy(conn); + tcp_connection_destroy(conn); return; } assert(bytes_read <= conn->to_read); @@ -295,41 +292,33 @@ static void tcp_read_cb(void *userarg) if (!(conn->read_buf = GETDNS_XMALLOC( *mf, uint8_t, conn->read_buf_len))) { /* Memory error */ - downstream_destroy(conn); + tcp_connection_destroy(conn); return; } } if (conn->to_read < 12) { /* Request smaller than DNS header, FORMERR */ - downstream_destroy(conn); + tcp_connection_destroy(conn); return; } conn->read_pos = conn->read_buf; return; /* Read DNS message */ } - if (!(msg = GETDNS_MALLOC(*mf, tcp_msg))) { - /* Memory error */ - downstream_destroy(conn); - return; - } - msg->ld = conn->ld; - msg->conn = conn; if ((r = getdns_wire2msg_dict(conn->read_buf, (conn->read_pos - conn->read_buf), &request_dict))) ; /* FROMERR on input, ignore */ else { - conn->to_answer += 1; + conn->to_answer++; /* Call request handler */ conn->ld->handler( - conn->ld->context, request_dict, (intptr_t)msg); + conn->ld->context, request_dict, (intptr_t)conn); conn->read_pos = conn->read_buf; conn->to_read = 2; return; /* Read more requests */ } - GETDNS_FREE(*mf, msg); conn->read_pos = conn->read_buf; conn->to_read = 2; /* Read more requests */ @@ -337,17 +326,28 @@ static void tcp_read_cb(void *userarg) static void tcp_timeout_cb(void *userarg) { - downstream *conn = (downstream *)userarg; + tcp_connection *conn = (tcp_connection *)userarg; assert(userarg); - downstream_destroy(conn); + if (conn->to_answer) { + getdns_eventloop *loop; + + if (getdns_context_get_eventloop(conn->ld->context, &loop)) + return; + + loop->vmt->clear(loop, &conn->event); + (void) loop->vmt->schedule(loop, + conn->fd, DOWNSTREAM_IDLE_TIMEOUT, + &conn->event); + } else + tcp_connection_destroy(conn); } static void tcp_accept_cb(void *userarg) { listen_data *ld = (listen_data *)userarg; - downstream *conn; + tcp_connection *conn; struct mem_funcs *mf; getdns_eventloop *loop; getdns_return_t r; @@ -360,10 +360,10 @@ static void tcp_accept_cb(void *userarg) if ((r = getdns_context_get_eventloop(ld->context, &loop))) return; - if (!(conn = GETDNS_MALLOC(*mf, downstream))) + if (!(conn = GETDNS_MALLOC(*mf, tcp_connection))) return; - (void) memset(conn, 0, sizeof(downstream)); + (void) memset(conn, 0, sizeof(tcp_connection)); conn->ld = ld; conn->addrlen = sizeof(conn->remote_in); @@ -390,7 +390,7 @@ static void tcp_accept_cb(void *userarg) static void udp_read_cb(void *userarg) { listen_data *ld = (listen_data *)userarg; - udp_msg *msg; + connection *conn; struct mem_funcs *mf; getdns_dict *request_dict; @@ -404,13 +404,13 @@ static void udp_read_cb(void *userarg) if (!(mf = priv_getdns_context_mf(ld->context))) return; - if (!(msg = GETDNS_MALLOC(*mf, udp_msg))) + if (!(conn = GETDNS_MALLOC(*mf, connection))) return; - msg->ld = ld; - msg->addrlen = sizeof(msg->remote_in); + conn->ld = ld; + conn->addrlen = sizeof(conn->remote_in); if ((len = recvfrom(ld->fd, buf, sizeof(buf), 0, - (struct sockaddr *)&msg->remote_in, &msg->addrlen)) == -1) + (struct sockaddr *)&conn->remote_in, &conn->addrlen)) == -1) ; /* IO error, TODO: cleanup this listener */ else if ((r = getdns_wire2msg_dict(buf, len, &request_dict))) @@ -418,10 +418,10 @@ static void udp_read_cb(void *userarg) else { /* Call request handler */ - ld->handler(ld->context, request_dict, (intptr_t)msg); + ld->handler(ld->context, request_dict, (intptr_t)conn); return; } - GETDNS_FREE(*mf, msg); + GETDNS_FREE(*mf, conn); } getdns_return_t getdns_context_set_listen_addresses(getdns_context *context,