diff --git a/ssl/handshake.cc b/ssl/handshake.cc index 573d3357..8087310f 100644 --- a/ssl/handshake.cc +++ b/ssl/handshake.cc @@ -126,6 +126,7 @@ SSL_HANDSHAKE::SSL_HANDSHAKE(SSL *ssl_arg) needs_psk_binder(0), received_hello_retry_request(0), received_custom_extension(0), + handshake_finalized(0), accept_psk_mode(0), cert_request(0), certificate_status_expected(0), diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc index 5a0902db..6afd44fb 100644 --- a/ssl/handshake_client.cc +++ b/ssl/handshake_client.cc @@ -1709,6 +1709,7 @@ static enum ssl_hs_wait_t do_finish_client_handshake(SSL_HANDSHAKE *hs) { hs->new_session.reset(); } + hs->handshake_finalized = 1; ssl->s3->initial_handshake_complete = 1; ssl_update_cache(hs, SSL_SESS_CACHE_CLIENT); diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc index eb79693e..71590cf6 100644 --- a/ssl/handshake_server.cc +++ b/ssl/handshake_server.cc @@ -1568,6 +1568,7 @@ static enum ssl_hs_wait_t do_finish_server_handshake(SSL_HANDSHAKE *hs) { ssl->s3->established_session->not_resumable = 0; } + hs->handshake_finalized = 1; ssl->s3->initial_handshake_complete = 1; ssl_update_cache(hs, SSL_SESS_CACHE_SERVER); diff --git a/ssl/internal.h b/ssl/internal.h index 480723d3..5fa38695 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -1243,6 +1243,10 @@ struct SSL_HANDSHAKE { unsigned received_custom_extension:1; + // handshake_finalized is true once the handshake has completed, at which + // point accessors should use the established state. + unsigned handshake_finalized:1; + // accept_psk_mode stores whether the client's PSK mode is compatible with our // preferences. unsigned accept_psk_mode:1; diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index e2106160..1f1461fc 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc @@ -2336,7 +2336,11 @@ int SSL_is_init_finished(const SSL *ssl) { } int SSL_in_init(const SSL *ssl) { - return ssl->s3->hs != NULL; + // This returns false once all the handshake state has been finalized, to + // allow callbacks and getters based on SSL_in_init to return the correct + // values. + SSL_HANDSHAKE *hs = ssl->s3->hs; + return hs != nullptr && !hs->handshake_finalized; } int SSL_in_false_start(const SSL *ssl) { diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index d86efab1..0c974381 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc @@ -887,6 +887,12 @@ static void InfoCallback(const SSL *ssl, int type, int val) { // test fails. abort(); } + // This callback is called when the handshake completes. |SSL_get_session| + // must continue to work and |SSL_in_init| must return false. + if (SSL_in_init(ssl) || SSL_get_session(ssl) == nullptr) { + fprintf(stderr, "Invalid state for SSL_CB_HANDSHAKE_DONE.\n"); + abort(); + } GetTestState(ssl)->handshake_done = true; // Callbacks may be called again on a new handshake. @@ -896,6 +902,14 @@ static void InfoCallback(const SSL *ssl, int type, int val) { } static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) { + // This callback is called as the handshake completes. |SSL_get_session| + // must continue to work and, historically, |SSL_in_init| returned false at + // this point. + if (SSL_in_init(ssl) || SSL_get_session(ssl) == nullptr) { + fprintf(stderr, "Invalid state for NewSessionCallback.\n"); + abort(); + } + GetTestState(ssl)->got_new_session = true; GetTestState(ssl)->new_session.reset(session); return 1;