diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 796949d8..62c12d97 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h @@ -4051,6 +4051,8 @@ struct ssl_ctx_st { int freelist_max_len; }; +typedef struct ssl_handshake_st SSL_HANDSHAKE; + struct ssl_st { /* method is the method table corresponding to the current protocol (DTLS or * TLS). */ @@ -4089,7 +4091,7 @@ struct ssl_st { * with a better mechanism. */ BIO *bbio; - int (*handshake_func)(SSL *); + int (*handshake_func)(SSL_HANDSHAKE *hs); BUF_MEM *init_buf; /* buffer used during init */ diff --git a/ssl/handshake_client.c b/ssl/handshake_client.c index 70d8d96d..8d503a55 100644 --- a/ssl/handshake_client.c +++ b/ssl/handshake_client.c @@ -186,7 +186,8 @@ static int ssl3_send_next_proto(SSL *ssl); static int ssl3_send_channel_id(SSL *ssl); static int ssl3_get_new_session_ticket(SSL *ssl); -int ssl3_connect(SSL *ssl) { +int ssl3_connect(SSL_HANDSHAKE *hs) { + SSL *const ssl = hs->ssl; int ret = -1; int state, skip = 0; @@ -205,12 +206,6 @@ int ssl3_connect(SSL *ssl) { case SSL_ST_CONNECT: ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); - ssl->s3->hs = ssl_handshake_new(tls13_client_handshake); - if (ssl->s3->hs == NULL) { - ret = -1; - goto end; - } - if (!ssl_init_wbio_buffer(ssl)) { ret = -1; goto end; @@ -277,7 +272,7 @@ int ssl3_connect(SSL *ssl) { break; case SSL3_ST_CR_CERT_STATUS_A: - if (ssl->s3->hs->certificate_status_expected) { + if (hs->certificate_status_expected) { ret = ssl3_get_cert_status(ssl); if (ret <= 0) { goto end; @@ -332,7 +327,7 @@ int ssl3_connect(SSL *ssl) { case SSL3_ST_CW_CERT_A: case SSL3_ST_CW_CERT_B: case SSL3_ST_CW_CERT_C: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_client_certificate(ssl); if (ret <= 0) { goto end; @@ -355,7 +350,7 @@ int ssl3_connect(SSL *ssl) { case SSL3_ST_CW_CERT_VRFY_A: case SSL3_ST_CW_CERT_VRFY_B: case SSL3_ST_CW_CERT_VRFY_C: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_cert_verify(ssl); if (ret <= 0) { goto end; @@ -383,7 +378,7 @@ int ssl3_connect(SSL *ssl) { case SSL3_ST_CW_NEXT_PROTO_A: case SSL3_ST_CW_NEXT_PROTO_B: - if (ssl->s3->hs->next_proto_neg_seen) { + if (hs->next_proto_neg_seen) { ret = ssl3_send_next_proto(ssl); if (ret <= 0) { goto end; @@ -441,14 +436,14 @@ int ssl3_connect(SSL *ssl) { case SSL3_ST_FALSE_START: ssl->state = SSL3_ST_CR_SESSION_TICKET_A; - ssl->s3->hs->in_false_start = 1; + hs->in_false_start = 1; ssl_free_wbio_buffer(ssl); ret = 1; goto end; case SSL3_ST_CR_SESSION_TICKET_A: - if (ssl->s3->hs->ticket_expected) { + if (hs->ticket_expected) { ret = ssl3_get_new_session_ticket(ssl); if (ret <= 0) { goto end; @@ -543,9 +538,6 @@ int ssl3_connect(SSL *ssl) { ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT); } - ssl_handshake_free(ssl->s3->hs); - ssl->s3->hs = NULL; - ret = 1; ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1); goto end; @@ -894,6 +886,7 @@ static int ssl3_get_server_hello(SSL *ssl) { if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) { ssl->state = SSL_ST_TLS13; + ssl->s3->hs->do_tls13_handshake = tls13_client_handshake; return 1; } diff --git a/ssl/handshake_server.c b/ssl/handshake_server.c index 3b66ab7d..ccf0e8b0 100644 --- a/ssl/handshake_server.c +++ b/ssl/handshake_server.c @@ -185,7 +185,8 @@ static int ssl3_get_next_proto(SSL *ssl); static int ssl3_get_channel_id(SSL *ssl); static int ssl3_send_new_session_ticket(SSL *ssl); -int ssl3_accept(SSL *ssl) { +int ssl3_accept(SSL_HANDSHAKE *hs) { + SSL *const ssl = hs->ssl; uint32_t alg_a; int ret = -1; int state, skip = 0; @@ -205,12 +206,6 @@ int ssl3_accept(SSL *ssl) { case SSL_ST_ACCEPT: ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); - ssl->s3->hs = ssl_handshake_new(tls13_server_handshake); - if (ssl->s3->hs == NULL) { - ret = -1; - goto end; - } - /* Enable a write buffer. This groups handshake messages within a flight * into a single write. */ if (!ssl_init_wbio_buffer(ssl)) { @@ -271,7 +266,7 @@ int ssl3_accept(SSL *ssl) { case SSL3_ST_SW_CERT_STATUS_A: case SSL3_ST_SW_CERT_STATUS_B: - if (ssl->s3->hs->certificate_status_expected) { + if (hs->certificate_status_expected) { ret = ssl3_send_certificate_status(ssl); if (ret <= 0) { goto end; @@ -303,7 +298,7 @@ int ssl3_accept(SSL *ssl) { case SSL3_ST_SW_CERT_REQ_A: case SSL3_ST_SW_CERT_REQ_B: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_send_certificate_request(ssl); if (ret <= 0) { goto end; @@ -325,7 +320,7 @@ int ssl3_accept(SSL *ssl) { break; case SSL3_ST_SR_CERT_A: - if (ssl->s3->hs->cert_request) { + if (hs->cert_request) { ret = ssl3_get_client_certificate(ssl); if (ret <= 0) { goto end; @@ -367,7 +362,7 @@ int ssl3_accept(SSL *ssl) { break; case SSL3_ST_SR_NEXT_PROTO_A: - if (ssl->s3->hs->next_proto_neg_seen) { + if (hs->next_proto_neg_seen) { ret = ssl3_get_next_proto(ssl); if (ret <= 0) { goto end; @@ -416,7 +411,7 @@ int ssl3_accept(SSL *ssl) { case SSL3_ST_SW_SESSION_TICKET_A: case SSL3_ST_SW_SESSION_TICKET_B: - if (ssl->s3->hs->ticket_expected) { + if (hs->ticket_expected) { ret = ssl3_send_new_session_ticket(ssl); if (ret <= 0) { goto end; @@ -505,9 +500,6 @@ int ssl3_accept(SSL *ssl) { ssl->s3->initial_handshake_complete = 1; ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER); - ssl_handshake_free(ssl->s3->hs); - ssl->s3->hs = NULL; - ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1); ret = 1; goto end; @@ -717,6 +709,7 @@ static int ssl3_get_client_hello(SSL *ssl) { if (ssl3_protocol_version(ssl) >= TLS1_3_VERSION) { ssl->state = SSL_ST_TLS13; + ssl->s3->hs->do_tls13_handshake = tls13_server_handshake; return 1; } } diff --git a/ssl/internal.h b/ssl/internal.h index 5893d4da..af833fbb 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -878,15 +878,18 @@ enum ssl_hs_wait_t { ssl_hs_private_key_operation, }; -typedef struct ssl_handshake_st { - /* wait contains the operation |do_handshake| is currently blocking on or - * |ssl_hs_ok| if none. */ +struct ssl_handshake_st { + /* ssl is a non-owning pointer to the parent |SSL| object. */ + SSL *ssl; + + /* wait contains the operation |do_tls13_handshake| is currently blocking on + * or |ssl_hs_ok| if none. */ enum ssl_hs_wait_t wait; - /* do_handshake runs the handshake. On completion, it returns |ssl_hs_ok|. - * Otherwise, it returns a value corresponding to what operation is needed to - * progress. */ - enum ssl_hs_wait_t (*do_handshake)(SSL *ssl); + /* do_tls13_handshake runs the TLS 1.3 handshake. On completion, it returns + * |ssl_hs_ok|. Otherwise, it returns a value corresponding to what operation + * is needed to progress. */ + enum ssl_hs_wait_t (*do_tls13_handshake)(SSL *ssl); int state; @@ -1022,9 +1025,9 @@ typedef struct ssl_handshake_st { /* hostname, on the server, is the value of the SNI extension. */ char *hostname; -} SSL_HANDSHAKE; +} /* SSL_HANDSHAKE */; -SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl)); +SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl); /* ssl_handshake_free releases all memory associated with |hs|. */ void ssl_handshake_free(SSL_HANDSHAKE *hs); @@ -1033,7 +1036,7 @@ void ssl_handshake_free(SSL_HANDSHAKE *hs); * 0 on error. */ int tls13_handshake(SSL *ssl); -/* The following are implementations of |do_handshake| for the client and +/* The following are implementations of |do_tls13_handshake| for the client and * server. */ enum ssl_hs_wait_t tls13_client_handshake(SSL *ssl); enum ssl_hs_wait_t tls13_server_handshake(SSL *ssl); @@ -1760,8 +1763,8 @@ const SSL_CIPHER *ssl3_choose_cipher( int ssl3_new(SSL *ssl); void ssl3_free(SSL *ssl); -int ssl3_accept(SSL *ssl); -int ssl3_connect(SSL *ssl); +int ssl3_accept(SSL_HANDSHAKE *hs); +int ssl3_connect(SSL_HANDSHAKE *hs); int ssl3_init_message(SSL *ssl, CBB *cbb, CBB *body, uint8_t type); int ssl3_finish_message(SSL *ssl, CBB *cbb, uint8_t **out_msg, size_t *out_len); diff --git a/ssl/s3_both.c b/ssl/s3_both.c index d8720205..b27938a0 100644 --- a/ssl/s3_both.c +++ b/ssl/s3_both.c @@ -130,14 +130,14 @@ #include "internal.h" -SSL_HANDSHAKE *ssl_handshake_new(enum ssl_hs_wait_t (*do_handshake)(SSL *ssl)) { +SSL_HANDSHAKE *ssl_handshake_new(SSL *ssl) { SSL_HANDSHAKE *hs = OPENSSL_malloc(sizeof(SSL_HANDSHAKE)); if (hs == NULL) { OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); return NULL; } memset(hs, 0, sizeof(SSL_HANDSHAKE)); - hs->do_handshake = do_handshake; + hs->ssl = ssl; hs->wait = ssl_hs_ok; return hs; } diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index aafad333..76a9de05 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c @@ -624,7 +624,28 @@ int SSL_do_handshake(SSL *ssl) { return 1; } - return ssl->handshake_func(ssl); + /* Set up a new handshake if necessary. */ + if (ssl->state == SSL_ST_INIT && ssl->s3->hs == NULL) { + ssl->s3->hs = ssl_handshake_new(ssl); + if (ssl->s3->hs == NULL) { + return -1; + } + } + + /* Run the handshake. */ + assert(ssl->s3->hs != NULL); + int ret = ssl->handshake_func(ssl->s3->hs); + if (ret <= 0) { + return ret; + } + + /* Destroy the handshake object if the handshake has completely finished. */ + if (!SSL_in_init(ssl)) { + ssl_handshake_free(ssl->s3->hs); + ssl->s3->hs = NULL; + } + + return 1; } int SSL_connect(SSL *ssl) { diff --git a/ssl/tls13_both.c b/ssl/tls13_both.c index 17f7161e..c8d32a1e 100644 --- a/ssl/tls13_both.c +++ b/ssl/tls13_both.c @@ -95,7 +95,7 @@ int tls13_handshake(SSL *ssl) { } /* Run the state machine again. */ - hs->wait = hs->do_handshake(ssl); + hs->wait = hs->do_tls13_handshake(ssl); if (hs->wait == ssl_hs_error) { /* Don't loop around to avoid a stray |SSL_R_SSL_HANDSHAKE_FAILURE| the * first time around. */