From c660417bd7a7a438eddb0bc087931671b6a849e7 Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Thu, 2 Jun 2016 16:38:35 -0400 Subject: [PATCH] Don't use dtls1_read_bytes to read messages. This was probably the worst offender of them all as read_bytes is the wrong abstraction to begin with. Note this is a slight change in how processing a record works. Rather than reading one fragment at a time, we process all fragments in a record and return. The intent here is so that all records are processed atomically since the connection eventually will not be able to retain a buffer holding the record. This loses a ton of (though not quite all yet) those a2b macros. Change-Id: Ibe4bbcc33c496328de08d272457d2282c411b38b Reviewed-on: https://boringssl-review.googlesource.com/8176 Reviewed-by: David Benjamin --- ssl/d1_both.c | 172 +++++++++++++++++++------------------- ssl/d1_pkt.c | 39 ++------- ssl/internal.h | 92 ++------------------ ssl/test/runner/runner.go | 6 +- 4 files changed, 107 insertions(+), 202 deletions(-) diff --git a/ssl/d1_both.c b/ssl/d1_both.c index 11c9b655..093bb692 100644 --- a/ssl/d1_both.c +++ b/ssl/d1_both.c @@ -415,24 +415,6 @@ static int dtls1_is_next_message_complete(SSL *ssl) { frag->reassembly == NULL; } -/* dtls1_discard_fragment_body discards a handshake fragment body of length - * |frag_len|. It returns one on success and zero on error. - * - * TODO(davidben): This function will go away when ssl_read_bytes is gone from - * the DTLS side. */ -static int dtls1_discard_fragment_body(SSL *ssl, size_t frag_len) { - uint8_t discard[256]; - while (frag_len > 0) { - size_t chunk = frag_len < sizeof(discard) ? frag_len : sizeof(discard); - int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, discard, chunk, 0); - if (ret != (int) chunk) { - return 0; - } - frag_len -= chunk; - } - return 1; -} - /* dtls1_get_buffered_message returns the buffered message corresponding to * |msg_hdr|. If none exists, it creates a new one and inserts it in the * queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It @@ -478,75 +460,92 @@ static hm_fragment *dtls1_get_buffered_message( return frag; } -/* dtls1_process_fragment reads a handshake fragment and processes it. It - * returns one if a fragment was successfully processed and 0 or -1 on error. */ -static int dtls1_process_fragment(SSL *ssl) { - /* Read handshake message header. */ - uint8_t header[DTLS1_HM_HEADER_LENGTH]; - int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, header, - DTLS1_HM_HEADER_LENGTH, 0); - if (ret <= 0) { - return ret; +/* dtls1_process_handshake_record reads a handshake record and processes it. It + * returns one if the record was successfully processed and 0 or -1 on error. */ +static int dtls1_process_handshake_record(SSL *ssl) { + SSL3_RECORD *rr = &ssl->s3->rrec; + +start: + if (rr->length == 0) { + int ret = dtls1_get_record(ssl); + if (ret <= 0) { + return ret; + } } - if (ret != DTLS1_HM_HEADER_LENGTH) { - OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); + + /* Cross-epoch records are discarded, but we may receive out-of-order + * application data between ChangeCipherSpec and Finished or a ChangeCipherSpec + * before the appropriate point in the handshake. Those must be silently + * discarded. + * + * However, only allow the out-of-order records in the correct epoch. + * Application data must come in the encrypted epoch, and ChangeCipherSpec in + * the unencrypted epoch (we never renegotiate). Other cases fall through and + * fail with a fatal error. */ + if ((rr->type == SSL3_RT_APPLICATION_DATA && + ssl->s3->aead_read_ctx != NULL) || + (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC && + ssl->s3->aead_read_ctx == NULL)) { + rr->length = 0; + goto start; + } + + if (rr->type != SSL3_RT_HANDSHAKE) { ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD); return -1; } - /* Parse the message fragment header. */ - struct hm_header_st msg_hdr; - dtls1_get_message_header(header, &msg_hdr); + CBS cbs; + CBS_init(&cbs, rr->data, rr->length); - /* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There - * should be no need to reach into |ssl->s3->rrec.length|. */ - const size_t frag_off = msg_hdr.frag_off; - const size_t frag_len = msg_hdr.frag_len; - const size_t msg_len = msg_hdr.msg_len; - if (frag_off > msg_len || frag_off + frag_len < frag_off || - frag_off + frag_len > msg_len || - msg_len > ssl_max_handshake_message_len(ssl) || - frag_len > ssl->s3->rrec.length) { - OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER); - return -1; - } - - if (msg_hdr.seq < ssl->d1->handshake_read_seq || - msg_hdr.seq > (unsigned)ssl->d1->handshake_read_seq + - kHandshakeBufferSize) { - /* Ignore fragments from the past, or ones too far in the future. */ - if (!dtls1_discard_fragment_body(ssl, frag_len)) { + while (CBS_len(&cbs) > 0) { + /* Read a handshake fragment. */ + struct hm_header_st msg_hdr; + CBS body; + if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD); + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); return -1; } - return 1; - } - hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr); - if (frag == NULL) { - return -1; - } - assert(frag->msg_header.msg_len == msg_len); - - if (frag->reassembly == NULL) { - /* The message is already assembled. */ - if (!dtls1_discard_fragment_body(ssl, frag_len)) { + const size_t frag_off = msg_hdr.frag_off; + const size_t frag_len = msg_hdr.frag_len; + const size_t msg_len = msg_hdr.msg_len; + if (frag_off > msg_len || frag_off + frag_len < frag_off || + frag_off + frag_len > msg_len || + msg_len > ssl_max_handshake_message_len(ssl)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE); + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER); return -1; } - return 1; - } - assert(msg_len > 0); - /* Read the body of the fragment. */ - ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, frag->fragment + frag_off, - frag_len, 0); - if (ret != (int) frag_len) { - OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR); - return -1; - } - dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len); + if (msg_hdr.seq < ssl->d1->handshake_read_seq || + msg_hdr.seq > + (unsigned)ssl->d1->handshake_read_seq + kHandshakeBufferSize) { + /* Ignore fragments from the past, or ones too far in the future. */ + continue; + } + hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr); + if (frag == NULL) { + return -1; + } + assert(frag->msg_header.msg_len == msg_len); + + if (frag->reassembly == NULL) { + /* The message is already assembled. */ + continue; + } + assert(msg_len > 0); + + /* Copy the body into the fragment. */ + memcpy(frag->fragment + frag_off, CBS_data(&body), CBS_len(&body)); + dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len); + } + + rr->length = 0; + ssl_read_buffer_discard(ssl); return 1; } @@ -579,9 +578,9 @@ long dtls1_get_message(SSL *ssl, int msg_type, return ssl->init_num; } - /* Process fragments until one is found. */ + /* Process handshake records until the next message is ready. */ while (!dtls1_is_next_message_complete(ssl)) { - int ret = dtls1_process_fragment(ssl); + int ret = dtls1_process_handshake_record(ssl); if (ret <= 0) { *ok = 0; return ret; @@ -835,13 +834,18 @@ unsigned int dtls1_min_mtu(void) { return kMinMTU; } -void dtls1_get_message_header(uint8_t *data, - struct hm_header_st *msg_hdr) { - memset(msg_hdr, 0x00, sizeof(struct hm_header_st)); - msg_hdr->type = *(data++); - n2l3(data, msg_hdr->msg_len); +int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr, + CBS *out_body) { + memset(out_hdr, 0x00, sizeof(struct hm_header_st)); - n2s(data, msg_hdr->seq); - n2l3(data, msg_hdr->frag_off); - n2l3(data, msg_hdr->frag_len); + if (!CBS_get_u8(cbs, &out_hdr->type) || + !CBS_get_u24(cbs, &out_hdr->msg_len) || + !CBS_get_u16(cbs, &out_hdr->seq) || + !CBS_get_u24(cbs, &out_hdr->frag_off) || + !CBS_get_u24(cbs, &out_hdr->frag_len) || + !CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) { + return 0; + } + + return 1; } diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c index d0a884a0..c67a7aec 100644 --- a/ssl/d1_pkt.c +++ b/ssl/d1_pkt.c @@ -116,6 +116,7 @@ #include #include +#include #include #include #include @@ -127,10 +128,7 @@ static int do_dtls1_write(SSL *ssl, int type, const uint8_t *buf, unsigned int len, enum dtls1_use_epoch_t use_epoch); -/* dtls1_get_record reads a new input record. On success, it places it in - * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if - * more data is needed. */ -static int dtls1_get_record(SSL *ssl) { +int dtls1_get_record(SSL *ssl) { again: switch (ssl->s3->recv_shutdown) { case ssl_shutdown_none: @@ -258,10 +256,7 @@ void dtls1_read_close_notify(SSL *ssl) { } /* Return up to 'len' payload bytes received in 'type' records. - * 'type' is one of the following: - * - * - SSL3_RT_HANDSHAKE (when dtls1_get_message calls us) - * - SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us) + * 'type' must be SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us). * * If we don't have stored data to work from, read a DTLS record first (possibly * multiple records if we still don't have anything to return). @@ -273,8 +268,7 @@ int dtls1_read_bytes(SSL *ssl, int type, unsigned char *buf, int len, int peek) unsigned int n; SSL3_RECORD *rr; - if ((type != SSL3_RT_APPLICATION_DATA && type != SSL3_RT_HANDSHAKE) || - (peek && type != SSL3_RT_APPLICATION_DATA)) { + if (type != SSL3_RT_APPLICATION_DATA) { OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return -1; } @@ -327,35 +321,18 @@ start: /* If we get here, then type != rr->type. */ - /* Cross-epoch records are discarded, but we may receive out-of-order - * application data between ChangeCipherSpec and Finished or a ChangeCipherSpec - * before the appropriate point in the handshake. Those must be silently - * discarded. - * - * However, only allow the out-of-order records in the correct epoch. - * Application data must come in the encrypted epoch, and ChangeCipherSpec in - * the unencrypted epoch (we never renegotiate). Other cases fall through and - * fail with a fatal error. */ - if ((rr->type == SSL3_RT_APPLICATION_DATA && - ssl->s3->aead_read_ctx != NULL) || - (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC && - ssl->s3->aead_read_ctx == NULL)) { - rr->length = 0; - goto start; - } - if (rr->type == SSL3_RT_HANDSHAKE) { - assert(type == SSL3_RT_APPLICATION_DATA); /* Parse the first fragment header to determine if this is a pre-CCS or * post-CCS handshake record. DTLS resets handshake message numbers on each * handshake, so renegotiations and retransmissions are ambiguous. */ - if (rr->length < DTLS1_HM_HEADER_LENGTH) { + CBS cbs, body; + struct hm_header_st msg_hdr; + CBS_init(&cbs, rr->data, rr->length); + if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) { al = SSL_AD_DECODE_ERROR; OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD); goto f_err; } - struct hm_header_st msg_hdr; - dtls1_get_message_header(rr->data, &msg_hdr); if (msg_hdr.type == SSL3_MT_FINISHED) { if (msg_hdr.frag_off == 0) { diff --git a/ssl/internal.h b/ssl/internal.h index 48569695..1b2c5445 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -704,100 +704,17 @@ void ssl_write_buffer_clear(SSL *ssl); * * Functions below here haven't been touched up and may be underdocumented. */ -#define c2l(c, l) \ - (l = ((unsigned long)(*((c)++))), l |= (((unsigned long)(*((c)++))) << 8), \ - l |= (((unsigned long)(*((c)++))) << 16), \ - l |= (((unsigned long)(*((c)++))) << 24)) - -/* NOTE - c is not incremented as per c2l */ -#define c2ln(c, l1, l2, n) \ - { \ - c += n; \ - l1 = l2 = 0; \ - switch (n) { \ - case 8: \ - l2 = ((unsigned long)(*(--(c)))) << 24; \ - case 7: \ - l2 |= ((unsigned long)(*(--(c)))) << 16; \ - case 6: \ - l2 |= ((unsigned long)(*(--(c)))) << 8; \ - case 5: \ - l2 |= ((unsigned long)(*(--(c)))); \ - case 4: \ - l1 = ((unsigned long)(*(--(c)))) << 24; \ - case 3: \ - l1 |= ((unsigned long)(*(--(c)))) << 16; \ - case 2: \ - l1 |= ((unsigned long)(*(--(c)))) << 8; \ - case 1: \ - l1 |= ((unsigned long)(*(--(c)))); \ - } \ - } - -#define l2c(l, c) \ - (*((c)++) = (uint8_t)(((l)) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 8) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 24) & 0xff)) - -#define n2l(c, l) \ - (l = ((unsigned long)(*((c)++))) << 24, \ - l |= ((unsigned long)(*((c)++))) << 16, \ - l |= ((unsigned long)(*((c)++))) << 8, l |= ((unsigned long)(*((c)++)))) - #define l2n(l, c) \ (*((c)++) = (uint8_t)(((l) >> 24) & 0xff), \ *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \ *((c)++) = (uint8_t)(((l) >> 8) & 0xff), \ *((c)++) = (uint8_t)(((l)) & 0xff)) -#define l2n8(l, c) \ - (*((c)++) = (uint8_t)(((l) >> 56) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 48) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 40) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 32) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 24) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 16) & 0xff), \ - *((c)++) = (uint8_t)(((l) >> 8) & 0xff), \ - *((c)++) = (uint8_t)(((l)) & 0xff)) - -/* NOTE - c is not incremented as per l2c */ -#define l2cn(l1, l2, c, n) \ - { \ - c += n; \ - switch (n) { \ - case 8: \ - *(--(c)) = (uint8_t)(((l2) >> 24) & 0xff); \ - case 7: \ - *(--(c)) = (uint8_t)(((l2) >> 16) & 0xff); \ - case 6: \ - *(--(c)) = (uint8_t)(((l2) >> 8) & 0xff); \ - case 5: \ - *(--(c)) = (uint8_t)(((l2)) & 0xff); \ - case 4: \ - *(--(c)) = (uint8_t)(((l1) >> 24) & 0xff); \ - case 3: \ - *(--(c)) = (uint8_t)(((l1) >> 16) & 0xff); \ - case 2: \ - *(--(c)) = (uint8_t)(((l1) >> 8) & 0xff); \ - case 1: \ - *(--(c)) = (uint8_t)(((l1)) & 0xff); \ - } \ - } - -#define n2s(c, s) \ - ((s = (((unsigned int)(c[0])) << 8) | (((unsigned int)(c[1])))), c += 2) - #define s2n(s, c) \ ((c[0] = (uint8_t)(((s) >> 8) & 0xff), \ c[1] = (uint8_t)(((s)) & 0xff)), \ c += 2) -#define n2l3(c, l) \ - ((l = (((unsigned long)(c[0])) << 16) | (((unsigned long)(c[1])) << 8) | \ - (((unsigned long)(c[2])))), \ - c += 3) - #define l2n3(l, c) \ ((c[0] = (uint8_t)(((l) >> 16) & 0xff), \ c[1] = (uint8_t)(((l) >> 8) & 0xff), \ @@ -1126,6 +1043,12 @@ int ssl3_set_handshake_header(SSL *ssl, int htype, unsigned long len); int ssl3_handshake_write(SSL *ssl); int dtls1_do_handshake_write(SSL *ssl, enum dtls1_use_epoch_t use_epoch); + +/* dtls1_get_record reads a new input record. On success, it places it in + * |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if + * more data is needed. */ +int dtls1_get_record(SSL *ssl); + int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek); int dtls1_read_change_cipher_spec(SSL *ssl); void dtls1_read_close_notify(SSL *ssl); @@ -1143,7 +1066,8 @@ int dtls1_send_finished(SSL *ssl, int a, int b, const char *sender, int slen); int dtls1_buffer_message(SSL *ssl); int dtls1_retransmit_buffered_messages(SSL *ssl); void dtls1_clear_record_buffer(SSL *ssl); -void dtls1_get_message_header(uint8_t *data, struct hm_header_st *msg_hdr); +int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr, + CBS *out_body); int dtls1_check_timeout_num(SSL *ssl); int dtls1_set_handshake_header(SSL *ssl, int type, unsigned long len); int dtls1_handshake_write(SSL *ssl); diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index 1b0bf54b..71d14d2a 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -1772,7 +1772,7 @@ func addBasicTests() { }, }, shouldFail: true, - expectedError: ":UNEXPECTED_MESSAGE:", + expectedError: ":BAD_HANDSHAKE_RECORD:", }, { protocol: dtls, @@ -1783,7 +1783,7 @@ func addBasicTests() { }, }, shouldFail: true, - expectedError: ":EXCESSIVE_MESSAGE_SIZE:", + expectedError: ":BAD_HANDSHAKE_RECORD:", }, { protocol: dtls, @@ -1794,7 +1794,7 @@ func addBasicTests() { }, }, shouldFail: true, - expectedError: ":EXCESSIVE_MESSAGE_SIZE:", + expectedError: ":BAD_HANDSHAKE_RECORD:", }, { protocol: dtls,