Simplify ssl3_get_message.

Rather than this confusing coordination with the handshake state machine and
init_num changing meaning partway through, use the length field already in
BUF_MEM. Like the new record layer parsing, is no need to keep track of whether
we are reading the header or the body. Simply keep extending the handshake
message until it's far enough along.

ssl3_get_message still needs tons of work, but this allows us to disentangle it
from the handshake state.

Change-Id: Ic2b3e7cfe6152a7e28a04980317d3c7c396d9b08
Reviewed-on: https://boringssl-review.googlesource.com/7948
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2016-05-13 18:12:19 -04:00 committed by Adam Langley
parent 1f9329aaf5
commit a6338be3fa
7 changed files with 95 additions and 104 deletions

View File

@ -4129,9 +4129,12 @@ typedef struct ssl3_state_st {
uint8_t peer_finish_md[EVP_MAX_MD_SIZE];
int peer_finish_md_len;
unsigned long message_size;
int message_type;
/* message_complete is one if the current message is complete and zero
* otherwise. */
unsigned message_complete:1;
/* used to hold the new cipher we are going to use */
const SSL_CIPHER *new_cipher;

View File

@ -578,8 +578,9 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type,
goto f_err;
}
*ok = 1;
assert(ssl->init_buf->length >= DTLS1_HM_HEADER_LENGTH);
ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
ssl->init_num = (int)ssl->s3->tmp.message_size;
ssl->init_num = (int)ssl->init_buf->length - DTLS1_HM_HEADER_LENGTH;
return ssl->init_num;
}
@ -600,11 +601,10 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type,
assert(frag->reassembly == NULL);
/* Reconstruct the assembled message. */
size_t len;
CBB cbb;
CBB_zero(&cbb);
if (!BUF_MEM_grow(ssl->init_buf, (size_t)frag->msg_header.msg_len +
DTLS1_HM_HEADER_LENGTH) ||
if (!BUF_MEM_reserve(ssl->init_buf, (size_t)frag->msg_header.msg_len +
DTLS1_HM_HEADER_LENGTH) ||
!CBB_init_fixed(&cbb, (uint8_t *)ssl->init_buf->data,
ssl->init_buf->max) ||
!CBB_add_u8(&cbb, frag->msg_header.type) ||
@ -613,19 +613,19 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type,
!CBB_add_u24(&cbb, 0 /* frag_off */) ||
!CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
!CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) ||
!CBB_finish(&cbb, NULL, &len)) {
!CBB_finish(&cbb, NULL, &ssl->init_buf->length)) {
CBB_cleanup(&cbb);
OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
goto err;
}
assert(len == (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
assert(ssl->init_buf->length ==
(size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
ssl->d1->handshake_read_seq++;
/* TODO(davidben): This function has a lot of implicit outputs. Simplify the
* |ssl_get_message| API. */
ssl->s3->tmp.message_type = frag->msg_header.type;
ssl->s3->tmp.message_size = frag->msg_header.msg_len;
ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH;
ssl->init_num = frag->msg_header.msg_len;

View File

@ -161,7 +161,7 @@ int dtls1_connect(SSL *ssl) {
if (ssl->init_buf == NULL) {
buf = BUF_MEM_new();
if (buf == NULL ||
!BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
!BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1;
goto end;
}

View File

@ -158,7 +158,7 @@ int dtls1_accept(SSL *ssl) {
if (ssl->init_buf == NULL) {
buf = BUF_MEM_new();
if (buf == NULL || !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1;
goto end;
}

View File

@ -310,124 +310,113 @@ size_t ssl_max_handshake_message_len(const SSL *ssl) {
return kMaxMessageLen;
}
static int extend_handshake_buffer(SSL *ssl, size_t length) {
if (!BUF_MEM_reserve(ssl->init_buf, length)) {
return -1;
}
while (ssl->init_buf->length < length) {
int ret =
ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE,
(uint8_t *)ssl->init_buf->data + ssl->init_buf->length,
length - ssl->init_buf->length, 0);
if (ret <= 0) {
return ret;
}
ssl->init_buf->length += (size_t)ret;
}
return 1;
}
/* Obtain handshake message of message type |msg_type| (any if |msg_type| ==
* -1). The first four bytes (msg_type and length) are read in state
* |header_state|, the body is read in state |body_state|. */
long ssl3_get_message(SSL *ssl, int header_state, int body_state, int msg_type,
enum ssl_hash_message_t hash_message, int *ok) {
uint8_t *p;
unsigned long l;
long n;
int al;
*ok = 0;
if (ssl->s3->tmp.reuse_message) {
/* A ssl_dont_hash_message call cannot be combined with reuse_message; the
* ssl_dont_hash_message would have to have been applied to the previous
* call. */
assert(hash_message == ssl_hash_message);
assert(ssl->s3->tmp.message_complete);
ssl->s3->tmp.reuse_message = 0;
if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
return -1;
}
*ok = 1;
ssl->state = body_state;
assert(ssl->init_buf->length >= 4);
ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4;
ssl->init_num = (int)ssl->s3->tmp.message_size;
ssl->init_num = (int)ssl->init_buf->length - 4;
return ssl->init_num;
}
p = (uint8_t *)ssl->init_buf->data;
if (ssl->state == header_state) {
assert(ssl->init_num < 4);
for (;;) {
while (ssl->init_num < 4) {
int bytes_read = ssl3_read_bytes(
ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], 4 - ssl->init_num, 0);
if (bytes_read <= 0) {
*ok = 0;
return bytes_read;
}
ssl->init_num += bytes_read;
}
static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0};
if (ssl->server || memcmp(p, kHelloRequest, sizeof(kHelloRequest)) != 0) {
break;
}
/* The server may always send 'Hello Request' messages -- we are doing
* a handshake anyway now, so ignore them if their format is correct.
* Does not count for 'Finished' MAC. */
ssl->init_num = 0;
if (ssl->msg_callback) {
ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, p, 4, ssl,
ssl->msg_callback_arg);
}
}
/* ssl->init_num == 4 */
if (msg_type >= 0 && *p != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
ssl->s3->tmp.message_type = *(p++);
n2l3(p, l);
if (l > ssl_max_handshake_message_len(ssl)) {
al = SSL_AD_ILLEGAL_PARAMETER;
OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
goto f_err;
}
if (l && !BUF_MEM_grow_clean(ssl->init_buf, l + 4)) {
OPENSSL_PUT_ERROR(SSL, ERR_R_BUF_LIB);
goto err;
}
ssl->s3->tmp.message_size = l;
ssl->state = body_state;
ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4;
ssl->init_num = 0;
again:
if (ssl->s3->tmp.message_complete) {
ssl->s3->tmp.message_complete = 0;
ssl->init_buf->length = 0;
}
/* next state (body_state) */
p = ssl->init_msg;
n = ssl->s3->tmp.message_size - ssl->init_num;
while (n > 0) {
int bytes_read =
ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], n, 0);
if (bytes_read <= 0) {
*ok = 0;
return bytes_read;
}
ssl->init_num += bytes_read;
n -= bytes_read;
/* Read the message header, if we haven't yet. */
int ret = extend_handshake_buffer(ssl, 4);
if (ret <= 0) {
return ret;
}
/* Parse out the length. Cap it so the peer cannot force us to buffer up to
* 2^24 bytes. */
const uint8_t *p = (uint8_t *)ssl->init_buf->data;
size_t msg_len = (((uint32_t)p[1]) << 16) | (((uint32_t)p[2]) << 8) | p[3];
if (msg_len > ssl_max_handshake_message_len(ssl)) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
return -1;
}
/* Read the message body, if we haven't yet. */
ret = extend_handshake_buffer(ssl, 4 + msg_len);
if (ret <= 0) {
return ret;
}
/* We have now received a complete message. */
ssl->s3->tmp.message_complete = 1;
if (ssl->msg_callback) {
ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
ssl->init_buf->length, ssl, ssl->msg_callback_arg);
}
static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0};
if (!ssl->server && ssl->init_buf->length == sizeof(kHelloRequest) &&
memcmp(kHelloRequest, ssl->init_buf->data, sizeof(kHelloRequest)) == 0) {
/* The server may always send 'Hello Request' messages -- we are doing a
* handshake anyway now, so ignore them if their format is correct. Does
* not count for 'Finished' MAC. */
goto again;
}
uint8_t actual_type = ((const uint8_t *)ssl->init_buf->data)[0];
if (msg_type >= 0 && actual_type != msg_type) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
return -1;
}
ssl->s3->tmp.message_type = actual_type;
ssl->state = body_state;
ssl->init_msg = (uint8_t*)ssl->init_buf->data + 4;
ssl->init_num = ssl->init_buf->length - 4;
/* Feed this message into MAC computation. */
if (hash_message == ssl_hash_message && !ssl3_hash_current_message(ssl)) {
goto err;
}
if (ssl->msg_callback) {
ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
(size_t)ssl->init_num + 4, ssl, ssl->msg_callback_arg);
return -1;
}
*ok = 1;
return ssl->init_num;
f_err:
ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
err:
*ok = 0;
return -1;
}
int ssl3_hash_current_message(SSL *ssl) {

View File

@ -200,7 +200,7 @@ int ssl3_connect(SSL *ssl) {
if (ssl->init_buf == NULL) {
buf = BUF_MEM_new();
if (buf == NULL ||
!BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
!BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1;
goto end;
}

View File

@ -208,7 +208,7 @@ int ssl3_accept(SSL *ssl) {
if (ssl->init_buf == NULL) {
buf = BUF_MEM_new();
if (!buf || !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
if (!buf || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) {
ret = -1;
goto end;
}
@ -625,7 +625,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) {
const uint8_t *p;
int ret;
CBS v2_client_hello, cipher_specs, session_id, challenge;
size_t msg_length, rand_len, len;
size_t msg_length, rand_len;
uint8_t msg_type;
uint16_t version, cipher_spec_length, session_id_length, challenge_length;
CBB client_hello, hello_body, cipher_suites;
@ -730,7 +730,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) {
/* Add the null compression scheme and finish. */
if (!CBB_add_u8(&hello_body, 1) || !CBB_add_u8(&hello_body, 0) ||
!CBB_finish(&client_hello, NULL, &len)) {
!CBB_finish(&client_hello, NULL, &ssl->init_buf->length)) {
CBB_cleanup(&client_hello);
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return -1;
@ -739,8 +739,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) {
/* Mark the message for "re"-use by the version-specific method. */
ssl->s3->tmp.reuse_message = 1;
ssl->s3->tmp.message_type = SSL3_MT_CLIENT_HELLO;
/* The handshake message header is 4 bytes. */
ssl->s3->tmp.message_size = len - 4;
ssl->s3->tmp.message_complete = 1;
/* Consume and discard the V2ClientHello. */
ssl_read_buffer_consume(ssl, 2 + msg_length);