From a6338be3fa1a08f53d6d5f80aa4f26629fd047ab Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Fri, 13 May 2016 18:12:19 -0400 Subject: [PATCH] Simplify ssl3_get_message. Rather than this confusing coordination with the handshake state machine and init_num changing meaning partway through, use the length field already in BUF_MEM. Like the new record layer parsing, is no need to keep track of whether we are reading the header or the body. Simply keep extending the handshake message until it's far enough along. ssl3_get_message still needs tons of work, but this allows us to disentangle it from the handshake state. Change-Id: Ic2b3e7cfe6152a7e28a04980317d3c7c396d9b08 Reviewed-on: https://boringssl-review.googlesource.com/7948 Reviewed-by: Adam Langley --- include/openssl/ssl.h | 5 +- ssl/d1_both.c | 14 ++-- ssl/d1_clnt.c | 2 +- ssl/d1_srvr.c | 2 +- ssl/s3_both.c | 165 ++++++++++++++++++++---------------------- ssl/s3_clnt.c | 2 +- ssl/s3_srvr.c | 9 +-- 7 files changed, 95 insertions(+), 104 deletions(-) diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 960bc141..33d2710b 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h @@ -4129,9 +4129,12 @@ typedef struct ssl3_state_st { uint8_t peer_finish_md[EVP_MAX_MD_SIZE]; int peer_finish_md_len; - unsigned long message_size; int message_type; + /* message_complete is one if the current message is complete and zero + * otherwise. */ + unsigned message_complete:1; + /* used to hold the new cipher we are going to use */ const SSL_CIPHER *new_cipher; diff --git a/ssl/d1_both.c b/ssl/d1_both.c index 7a624126..86fbe4a6 100644 --- a/ssl/d1_both.c +++ b/ssl/d1_both.c @@ -578,8 +578,9 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type, goto f_err; } *ok = 1; + assert(ssl->init_buf->length >= DTLS1_HM_HEADER_LENGTH); ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH; - ssl->init_num = (int)ssl->s3->tmp.message_size; + ssl->init_num = (int)ssl->init_buf->length - DTLS1_HM_HEADER_LENGTH; return ssl->init_num; } @@ -600,11 +601,10 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type, assert(frag->reassembly == NULL); /* Reconstruct the assembled message. */ - size_t len; CBB cbb; CBB_zero(&cbb); - if (!BUF_MEM_grow(ssl->init_buf, (size_t)frag->msg_header.msg_len + - DTLS1_HM_HEADER_LENGTH) || + if (!BUF_MEM_reserve(ssl->init_buf, (size_t)frag->msg_header.msg_len + + DTLS1_HM_HEADER_LENGTH) || !CBB_init_fixed(&cbb, (uint8_t *)ssl->init_buf->data, ssl->init_buf->max) || !CBB_add_u8(&cbb, frag->msg_header.type) || @@ -613,19 +613,19 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int msg_type, !CBB_add_u24(&cbb, 0 /* frag_off */) || !CBB_add_u24(&cbb, frag->msg_header.msg_len) || !CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) || - !CBB_finish(&cbb, NULL, &len)) { + !CBB_finish(&cbb, NULL, &ssl->init_buf->length)) { CBB_cleanup(&cbb); OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); goto err; } - assert(len == (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH); + assert(ssl->init_buf->length == + (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH); ssl->d1->handshake_read_seq++; /* TODO(davidben): This function has a lot of implicit outputs. Simplify the * |ssl_get_message| API. */ ssl->s3->tmp.message_type = frag->msg_header.type; - ssl->s3->tmp.message_size = frag->msg_header.msg_len; ssl->init_msg = (uint8_t *)ssl->init_buf->data + DTLS1_HM_HEADER_LENGTH; ssl->init_num = frag->msg_header.msg_len; diff --git a/ssl/d1_clnt.c b/ssl/d1_clnt.c index 3088a7e9..62da1756 100644 --- a/ssl/d1_clnt.c +++ b/ssl/d1_clnt.c @@ -161,7 +161,7 @@ int dtls1_connect(SSL *ssl) { if (ssl->init_buf == NULL) { buf = BUF_MEM_new(); if (buf == NULL || - !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { + !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { ret = -1; goto end; } diff --git a/ssl/d1_srvr.c b/ssl/d1_srvr.c index f9268a29..f715f29e 100644 --- a/ssl/d1_srvr.c +++ b/ssl/d1_srvr.c @@ -158,7 +158,7 @@ int dtls1_accept(SSL *ssl) { if (ssl->init_buf == NULL) { buf = BUF_MEM_new(); - if (buf == NULL || !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { + if (buf == NULL || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { ret = -1; goto end; } diff --git a/ssl/s3_both.c b/ssl/s3_both.c index 5235b976..48519cb3 100644 --- a/ssl/s3_both.c +++ b/ssl/s3_both.c @@ -310,124 +310,113 @@ size_t ssl_max_handshake_message_len(const SSL *ssl) { return kMaxMessageLen; } +static int extend_handshake_buffer(SSL *ssl, size_t length) { + if (!BUF_MEM_reserve(ssl->init_buf, length)) { + return -1; + } + while (ssl->init_buf->length < length) { + int ret = + ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE, + (uint8_t *)ssl->init_buf->data + ssl->init_buf->length, + length - ssl->init_buf->length, 0); + if (ret <= 0) { + return ret; + } + ssl->init_buf->length += (size_t)ret; + } + return 1; +} + /* Obtain handshake message of message type |msg_type| (any if |msg_type| == * -1). The first four bytes (msg_type and length) are read in state * |header_state|, the body is read in state |body_state|. */ long ssl3_get_message(SSL *ssl, int header_state, int body_state, int msg_type, enum ssl_hash_message_t hash_message, int *ok) { - uint8_t *p; - unsigned long l; - long n; - int al; + *ok = 0; if (ssl->s3->tmp.reuse_message) { /* A ssl_dont_hash_message call cannot be combined with reuse_message; the * ssl_dont_hash_message would have to have been applied to the previous * call. */ assert(hash_message == ssl_hash_message); + assert(ssl->s3->tmp.message_complete); ssl->s3->tmp.reuse_message = 0; if (msg_type >= 0 && ssl->s3->tmp.message_type != msg_type) { - al = SSL_AD_UNEXPECTED_MESSAGE; + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); - goto f_err; + return -1; } *ok = 1; ssl->state = body_state; + assert(ssl->init_buf->length >= 4); ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4; - ssl->init_num = (int)ssl->s3->tmp.message_size; + ssl->init_num = (int)ssl->init_buf->length - 4; return ssl->init_num; } - p = (uint8_t *)ssl->init_buf->data; - - if (ssl->state == header_state) { - assert(ssl->init_num < 4); - - for (;;) { - while (ssl->init_num < 4) { - int bytes_read = ssl3_read_bytes( - ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], 4 - ssl->init_num, 0); - if (bytes_read <= 0) { - *ok = 0; - return bytes_read; - } - ssl->init_num += bytes_read; - } - - static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0}; - if (ssl->server || memcmp(p, kHelloRequest, sizeof(kHelloRequest)) != 0) { - break; - } - - /* The server may always send 'Hello Request' messages -- we are doing - * a handshake anyway now, so ignore them if their format is correct. - * Does not count for 'Finished' MAC. */ - ssl->init_num = 0; - - if (ssl->msg_callback) { - ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, p, 4, ssl, - ssl->msg_callback_arg); - } - } - - /* ssl->init_num == 4 */ - - if (msg_type >= 0 && *p != msg_type) { - al = SSL_AD_UNEXPECTED_MESSAGE; - OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); - goto f_err; - } - ssl->s3->tmp.message_type = *(p++); - - n2l3(p, l); - if (l > ssl_max_handshake_message_len(ssl)) { - al = SSL_AD_ILLEGAL_PARAMETER; - OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); - goto f_err; - } - - if (l && !BUF_MEM_grow_clean(ssl->init_buf, l + 4)) { - OPENSSL_PUT_ERROR(SSL, ERR_R_BUF_LIB); - goto err; - } - ssl->s3->tmp.message_size = l; - ssl->state = body_state; - - ssl->init_msg = (uint8_t *)ssl->init_buf->data + 4; - ssl->init_num = 0; +again: + if (ssl->s3->tmp.message_complete) { + ssl->s3->tmp.message_complete = 0; + ssl->init_buf->length = 0; } - /* next state (body_state) */ - p = ssl->init_msg; - n = ssl->s3->tmp.message_size - ssl->init_num; - while (n > 0) { - int bytes_read = - ssl3_read_bytes(ssl, SSL3_RT_HANDSHAKE, &p[ssl->init_num], n, 0); - if (bytes_read <= 0) { - *ok = 0; - return bytes_read; - } - ssl->init_num += bytes_read; - n -= bytes_read; + /* Read the message header, if we haven't yet. */ + int ret = extend_handshake_buffer(ssl, 4); + if (ret <= 0) { + return ret; } + /* Parse out the length. Cap it so the peer cannot force us to buffer up to + * 2^24 bytes. */ + const uint8_t *p = (uint8_t *)ssl->init_buf->data; + size_t msg_len = (((uint32_t)p[1]) << 16) | (((uint32_t)p[2]) << 8) | p[3]; + if (msg_len > ssl_max_handshake_message_len(ssl)) { + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER); + OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); + return -1; + } + + /* Read the message body, if we haven't yet. */ + ret = extend_handshake_buffer(ssl, 4 + msg_len); + if (ret <= 0) { + return ret; + } + + /* We have now received a complete message. */ + ssl->s3->tmp.message_complete = 1; + if (ssl->msg_callback) { + ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data, + ssl->init_buf->length, ssl, ssl->msg_callback_arg); + } + + static const uint8_t kHelloRequest[4] = {SSL3_MT_HELLO_REQUEST, 0, 0, 0}; + if (!ssl->server && ssl->init_buf->length == sizeof(kHelloRequest) && + memcmp(kHelloRequest, ssl->init_buf->data, sizeof(kHelloRequest)) == 0) { + /* The server may always send 'Hello Request' messages -- we are doing a + * handshake anyway now, so ignore them if their format is correct. Does + * not count for 'Finished' MAC. */ + goto again; + } + + uint8_t actual_type = ((const uint8_t *)ssl->init_buf->data)[0]; + if (msg_type >= 0 && actual_type != msg_type) { + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); + return -1; + } + ssl->s3->tmp.message_type = actual_type; + ssl->state = body_state; + + ssl->init_msg = (uint8_t*)ssl->init_buf->data + 4; + ssl->init_num = ssl->init_buf->length - 4; + /* Feed this message into MAC computation. */ if (hash_message == ssl_hash_message && !ssl3_hash_current_message(ssl)) { - goto err; - } - if (ssl->msg_callback) { - ssl->msg_callback(0, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data, - (size_t)ssl->init_num + 4, ssl, ssl->msg_callback_arg); + return -1; } + *ok = 1; return ssl->init_num; - -f_err: - ssl3_send_alert(ssl, SSL3_AL_FATAL, al); - -err: - *ok = 0; - return -1; } int ssl3_hash_current_message(SSL *ssl) { diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c index a34659ff..8a6064a6 100644 --- a/ssl/s3_clnt.c +++ b/ssl/s3_clnt.c @@ -200,7 +200,7 @@ int ssl3_connect(SSL *ssl) { if (ssl->init_buf == NULL) { buf = BUF_MEM_new(); if (buf == NULL || - !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { + !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { ret = -1; goto end; } diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c index a13d0e69..266fb626 100644 --- a/ssl/s3_srvr.c +++ b/ssl/s3_srvr.c @@ -208,7 +208,7 @@ int ssl3_accept(SSL *ssl) { if (ssl->init_buf == NULL) { buf = BUF_MEM_new(); - if (!buf || !BUF_MEM_grow(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { + if (!buf || !BUF_MEM_reserve(buf, SSL3_RT_MAX_PLAIN_LENGTH)) { ret = -1; goto end; } @@ -625,7 +625,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) { const uint8_t *p; int ret; CBS v2_client_hello, cipher_specs, session_id, challenge; - size_t msg_length, rand_len, len; + size_t msg_length, rand_len; uint8_t msg_type; uint16_t version, cipher_spec_length, session_id_length, challenge_length; CBB client_hello, hello_body, cipher_suites; @@ -730,7 +730,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) { /* Add the null compression scheme and finish. */ if (!CBB_add_u8(&hello_body, 1) || !CBB_add_u8(&hello_body, 0) || - !CBB_finish(&client_hello, NULL, &len)) { + !CBB_finish(&client_hello, NULL, &ssl->init_buf->length)) { CBB_cleanup(&client_hello); OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return -1; @@ -739,8 +739,7 @@ int ssl3_get_v2_client_hello(SSL *ssl) { /* Mark the message for "re"-use by the version-specific method. */ ssl->s3->tmp.reuse_message = 1; ssl->s3->tmp.message_type = SSL3_MT_CLIENT_HELLO; - /* The handshake message header is 4 bytes. */ - ssl->s3->tmp.message_size = len - 4; + ssl->s3->tmp.message_complete = 1; /* Consume and discard the V2ClientHello. */ ssl_read_buffer_consume(ssl, 2 + msg_length);