diff --git a/src/openssl/tls.c b/src/openssl/tls.c index 0e0e0f93..a3267c52 100644 --- a/src/openssl/tls.c +++ b/src/openssl/tls.c @@ -284,8 +284,7 @@ getdns_return_t _getdns_tls_connection_shutdown(_getdns_tls_connection* conn) if (!conn || !conn->ssl) return GETDNS_RETURN_INVALID_PARAMETER; - switch(SSL_shutdown(conn->ssl)) - { + switch (SSL_shutdown(conn->ssl)) { case 0: return GETDNS_RETURN_CONTEXT_UPDATE_FAIL; case 1: return GETDNS_RETURN_GOOD; default: return GETDNS_RETURN_GENERIC_ERROR; @@ -356,8 +355,7 @@ getdns_return_t _getdns_tls_connection_do_handshake(_getdns_tls_connection* conn if (r == 1) return GETDNS_RETURN_GOOD; err = SSL_get_error(conn->ssl, r); - switch(err) - { + switch (err) { case SSL_ERROR_WANT_READ: return GETDNS_RETURN_TLS_WANT_READ; @@ -380,6 +378,32 @@ getdns_return_t _getdns_tls_connection_is_session_reused(_getdns_tls_connection* return GETDNS_RETURN_TLS_CONNECTION_FRESH; } +getdns_return_t _getdns_tls_connection_read(_getdns_tls_connection* conn, uint8_t* buf, size_t to_read, size_t* read) +{ + int sread; + + if (!conn || !conn->ssl || !read) + return -GETDNS_RETURN_INVALID_PARAMETER; + + ERR_clear_error(); + sread = SSL_read(conn->ssl, buf, to_read); + if (sread <= 0) { + switch (SSL_get_error(conn->ssl, sread)) { + case SSL_ERROR_WANT_READ: + return GETDNS_RETURN_TLS_WANT_READ; + + case SSL_ERROR_WANT_WRITE: + return GETDNS_RETURN_TLS_WANT_WRITE; + + default: + return GETDNS_RETURN_GENERIC_ERROR; + } + } + + *read = sread; + return GETDNS_RETURN_GOOD; +} + getdns_return_t _getdns_tls_session_free(_getdns_tls_session* s) { if (!s || !s->ssl) @@ -389,8 +413,6 @@ getdns_return_t _getdns_tls_session_free(_getdns_tls_session* s) return GETDNS_RETURN_GOOD; } - - getdns_return_t _getdns_tls_get_api_information(getdns_dict* dict) { if (! getdns_dict_set_int( diff --git a/src/openssl/tls.h b/src/openssl/tls.h index 6dfc503d..44a50bd2 100644 --- a/src/openssl/tls.h +++ b/src/openssl/tls.h @@ -105,6 +105,21 @@ getdns_return_t _getdns_tls_connection_do_handshake(_getdns_tls_connection* conn */ getdns_return_t _getdns_tls_connection_is_session_reused(_getdns_tls_connection* conn); +/** + * Read from TLS. + * + * @param conn the connection. + * @param buf the buffer to read to. + * @param to_read the number of bytes to read. + * @param read pointer to holder for the number of bytes read. + * @return GETDNS_RETURN_GOOD if some bytes were read. + * @return GETDNS_RETURN_INVALID_PARAMETER if conn is null or has no SSL. + * @return GETDNS_RETURN_TLS_WANT_READ if the read needs to be retried. + * @return GETDNS_RETURN_TLS_WANT_WRITE if handshake isn't finished. + * @return GETDNS_RETURN_GENERIC_ERROR if read failed. + */ +getdns_return_t _getdns_tls_connection_read(_getdns_tls_connection* conn, uint8_t* buf, size_t to_read, size_t* read); + getdns_return_t _getdns_tls_session_free(_getdns_tls_session* s); getdns_return_t _getdns_tls_get_api_information(getdns_dict* dict); diff --git a/src/stub.c b/src/stub.c index 3acbb2ea..28d2e983 100644 --- a/src/stub.c +++ b/src/stub.c @@ -1273,10 +1273,10 @@ static int stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, struct mem_funcs *mf) { - ssize_t read; + size_t read; uint8_t *buf; size_t buf_size; - SSL* tls_obj = upstream->tls_obj->ssl; + _getdns_tls_connection* tls_obj = upstream->tls_obj; int q = tls_connected(upstream); if (q != 0) @@ -1292,16 +1292,17 @@ stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, tcp->to_read = 2; /* Packet size */ } - ERR_clear_error(); - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake - renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { + switch ((int)_getdns_tls_connection_read(tls_obj, tcp->read_pos, tcp->to_read, &read)) { + case GETDNS_RETURN_GOOD: + break; + + case GETDNS_RETURN_TLS_WANT_READ: return STUB_TCP_RETRY; /* Come back later */ - } else - return STUB_TCP_ERROR; + + default: + /* TODO[TLS]: Handle GETDNS_RETURN_TLS_WANT_WRITE which means handshake + renegotiation. Need to keep handshake state to do that.*/ + return STUB_TCP_ERROR; } tcp->to_read -= read; tcp->read_pos += read; @@ -1333,15 +1334,17 @@ stub_tls_read(getdns_upstream *upstream, getdns_tcp_state *tcp, /* Ready to start reading the packet */ tcp->read_pos = tcp->read_buf; - read = SSL_read(tls_obj, tcp->read_pos, tcp->to_read); - if (read <= 0) { - /* TODO[TLS]: Handle SSL_ERROR_WANT_WRITE which means handshake + switch ((int)_getdns_tls_connection_read(tls_obj, tcp->read_pos, tcp->to_read, &read)) { + case GETDNS_RETURN_GOOD: + break; + + case GETDNS_RETURN_TLS_WANT_READ: + return STUB_TCP_RETRY; /* Come back later */ + + default: + /* TODO[TLS]: Handle GETDNS_RETURN_TLS_WANT_WRITE which means handshake renegotiation. Need to keep handshake state to do that.*/ - int want = SSL_get_error(tls_obj, read); - if (want == SSL_ERROR_WANT_READ) { - return STUB_TCP_RETRY; /* read more later */ - } else - return STUB_TCP_ERROR; + return STUB_TCP_ERROR; } tcp->to_read -= read; tcp->read_pos += read;