Move references to init_buf into SSL_PROTOCOL_METHOD.

Both DTLS and TLS still use it, but that will change in the following
commit. This also removes the handshake's knowledge of the
dtls_clear_incoming_messages function.

(It's possible we'll want to get rid of begin_handshake in favor of
allocating it lazily depending on how TLS 1.3 post-handshake messages
end up working out. But this should work for now.)

Change-Id: I0f512788bbc330ab2c947890939c73e0a1aca18b
Reviewed-on: https://boringssl-review.googlesource.com/8666
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
This commit is contained in:
David Benjamin 2016-07-08 09:16:50 -07:00
parent a2c42d7685
commit 97718f1437
5 changed files with 69 additions and 41 deletions

View File

@ -58,6 +58,8 @@
#include <assert.h> #include <assert.h>
#include <openssl/buf.h>
#include "internal.h" #include "internal.h"
@ -88,6 +90,32 @@ static uint16_t dtls1_version_to_wire(uint16_t version) {
return ~(version - 0x0201); return ~(version - 0x0201);
} }
static int dtls1_begin_handshake(SSL *ssl) {
if (ssl->init_buf != NULL) {
return 1;
}
BUF_MEM *buf = BUF_MEM_new();
if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
BUF_MEM_free(buf);
return 0;
}
ssl->init_buf = buf;
ssl->init_num = 0;
return 1;
}
static void dtls1_finish_handshake(SSL *ssl) {
BUF_MEM_free(ssl->init_buf);
ssl->init_buf = NULL;
ssl->init_num = 0;
ssl->d1->handshake_read_seq = 0;
ssl->d1->handshake_write_seq = 0;
dtls_clear_incoming_messages(ssl);
}
static const SSL_PROTOCOL_METHOD kDTLSProtocolMethod = { static const SSL_PROTOCOL_METHOD kDTLSProtocolMethod = {
1 /* is_dtls */, 1 /* is_dtls */,
TLS1_1_VERSION, TLS1_1_VERSION,
@ -96,6 +124,8 @@ static const SSL_PROTOCOL_METHOD kDTLSProtocolMethod = {
dtls1_version_to_wire, dtls1_version_to_wire,
dtls1_new, dtls1_new,
dtls1_free, dtls1_free,
dtls1_begin_handshake,
dtls1_finish_handshake,
dtls1_get_message, dtls1_get_message,
dtls1_read_app_data, dtls1_read_app_data,
dtls1_read_change_cipher_spec, dtls1_read_change_cipher_spec,

View File

@ -187,7 +187,6 @@ static int ssl3_send_channel_id(SSL *ssl);
static int ssl3_get_new_session_ticket(SSL *ssl); static int ssl3_get_new_session_ticket(SSL *ssl);
int ssl3_connect(SSL *ssl) { int ssl3_connect(SSL *ssl) {
BUF_MEM *buf = NULL;
int ret = -1; int ret = -1;
int state, skip = 0; int state, skip = 0;
@ -201,19 +200,11 @@ int ssl3_connect(SSL *ssl) {
case SSL_ST_CONNECT: case SSL_ST_CONNECT:
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
if (ssl->init_buf == NULL) { if (!ssl->method->begin_handshake(ssl)) {
buf = BUF_MEM_new();
if (buf == NULL ||
!BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1; ret = -1;
goto end; goto end;
} }
ssl->init_buf = buf;
buf = NULL;
}
ssl->init_num = 0;
if (!ssl_init_wbio_buffer(ssl)) { if (!ssl_init_wbio_buffer(ssl)) {
ret = -1; ret = -1;
goto end; goto end;
@ -503,9 +494,7 @@ int ssl3_connect(SSL *ssl) {
/* clean a few things up */ /* clean a few things up */
ssl3_cleanup_key_block(ssl); ssl3_cleanup_key_block(ssl);
BUF_MEM_free(ssl->init_buf); ssl->method->finish_handshake(ssl);
ssl->init_buf = NULL;
ssl->init_num = 0;
/* Remove write buffering now. */ /* Remove write buffering now. */
ssl_free_wbio_buffer(ssl); ssl_free_wbio_buffer(ssl);
@ -520,11 +509,6 @@ int ssl3_connect(SSL *ssl) {
ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT); ssl_update_cache(ssl, SSL_SESS_CACHE_CLIENT);
} }
if (SSL_IS_DTLS(ssl)) {
ssl->d1->handshake_read_seq = 0;
ssl->d1->handshake_write_seq = 0;
}
ret = 1; ret = 1;
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1); ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_DONE, 1);
goto end; goto end;
@ -545,7 +529,6 @@ int ssl3_connect(SSL *ssl) {
} }
end: end:
BUF_MEM_free(buf);
ssl_do_info_callback(ssl, SSL_CB_CONNECT_EXIT, ret); ssl_do_info_callback(ssl, SSL_CB_CONNECT_EXIT, ret);
return ret; return ret;
} }

View File

@ -188,7 +188,6 @@ static int ssl3_get_channel_id(SSL *ssl);
static int ssl3_send_new_session_ticket(SSL *ssl); static int ssl3_send_new_session_ticket(SSL *ssl);
int ssl3_accept(SSL *ssl) { int ssl3_accept(SSL *ssl) {
BUF_MEM *buf = NULL;
uint32_t alg_a; uint32_t alg_a;
int ret = -1; int ret = -1;
int state, skip = 0; int state, skip = 0;
@ -203,16 +202,10 @@ int ssl3_accept(SSL *ssl) {
case SSL_ST_ACCEPT: case SSL_ST_ACCEPT:
ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1); ssl_do_info_callback(ssl, SSL_CB_HANDSHAKE_START, 1);
if (ssl->init_buf == NULL) { if (!ssl->method->begin_handshake(ssl)) {
buf = BUF_MEM_new();
if (!buf || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1; ret = -1;
goto end; goto end;
} }
ssl->init_buf = buf;
buf = NULL;
}
ssl->init_num = 0;
/* Enable a write buffer. This groups handshake messages within a flight /* Enable a write buffer. This groups handshake messages within a flight
* into a single write. */ * into a single write. */
@ -470,9 +463,7 @@ int ssl3_accept(SSL *ssl) {
/* clean a few things up */ /* clean a few things up */
ssl3_cleanup_key_block(ssl); ssl3_cleanup_key_block(ssl);
BUF_MEM_free(ssl->init_buf); ssl->method->finish_handshake(ssl);
ssl->init_buf = NULL;
ssl->init_num = 0;
/* remove buffering on output */ /* remove buffering on output */
ssl_free_wbio_buffer(ssl); ssl_free_wbio_buffer(ssl);
@ -486,12 +477,6 @@ int ssl3_accept(SSL *ssl) {
ssl->session->cert_chain = NULL; ssl->session->cert_chain = NULL;
} }
if (SSL_IS_DTLS(ssl)) {
ssl->d1->handshake_read_seq = 0;
ssl->d1->handshake_write_seq = 0;
dtls_clear_incoming_messages(ssl);
}
ssl->s3->initial_handshake_complete = 1; ssl->s3->initial_handshake_complete = 1;
ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER); ssl_update_cache(ssl, SSL_SESS_CACHE_SERVER);
@ -517,7 +502,6 @@ int ssl3_accept(SSL *ssl) {
} }
end: end:
BUF_MEM_free(buf);
ssl_do_info_callback(ssl, SSL_CB_ACCEPT_EXIT, ret); ssl_do_info_callback(ssl, SSL_CB_ACCEPT_EXIT, ret);
return ret; return ret;
} }

View File

@ -829,6 +829,11 @@ struct ssl_protocol_method_st {
uint16_t (*version_to_wire)(uint16_t version); uint16_t (*version_to_wire)(uint16_t version);
int (*ssl_new)(SSL *ssl); int (*ssl_new)(SSL *ssl);
void (*ssl_free)(SSL *ssl); void (*ssl_free)(SSL *ssl);
/* begin_handshake is called to start a new handshake. It returns one on
* success and zero on error. */
int (*begin_handshake)(SSL *ssl);
/* finish_handshake is called when a handshake completes. */
void (*finish_handshake)(SSL *ssl);
long (*ssl_get_message)(SSL *ssl, int msg_type, long (*ssl_get_message)(SSL *ssl, int msg_type,
enum ssl_hash_message_t hash_message, int *ok); enum ssl_hash_message_t hash_message, int *ok);
int (*read_app_data)(SSL *ssl, uint8_t *buf, int len, int peek); int (*read_app_data)(SSL *ssl, uint8_t *buf, int len, int peek);

View File

@ -56,6 +56,8 @@
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/buf.h>
#include "internal.h" #include "internal.h"
@ -65,6 +67,28 @@ static uint16_t ssl3_version_from_wire(uint16_t wire_version) {
static uint16_t ssl3_version_to_wire(uint16_t version) { return version; } static uint16_t ssl3_version_to_wire(uint16_t version) { return version; }
static int ssl3_begin_handshake(SSL *ssl) {
if (ssl->init_buf != NULL) {
return 1;
}
BUF_MEM *buf = BUF_MEM_new();
if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
BUF_MEM_free(buf);
return 0;
}
ssl->init_buf = buf;
ssl->init_num = 0;
return 1;
}
static void ssl3_finish_handshake(SSL *ssl) {
BUF_MEM_free(ssl->init_buf);
ssl->init_buf = NULL;
ssl->init_num = 0;
}
static const SSL_PROTOCOL_METHOD kTLSProtocolMethod = { static const SSL_PROTOCOL_METHOD kTLSProtocolMethod = {
0 /* is_dtls */, 0 /* is_dtls */,
SSL3_VERSION, SSL3_VERSION,
@ -73,6 +97,8 @@ static const SSL_PROTOCOL_METHOD kTLSProtocolMethod = {
ssl3_version_to_wire, ssl3_version_to_wire,
ssl3_new, ssl3_new,
ssl3_free, ssl3_free,
ssl3_begin_handshake,
ssl3_finish_handshake,
ssl3_get_message, ssl3_get_message,
ssl3_read_app_data, ssl3_read_app_data,
ssl3_read_change_cipher_spec, ssl3_read_change_cipher_spec,