diff --git a/src/test/getdns_context_set_listen_addresses.c b/src/test/getdns_context_set_listen_addresses.c index ff96ec40..86faee28 100644 --- a/src/test/getdns_context_set_listen_addresses.c +++ b/src/test/getdns_context_set_listen_addresses.c @@ -35,20 +35,39 @@ #define DOWNSTREAM_IDLE_TIMEOUT 5000 #define TCP_LISTEN_BACKLOG 16 +typedef struct listen_set listen_set; +typedef enum listen_set_action { + to_stay, to_add, to_remove +} listen_set_action; + typedef struct connection connection; -typedef struct listen_data { +typedef struct listener listener; +struct listener { getdns_eventloop_event event; socklen_t addr_len; struct sockaddr_storage addr; int fd; getdns_transport_list_t transport; - getdns_context *context; - /* Should be per context eventually */ - getdns_request_handler_t handler; + listen_set_action action; + listener *to_replace; + listen_set *set; + /* Should be per context eventually */ connection *connections; -} listen_data; +}; + +/* listen set is temporarily a singly linked list node, to associate the set + * with a context. Eventually it has to become a context attribute. + */ +struct listen_set { + getdns_context *context; + listen_set *next; + getdns_request_handler_t handler; + + size_t count; + listener items[]; +}; typedef struct tcp_to_write tcp_to_write; struct tcp_to_write { @@ -59,7 +78,7 @@ struct tcp_to_write { }; struct connection { - listen_data *ld; + listener *l; struct sockaddr_storage remote_in; socklen_t addrlen; @@ -69,7 +88,7 @@ struct connection { typedef struct tcp_connection { /* A TCP connection is a connection */ - listen_data *ld; + listener *l; struct sockaddr_storage remote_in; socklen_t addrlen; @@ -89,6 +108,8 @@ typedef struct tcp_connection { size_t to_answer; } tcp_connection; + +static void free_listen_set_when_done(listen_set *set); static void tcp_connection_destroy(tcp_connection *conn) { struct mem_funcs *mf; @@ -96,26 +117,31 @@ static void tcp_connection_destroy(tcp_connection *conn) tcp_to_write *cur, *next; - if (!(mf = priv_getdns_context_mf(conn->ld->context))) + if (!(mf = priv_getdns_context_mf(conn->l->set->context))) return; - if (getdns_context_get_eventloop(conn->ld->context, &loop)) + if (getdns_context_get_eventloop(conn->l->set->context, &loop)) return; if (conn->event.read_cb||conn->event.write_cb||conn->event.timeout_cb) loop->vmt->clear(loop, &conn->event); - if (conn->fd >= 0) { - if (close(conn->fd) == -1) - ; /* Whatever */ - } + + if (conn->fd >= 0) + (void) close(conn->fd); GETDNS_FREE(*mf, conn->read_buf); + for (cur = conn->to_write; cur; cur = next) { next = cur->next; GETDNS_FREE(*mf, cur); } + if (conn->to_answer > 0) + return; + /* Unlink this connection */ if ((*conn->prev_next = conn->next)) conn->next->prev_next = conn->prev_next; + + free_listen_set_when_done(conn->l->set); GETDNS_FREE(*mf, conn); } @@ -130,15 +156,15 @@ static void tcp_write_cb(void *userarg) assert(userarg); - if (!(mf = priv_getdns_context_mf(conn->ld->context))) + if (!(mf = priv_getdns_context_mf(conn->l->set->context))) return; - if (getdns_context_get_eventloop(conn->ld->context, &loop)) + if (getdns_context_get_eventloop(conn->l->set->context, &loop)) return; /* Reset tcp_connection idle timeout */ loop->vmt->clear(loop, &conn->event); - + if (!conn->to_write) { conn->event.write_cb = NULL; (void) loop->vmt->schedule(loop, conn->fd, @@ -146,7 +172,8 @@ static void tcp_write_cb(void *userarg) return; } to_write = conn->to_write; - if ((written = write(conn->fd, &to_write->write_buf[to_write->written], + if (conn->fd == -1 || + (written = write(conn->fd, &to_write->write_buf[to_write->written], to_write->write_buf_len - to_write->written)) == -1) { /* IO error, close connection */ @@ -177,18 +204,22 @@ _getdns_cancel_reply(getdns_context *context, getdns_transaction_t request_id) if (!context || !conn) return; - if (conn->ld->transport == GETDNS_TRANSPORT_TCP) { + if (conn->l->transport == GETDNS_TRANSPORT_TCP) { tcp_connection *conn = (tcp_connection *)(intptr_t)request_id; - if (conn->to_answer > 0) - conn->to_answer--; - } else if (conn->ld->transport == GETDNS_TRANSPORT_UDP && - (mf = priv_getdns_context_mf(conn->ld->context))) { + if (conn->to_answer > 0 && --conn->to_answer == 0 && + conn->fd == -1) + tcp_connection_destroy(conn); + + } else if (conn->l->transport == GETDNS_TRANSPORT_UDP && + (mf = priv_getdns_context_mf(conn->l->set->context))) { + listen_set *set = conn->l->set; /* Unlink this connection */ if ((*conn->prev_next = conn->next)) conn->next->prev_next = conn->prev_next; GETDNS_FREE(*mf, conn); + free_listen_set_when_done(set); } } @@ -207,34 +238,47 @@ getdns_reply( if (!context || !reply || !conn) return GETDNS_RETURN_INVALID_PARAMETER; - if (!(mf = priv_getdns_context_mf(conn->ld->context))) + if (!(mf = priv_getdns_context_mf(conn->l->set->context))) return GETDNS_RETURN_GENERIC_ERROR;; - if ((r = getdns_context_get_eventloop(conn->ld->context, &loop))) + if ((r = getdns_context_get_eventloop(conn->l->set->context, &loop))) return r; len = sizeof(buf); if ((r = getdns_msg_dict2wire_buf(reply, buf, &len))) return r; - else if (conn->ld->transport == GETDNS_TRANSPORT_UDP) { - if (sendto(conn->ld->fd, buf, len, 0, - (struct sockaddr *)&conn->remote_in, conn->addrlen) == -1) - ; /* IO error, TODO: cleanup this listener */ + else if (conn->l->transport == GETDNS_TRANSPORT_UDP) { + listener *l = conn->l; + if (conn->l->fd >= 0 && sendto(conn->l->fd, buf, len, 0, + (struct sockaddr *)&conn->remote_in, conn->addrlen) == -1) { + /* IO error, cleanup this listener */ + loop->vmt->clear(loop, &conn->l->event); + close(conn->l->fd); + conn->l->fd = -1; + } /* Unlink this connection */ if ((*conn->prev_next = conn->next)) conn->next->prev_next = conn->prev_next; GETDNS_FREE(*mf, conn); + if (l->fd < 0) + free_listen_set_when_done(l->set); - } else if (conn->ld->transport == GETDNS_TRANSPORT_TCP) { + } else if (conn->l->transport == GETDNS_TRANSPORT_TCP) { tcp_connection *conn = (tcp_connection *)(intptr_t)request_id; tcp_to_write **to_write_p; - tcp_to_write *to_write = (tcp_to_write *)GETDNS_XMALLOC( - *mf, uint8_t, sizeof(tcp_to_write) + len + 2); + tcp_to_write *to_write; - if (!to_write) + if (conn->fd == -1) { + if (conn->to_answer > 0) + --conn->to_answer; + tcp_connection_destroy(conn); + return GETDNS_RETURN_GOOD; + } + if (!(to_write = (tcp_to_write *)GETDNS_XMALLOC( + *mf, uint8_t, sizeof(tcp_to_write) + len + 2))) return GETDNS_RETURN_MEMORY_ERROR; to_write->write_buf_len = len + 2; @@ -275,10 +319,10 @@ static void tcp_read_cb(void *userarg) assert(userarg); - if (!(mf = priv_getdns_context_mf(conn->ld->context))) + if (!(mf = priv_getdns_context_mf(conn->l->set->context))) return; - if ((r = getdns_context_get_eventloop(conn->ld->context, &loop))) + if ((r = getdns_context_get_eventloop(conn->l->set->context, &loop))) return; /* Reset tcp_connection idle timeout */ @@ -336,8 +380,8 @@ static void tcp_read_cb(void *userarg) conn->to_answer++; /* Call request handler */ - conn->ld->handler( - conn->ld->context, request_dict, (intptr_t)conn); + conn->l->set->handler( + conn->l->set->context, request_dict, (intptr_t)conn); conn->read_pos = conn->read_buf; conn->to_read = 2; @@ -357,7 +401,7 @@ static void tcp_timeout_cb(void *userarg) if (conn->to_answer) { getdns_eventloop *loop; - if (getdns_context_get_eventloop(conn->ld->context, &loop)) + if (getdns_context_get_eventloop(conn->l->set->context, &loop)) return; loop->vmt->clear(loop, &conn->event); @@ -370,7 +414,7 @@ static void tcp_timeout_cb(void *userarg) static void tcp_accept_cb(void *userarg) { - listen_data *ld = (listen_data *)userarg; + listener *l = (listener *)userarg; tcp_connection *conn; struct mem_funcs *mf; getdns_eventloop *loop; @@ -378,10 +422,10 @@ static void tcp_accept_cb(void *userarg) assert(userarg); - if (!(mf = priv_getdns_context_mf(ld->context))) + if (!(mf = priv_getdns_context_mf(l->set->context))) return; - if ((r = getdns_context_get_eventloop(ld->context, &loop))) + if ((r = getdns_context_get_eventloop(l->set->context, &loop))) return; if (!(conn = GETDNS_MALLOC(*mf, tcp_connection))) @@ -389,12 +433,16 @@ static void tcp_accept_cb(void *userarg) (void) memset(conn, 0, sizeof(tcp_connection)); - conn->ld = ld; + conn->l = l; conn->addrlen = sizeof(conn->remote_in); - if ((conn->fd = accept(ld->fd, + if ((conn->fd = accept(l->fd, (struct sockaddr *)&conn->remote_in, &conn->addrlen)) == -1) { - /* IO error, TODO: cleanup this listener? */ + /* IO error, cleanup this listener */ + loop->vmt->clear(loop, &l->event); + close(l->fd); + l->fd = -1; GETDNS_FREE(*mf, conn); + return; } if (!(conn->read_buf = malloc(DNS_REQUEST_SZ))) { /* Memory error */ @@ -409,10 +457,10 @@ static void tcp_accept_cb(void *userarg) conn->event.timeout_cb = tcp_timeout_cb; /* Insert connection */ - if ((conn->next = ld->connections)) + if ((conn->next = l->connections)) conn->next->prev_next = &conn->next; - conn->prev_next = &ld->connections; - ld->connections = (connection *)conn; + conn->prev_next = &l->connections; + l->connections = (connection *)conn; (void) loop->vmt->schedule(loop, conn->fd, DOWNSTREAM_IDLE_TIMEOUT, &conn->event); @@ -420,9 +468,10 @@ static void tcp_accept_cb(void *userarg) static void udp_read_cb(void *userarg) { - listen_data *ld = (listen_data *)userarg; + listener *l = (listener *)userarg; connection *conn; struct mem_funcs *mf; + getdns_eventloop *loop; getdns_dict *request_dict; /* Maximum reasonable size for requests */ @@ -432,35 +481,194 @@ static void udp_read_cb(void *userarg) assert(userarg); - if (!(mf = priv_getdns_context_mf(ld->context))) + if (l->fd == -1) + return; + + if (!(mf = priv_getdns_context_mf(l->set->context))) + return; + + if ((r = getdns_context_get_eventloop(l->set->context, &loop))) return; if (!(conn = GETDNS_MALLOC(*mf, connection))) return; - conn->ld = ld; + conn->l = l; conn->addrlen = sizeof(conn->remote_in); - if ((len = recvfrom(ld->fd, buf, sizeof(buf), 0, - (struct sockaddr *)&conn->remote_in, &conn->addrlen)) == -1) - ; /* IO error, TODO: cleanup this listener */ + if ((len = recvfrom(l->fd, buf, sizeof(buf), 0, + (struct sockaddr *)&conn->remote_in, &conn->addrlen)) == -1) { + /* IO error, cleanup this listener. */ + loop->vmt->clear(loop, &l->event); + close(l->fd); + l->fd = -1; - else if ((r = getdns_wire2msg_dict(buf, len, &request_dict))) + } else if ((r = getdns_wire2msg_dict(buf, len, &request_dict))) ; /* FROMERR on input, ignore */ else { /* Insert connection */ - if ((conn->next = ld->connections)) + if ((conn->next = l->connections)) conn->next->prev_next = &conn->next; - conn->prev_next = &ld->connections; - ld->connections = conn; + conn->prev_next = &l->connections; + l->connections = conn; /* Call request handler */ - ld->handler(ld->context, request_dict, (intptr_t)conn); + l->set->handler(l->set->context, request_dict, (intptr_t)conn); return; } GETDNS_FREE(*mf, conn); } +static void rm_listen_set(listen_set **root, listen_set *set) +{ + assert(root); + + while (*root && *root != set) + root = &(*root)->next; + + *root = set->next; + set->next = NULL; +} + +static listen_set *lookup_listen_set(listen_set *root, getdns_context *key) +{ + while (root && root->context != key) + root = root->next; + + return root; +} + +static void free_listen_set_when_done(listen_set *set) +{ + struct mem_funcs *mf; + size_t i; + + assert(set); + assert(set->context); + + if (!(mf = priv_getdns_context_mf(set->context))) + return; + + for (i = 0; i < set->count; i++) { + listener *l = &set->items[i]; + + if (l->fd >= 0) + return; + + if (l->connections) + return; + } + GETDNS_FREE(*mf, set); +} + +static void remove_listeners(listen_set *set) +{ + struct mem_funcs *mf; + getdns_eventloop *loop; + size_t i; + + assert(set); + assert(set->context); + + if (!(mf = priv_getdns_context_mf(set->context))) + return; + + if (getdns_context_get_eventloop(set->context, &loop)) + return; + + for (i = 0; i < set->count; i++) { + listener *l = &set->items[i]; + tcp_connection **conn_p; + + if (l->action != to_remove || l->fd == -1) + continue; + + loop->vmt->clear(loop, &l->event); + close(l->fd); + l->fd = -1; + + if (l->transport != GETDNS_TRANSPORT_TCP) + continue; + + conn_p = (tcp_connection **)&l->connections; + while (*conn_p) { + tcp_connection_destroy(*conn_p); + if (*conn_p && (*conn_p)->to_answer > 0) + conn_p = (tcp_connection **)&(*conn_p)->next; + } + } + free_listen_set_when_done(set); +} + +static getdns_return_t add_listeners(listen_set *set) +{ + static const int enable = 1; + + struct mem_funcs *mf; + getdns_eventloop *loop; + size_t i; + getdns_return_t r; + + assert(set); + assert(set->context); + + if (!(mf = priv_getdns_context_mf(set->context))) + return GETDNS_RETURN_GENERIC_ERROR; + + if ((r = getdns_context_get_eventloop(set->context, &loop))) + return r; + + r = GETDNS_RETURN_GENERIC_ERROR; + for (i = 0; i < set->count; i++) { + listener *l = &set->items[i]; + + if (l->action != to_add) + continue; + + if (l->transport != GETDNS_TRANSPORT_UDP && + l->transport != GETDNS_TRANSPORT_TCP) + continue; + + if ((l->fd = socket(l->addr.ss_family, + ( l->transport == GETDNS_TRANSPORT_UDP + ? SOCK_DGRAM : SOCK_STREAM), 0)) == -1) + /* IO error */ + break; + + if (setsockopt(l->fd, SOL_SOCKET, SO_REUSEADDR, + &enable, sizeof(int)) < 0) + ; /* Ignore */ + + if (bind(l->fd, (struct sockaddr *)&l->addr, + l->addr_len) == -1) + /* IO error */ + break; + + if (l->transport == GETDNS_TRANSPORT_UDP) { + l->event.userarg = l; + l->event.read_cb = udp_read_cb; + if ((r = loop->vmt->schedule( + loop, l->fd, -1, &l->event))) + break; + + } else if (listen(l->fd, TCP_LISTEN_BACKLOG) == -1) + /* IO error */ + break; + + else { + l->event.userarg = l; + l->event.read_cb = tcp_accept_cb; + if ((r = loop->vmt->schedule( + loop, l->fd, -1, &l->event))) + break; + } + } + if (i < set->count) + return r; + + return GETDNS_RETURN_GOOD; +} + getdns_return_t getdns_context_set_listen_addresses(getdns_context *context, getdns_request_handler_t request_handler, getdns_list *listen_addresses) { @@ -469,20 +677,19 @@ getdns_return_t getdns_context_set_listen_addresses(getdns_context *context, static const uint32_t transport_ports[] = { 53, 53 }; static const size_t n_transports = sizeof( listen_transports) / sizeof(*listen_transports); + static listen_set *root = NULL; - /* Things that should (eventually) be stored in the getdns_context */ - size_t listen_count; - listen_data *listening; - struct mem_funcs *mf; - getdns_eventloop *loop; + listen_set *current_set; + listen_set *new_set; + size_t new_set_count; + + struct mem_funcs *mf; + getdns_eventloop *loop; /* auxiliary variables */ getdns_return_t r; size_t i; - size_t t; struct addrinfo hints; - char addrstr[1024], portstr[1024], *eos; - const int enable = 1; /* For SO_REUSEADDR */ if (!(mf = priv_getdns_context_mf(context))) return GETDNS_RETURN_GENERIC_ERROR; @@ -490,29 +697,50 @@ getdns_return_t getdns_context_set_listen_addresses(getdns_context *context, if ((r = getdns_context_get_eventloop(context, &loop))) return r; - if ((r = getdns_list_get_length(listen_addresses, &listen_count))) + if (listen_addresses == NULL) + new_set_count = 0; + + else if ((r = getdns_list_get_length(listen_addresses, &new_set_count))) return r; - if (!listen_count) + if ((current_set = lookup_listen_set(root, context))) { + for (i = 0; i < current_set->count; i++) + current_set->items[i].action = to_remove; + } + if (new_set_count == 0) { + if (!current_set) + return GETDNS_RETURN_GOOD; + + rm_listen_set(&root, current_set); + remove_listeners(current_set); /* Is already remove */ return GETDNS_RETURN_GOOD; + } + if (!request_handler) + return GETDNS_RETURN_INVALID_PARAMETER; - if (!(listening = GETDNS_XMALLOC( - *mf, listen_data, listen_count * n_transports))) + if (!(new_set = (listen_set *)GETDNS_XMALLOC(*mf, uint8_t, + sizeof(listen_set) + + sizeof(listener) * new_set_count * n_transports))) return GETDNS_RETURN_MEMORY_ERROR; - (void) memset(listening, 0, - sizeof(listen_data) * n_transports * listen_count); - (void) memset(&hints, 0, sizeof(struct addrinfo)); + new_set->context = context; + new_set->next = root; + new_set->handler = request_handler; + new_set->count = new_set_count * n_transports; + (void) memset(new_set->items, 0, + sizeof(listener) * new_set_count * n_transports); (void) memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_family = AF_UNSPEC; hints.ai_flags = AI_NUMERICHOST; - for (i = 0; !r && i < listen_count; i++) { + for (i = 0; !r && i < new_set_count; i++) { getdns_dict *dict = NULL; getdns_bindata *address_data; struct sockaddr_storage addr; getdns_bindata *scope_id; + char addrstr[1024], *eos; + size_t t; if ((r = getdns_list_get_dict(listen_addresses, i, &dict))) { if ((r = getdns_list_get_bindata( @@ -548,13 +776,16 @@ getdns_return_t getdns_context_set_listen_addresses(getdns_context *context, eos[scope_id->size] = 0; } for (t = 0; !r && t < n_transports; t++) { + char portstr[1024]; getdns_transport_list_t transport = listen_transports[t]; uint32_t port = transport_ports[t]; struct addrinfo *ai; - listen_data *ld = &listening[i * n_transports + t]; + listener *l = &new_set->items[i*n_transports + t]; + size_t j; + listener *cl; - ld->fd = -1; + l->fd = -1; if (dict) (void) getdns_dict_get_int(dict, ( transport == GETDNS_TRANSPORT_TLS @@ -569,60 +800,62 @@ getdns_return_t getdns_context_set_listen_addresses(getdns_context *context, if (!ai) continue; - ld->addr.ss_family = addr.ss_family; - ld->addr_len = ai->ai_addrlen; - (void) memcpy(&ld->addr, ai->ai_addr, ai->ai_addrlen); - ld->transport = transport; - ld->handler = request_handler; - ld->context = context; - ld->connections = NULL; + l->addr.ss_family = addr.ss_family; + l->addr_len = ai->ai_addrlen; + (void) memcpy(&l->addr, ai->ai_addr, ai->ai_addrlen); + l->transport = transport; + l->set = new_set; + l->connections = NULL; freeaddrinfo(ai); + + /* Now determine the action */ + if (!current_set) { + l->action = to_add; + continue; + } + for (j = 0; j < current_set->count; j++) { + cl = ¤t_set->items[j]; + + if (l->transport == cl->transport && + l->addr_len == cl->addr_len && + !memcmp(&l->addr, &cl->addr, l->addr_len)) + break; + } + if (j == current_set->count) { + /* Not found */ + l->action = to_add; + continue; + } + l->action = cl->action = to_stay; + l->fd = cl->fd; + l->connections = cl->connections; + l->event = cl->event; + /* So the event can be rescheduled */ + l->to_replace = cl; } } - if (r) { - GETDNS_FREE(*mf, listening); - listening = NULL; + if ((r = add_listeners(new_set))) { + for (i = 0; i < new_set->count; i++) + new_set->items[i].action = to_remove; - } else for (i = 0; !r && i < listen_count * n_transports; i++) { - listen_data *ld = &listening[i]; + remove_listeners(new_set); + return r; + } + /* Reschedule all stayers */ + for (i = 0; i < new_set->count; i++) { + listener *l = &new_set->items[i]; - if (ld->transport != GETDNS_TRANSPORT_UDP && - ld->transport != GETDNS_TRANSPORT_TCP) - continue; - - if ((ld->fd = socket(ld->addr.ss_family, - ( ld->transport == GETDNS_TRANSPORT_UDP - ? SOCK_DGRAM : SOCK_STREAM), 0)) == -1) - /* IO error, TODO: report? */ - continue; - - if (setsockopt(ld->fd, SOL_SOCKET, SO_REUSEADDR, - &enable, sizeof(int)) < 0) - ; /* Ignore */ - - if (bind(ld->fd, (struct sockaddr *)&ld->addr, - ld->addr_len) == -1) { - /* IO error, TODO: report? */ - (void) close(ld->fd); - ld->fd = -1; - } - if (ld->transport == GETDNS_TRANSPORT_UDP) { - ld->event.userarg = ld; - ld->event.read_cb = udp_read_cb; - (void) loop->vmt->schedule( - loop, ld->fd, -1, &ld->event); - - } else if (listen(ld->fd, TCP_LISTEN_BACKLOG) == -1) { - /* IO error, TODO: report? */ - (void) close(ld->fd); - ld->fd = -1; - } else { - ld->event.userarg = ld; - ld->event.read_cb = tcp_accept_cb; - (void) loop->vmt->schedule( - loop, ld->fd, -1, &ld->event); + if (l->action == to_stay) { + loop->vmt->clear(loop, &l->to_replace->event); + /* assume success on reschedule */ + (void) loop->vmt->schedule(loop, l->fd, -1, &l->event); } } - return r; + if (current_set) { + rm_listen_set(&root, current_set); + remove_listeners(current_set); /* Is already remove */ + } + root = new_set; + return GETDNS_RETURN_GOOD; }