From 29a83c5a0c9a71670ddfeef6a2a68b666338e9ff Mon Sep 17 00:00:00 2001 From: David Benjamin Date: Fri, 17 Jun 2016 19:12:54 -0400 Subject: [PATCH] Rewrite DTLS outgoing message buffering. Now that retransitting is a lot less stateful, a lot of surrounding code can lose statefulness too. Rather than this overcomplicated pqueue structure, hardcode that a handshake flight is capped at 7 messages (actually, DTLS can only get up to 6 because we don't support NPN or Channel ID in DTLS) and used a fixed size array. This also resolves several TODOs. Change-Id: I2b54c3441577a75ad5ca411d872b807d69aa08eb Reviewed-on: https://boringssl-review.googlesource.com/8435 Reviewed-by: Adam Langley --- ssl/d1_both.c | 167 +++++++++++++++---------------------------------- ssl/d1_lib.c | 29 +++------ ssl/internal.h | 40 ++++++------ 3 files changed, 80 insertions(+), 156 deletions(-) diff --git a/ssl/d1_both.c b/ssl/d1_both.c index d98b5df7..6b2a0193 100644 --- a/ssl/d1_both.c +++ b/ssl/d1_both.c @@ -139,11 +139,7 @@ static const unsigned int kMinMTU = 256 - 28; * the underlying BIO supplies one. */ static const unsigned int kDefaultMTU = 1500 - 28; -/* kMaxHandshakeBuffer is the maximum number of handshake messages ahead of the - * current one to buffer. */ -static const unsigned int kHandshakeBufferSize = 10; - -static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) { +static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) { hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment)); if (frag == NULL) { OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); @@ -160,20 +156,18 @@ static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) { goto err; } - if (reassembly) { - /* Initialize reassembly bitmask. */ - if (frag_len + 7 < frag_len) { - OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); - goto err; - } - size_t bitmask_len = (frag_len + 7) / 8; - frag->reassembly = OPENSSL_malloc(bitmask_len); - if (frag->reassembly == NULL) { - OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); - goto err; - } - memset(frag->reassembly, 0, bitmask_len); + /* Initialize reassembly bitmask. */ + if (frag_len + 7 < frag_len) { + OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); + goto err; } + size_t bitmask_len = (frag_len + 7) / 8; + frag->reassembly = OPENSSL_malloc(bitmask_len); + if (frag->reassembly == NULL) { + OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE); + goto err; + } + memset(frag->reassembly, 0, bitmask_len); } return frag; @@ -440,8 +434,7 @@ static hm_fragment *dtls1_get_buffered_message( hm_fragment *frag; if (item == NULL) { /* This is the first fragment from this message. */ - frag = dtls1_hm_fragment_new(msg_hdr->msg_len, - 1 /* reassembly buffer needed */); + frag = dtls1_hm_fragment_new(msg_hdr->msg_len); if (frag == NULL) { return NULL; } @@ -532,7 +525,7 @@ start: if (msg_hdr.seq < ssl->d1->handshake_read_seq || msg_hdr.seq > - (unsigned)ssl->d1->handshake_read_seq + kHandshakeBufferSize) { + (unsigned)ssl->d1->handshake_read_seq + SSL_MAX_HANDSHAKE_FLIGHT) { /* Ignore fragments from the past, or ones too far in the future. */ continue; } @@ -661,39 +654,25 @@ err: return -1; } -static uint16_t dtls1_get_queue_priority(uint16_t seq, int is_ccs) { - assert(seq * 2 >= seq); - - /* The index of the retransmission queue actually is the message sequence - * number, since the queue only contains messages of a single handshake. - * However, the ChangeCipherSpec has no message sequence number and so using - * only the sequence will result in the CCS and Finished having the same - * index. To prevent this, the sequence number is multiplied by 2. In case of - * a CCS 1 is subtracted. This does not only differ CSS and Finished, it also - * maintains the order of the index (important for priority queues) and fits - * in the unsigned short variable. */ - return seq * 2 - is_ccs; -} - -static int dtls1_retransmit_message(SSL *ssl, hm_fragment *frag) { +static int dtls1_retransmit_message(SSL *ssl, + const DTLS_OUTGOING_MESSAGE *msg) { /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1 * (negotiated cipher) exist. */ assert(ssl->d1->w_epoch == 0 || ssl->d1->w_epoch == 1); - assert(frag->msg_header.epoch <= ssl->d1->w_epoch); + assert(msg->epoch <= ssl->d1->w_epoch); enum dtls1_use_epoch_t use_epoch = dtls1_use_current_epoch; - if (ssl->d1->w_epoch == 1 && frag->msg_header.epoch == 0) { + if (ssl->d1->w_epoch == 1 && msg->epoch == 0) { use_epoch = dtls1_use_previous_epoch; } /* TODO(davidben): This cannot handle non-blocking writes. */ int ret; - if (frag->msg_header.is_ccs) { + if (msg->is_ccs) { ret = dtls1_write_change_cipher_spec(ssl, use_epoch); } else { size_t offset = 0; - ret = dtls1_do_handshake_write( - ssl, &offset, frag->fragment, offset, - frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH, use_epoch); + ret = dtls1_do_handshake_write(ssl, &offset, msg->data, offset, msg->len, + use_epoch); } return ret; @@ -709,11 +688,9 @@ int dtls1_retransmit_buffered_messages(SSL *ssl) { assert(ssl_is_wbio_buffered(ssl)); int ret = -1; - piterator iter = pqueue_iterator(ssl->d1->sent_messages); - pitem *item; - for (item = pqueue_next(&iter); item != NULL; item = pqueue_next(&iter)) { - hm_fragment *frag = (hm_fragment *)item->data; - if (dtls1_retransmit_message(ssl, frag) <= 0) { + size_t i; + for (i = 0; i < ssl->d1->outgoing_messages_len; i++) { + if (dtls1_retransmit_message(ssl, &ssl->d1->outgoing_messages[i]) <= 0) { goto err; } } @@ -732,66 +709,41 @@ err: } /* dtls1_buffer_change_cipher_spec adds a ChangeCipherSpec to the current - * handshake flight, ordered just before the handshake message numbered - * |seq|. */ -static int dtls1_buffer_change_cipher_spec(SSL *ssl, uint16_t seq) { - hm_fragment *frag = dtls1_hm_fragment_new(0 /* frag_len */, - 0 /* no reassembly */); - if (frag == NULL) { + * handshake flight. */ +static int dtls1_buffer_change_cipher_spec(SSL *ssl) { + if (ssl->d1->outgoing_messages_len >= SSL_MAX_HANDSHAKE_FLIGHT) { + OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return 0; } - frag->msg_header.is_ccs = 1; - frag->msg_header.epoch = ssl->d1->w_epoch; - uint16_t priority = dtls1_get_queue_priority(seq, 1 /* is_ccs */); - uint8_t seq64be[8]; - memset(seq64be, 0, sizeof(seq64be)); - seq64be[6] = (uint8_t)(priority >> 8); - seq64be[7] = (uint8_t)priority; + DTLS_OUTGOING_MESSAGE *msg = + &ssl->d1->outgoing_messages[ssl->d1->outgoing_messages_len]; + msg->data = NULL; + msg->len = 0; + msg->epoch = ssl->d1->w_epoch; + msg->is_ccs = 1; - pitem *item = pitem_new(seq64be, frag); - if (item == NULL) { - dtls1_hm_fragment_free(frag); - return 0; - } - - pqueue_insert(ssl->d1->sent_messages, item); + ssl->d1->outgoing_messages_len++; return 1; } int dtls1_buffer_message(SSL *ssl) { - hm_fragment *frag = dtls1_hm_fragment_new(ssl->init_num, 0); - if (!frag) { + if (ssl->d1->outgoing_messages_len >= SSL_MAX_HANDSHAKE_FLIGHT) { + OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); return 0; } - memcpy(frag->fragment, ssl->init_buf->data, ssl->init_num); - - assert(ssl->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH == - (unsigned int)ssl->init_num); - - frag->msg_header.msg_len = ssl->d1->w_msg_hdr.msg_len; - frag->msg_header.seq = ssl->d1->w_msg_hdr.seq; - frag->msg_header.type = ssl->d1->w_msg_hdr.type; - frag->msg_header.frag_off = 0; - frag->msg_header.frag_len = ssl->d1->w_msg_hdr.msg_len; - frag->msg_header.is_ccs = 0; - frag->msg_header.epoch = ssl->d1->w_epoch; - - uint16_t priority = dtls1_get_queue_priority(frag->msg_header.seq, - 0 /* handshake */); - uint8_t seq64be[8]; - memset(seq64be, 0, sizeof(seq64be)); - seq64be[6] = (uint8_t)(priority >> 8); - seq64be[7] = (uint8_t)priority; - - pitem *item = pitem_new(seq64be, frag); - if (item == NULL) { - dtls1_hm_fragment_free(frag); + DTLS_OUTGOING_MESSAGE *msg = + &ssl->d1->outgoing_messages[ssl->d1->outgoing_messages_len]; + msg->data = BUF_memdup(ssl->init_buf->data, ssl->init_num); + if (msg->data == NULL) { return 0; } + msg->len = ssl->init_num; + msg->epoch = ssl->d1->w_epoch; + msg->is_ccs = 0; - pqueue_insert(ssl->d1->sent_messages, item); + ssl->d1->outgoing_messages_len++; return 1; } @@ -799,35 +751,20 @@ int dtls1_send_change_cipher_spec(SSL *ssl, int a, int b) { if (ssl->state == a) { /* Buffer the message to handle retransmits. */ ssl->d1->handshake_write_seq = ssl->d1->next_handshake_write_seq; - dtls1_buffer_change_cipher_spec(ssl, ssl->d1->handshake_write_seq); + dtls1_buffer_change_cipher_spec(ssl); ssl->state = b; } return dtls1_write_change_cipher_spec(ssl, dtls1_use_current_epoch); } -/* call this function when the buffered messages are no longer needed */ -void dtls1_clear_record_buffer(SSL *ssl) { - pitem *item; - - for (item = pqueue_pop(ssl->d1->sent_messages); item != NULL; - item = pqueue_pop(ssl->d1->sent_messages)) { - dtls1_hm_fragment_free((hm_fragment *)item->data); - pitem_free(item); +void dtls_clear_outgoing_messages(SSL *ssl) { + size_t i; + for (i = 0; i < ssl->d1->outgoing_messages_len; i++) { + OPENSSL_free(ssl->d1->outgoing_messages[i].data); + ssl->d1->outgoing_messages[i].data = NULL; } -} - -/* don't actually do the writing, wait till the MTU has been retrieved */ -void dtls1_set_message_header(SSL *ssl, uint8_t mt, unsigned long len, - unsigned short seq_num, unsigned long frag_off, - unsigned long frag_len) { - struct hm_header_st *msg_hdr = &ssl->d1->w_msg_hdr; - - msg_hdr->type = mt; - msg_hdr->msg_len = len; - msg_hdr->seq = seq_num; - msg_hdr->frag_off = frag_off; - msg_hdr->frag_len = frag_len; + ssl->d1->outgoing_messages_len = 0; } unsigned int dtls1_min_mtu(void) { diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c index 78379362..6e05695e 100644 --- a/ssl/d1_lib.c +++ b/ssl/d1_lib.c @@ -98,11 +98,7 @@ int dtls1_new(SSL *ssl) { memset(d1, 0, sizeof *d1); d1->buffered_messages = pqueue_new(); - d1->sent_messages = pqueue_new(); - - if (!d1->buffered_messages || !d1->sent_messages) { - pqueue_free(d1->buffered_messages); - pqueue_free(d1->sent_messages); + if (d1->buffered_messages == NULL) { OPENSSL_free(d1); ssl3_free(ssl); return 0; @@ -128,12 +124,6 @@ static void dtls1_clear_queues(SSL *ssl) { dtls1_hm_fragment_free(frag); pitem_free(item); } - - while ((item = pqueue_pop(ssl->d1->sent_messages)) != NULL) { - frag = (hm_fragment *)item->data; - dtls1_hm_fragment_free(frag); - pitem_free(item); - } } void dtls1_free(SSL *ssl) { @@ -144,9 +134,9 @@ void dtls1_free(SSL *ssl) { } dtls1_clear_queues(ssl); - pqueue_free(ssl->d1->buffered_messages); - pqueue_free(ssl->d1->sent_messages); + + dtls_clear_outgoing_messages(ssl); OPENSSL_free(ssl->d1); ssl->d1 = NULL; @@ -255,7 +245,7 @@ void dtls1_stop_timer(SSL *ssl) { BIO_ctrl(ssl->rbio, BIO_CTRL_DGRAM_SET_NEXT_TIMEOUT, 0, &ssl->d1->next_timeout); /* Clear retransmission buffer */ - dtls1_clear_record_buffer(ssl); + dtls_clear_outgoing_messages(ssl); } int dtls1_check_timeout_num(SSL *ssl) { @@ -321,23 +311,20 @@ static void get_current_time(const SSL *ssl, struct timeval *out_clock) { int dtls1_set_handshake_header(SSL *ssl, int htype, unsigned long len) { uint8_t *message = (uint8_t *)ssl->init_buf->data; - const struct hm_header_st *msg_hdr = &ssl->d1->w_msg_hdr; ssl->d1->handshake_write_seq = ssl->d1->next_handshake_write_seq; ssl->d1->next_handshake_write_seq++; - dtls1_set_message_header(ssl, htype, len, ssl->d1->handshake_write_seq, 0, - len); ssl->init_num = (int)len + DTLS1_HM_HEADER_LENGTH; ssl->init_off = 0; /* Serialize the message header as if it were a single fragment. */ uint8_t *p = message; - *p++ = msg_hdr->type; - l2n3(msg_hdr->msg_len, p); - s2n(msg_hdr->seq, p); + *p++ = htype; + l2n3(len, p); + s2n(ssl->d1->handshake_write_seq, p); l2n3(0, p); - l2n3(msg_hdr->msg_len, p); + l2n3(len, p); assert(p == message + DTLS1_HM_HEADER_LENGTH); /* Buffer the message to handle re-xmits. */ diff --git a/ssl/internal.h b/ssl/internal.h index ae3a8c92..0b263f24 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -635,10 +635,26 @@ int SSL_ECDH_CTX_finish(SSL_ECDH_CTX *ctx, uint8_t **out_secret, /* Handshake messages. */ +/* SSL_MAX_HANDSHAKE_FLIGHT is the number of messages, including + * ChangeCipherSpec, in the longest handshake flight. Currently this is the + * client's second leg in a full handshake when client certificates, NPN, and + * Channel ID, are all enabled. */ +#define SSL_MAX_HANDSHAKE_FLIGHT 7 + /* ssl_max_handshake_message_len returns the maximum number of bytes permitted * in a handshake message for |ssl|. */ size_t ssl_max_handshake_message_len(const SSL *ssl); +typedef struct dtls_outgoing_message_st { + uint8_t *data; + uint32_t len; + uint16_t epoch; + char is_ccs; +} DTLS_OUTGOING_MESSAGE; + +/* dtls_clear_outgoing_messages releases all buffered outgoing messages. */ +void dtls_clear_outgoing_messages(SSL *ssl); + /* Callbacks. */ @@ -870,24 +886,14 @@ struct ssl3_enc_method { #define DTLS1_AL_HEADER_LENGTH 2 -/* TODO(davidben): This structure is used for both incoming messages and - * outgoing messages. |is_ccs| and |epoch| are only used in the latter and - * should be moved elsewhere. */ struct hm_header_st { uint8_t type; uint32_t msg_len; uint16_t seq; uint32_t frag_off; uint32_t frag_len; - int is_ccs; - /* epoch, for buffered outgoing messages, is the epoch the message was - * originally sent in. */ - uint16_t epoch; }; -/* TODO(davidben): This structure is used for both incoming messages and - * outgoing messages. |fragment| and |reassembly| are only used in the former - * and should be moved elsewhere. */ typedef struct hm_fragment_st { struct hm_header_st msg_header; uint8_t *fragment; @@ -928,16 +934,13 @@ typedef struct dtls1_state_st { * size. */ pqueue buffered_messages; - /* send_messages is a priority queue of outgoing handshake messages sent in - * the most recent handshake flight. - * - * TODO(davidben): This data structure may as well be a STACK_OF(T). */ - pqueue sent_messages; + /* outgoing_messages is the queue of outgoing messages from the last handshake + * flight. */ + DTLS_OUTGOING_MESSAGE outgoing_messages[SSL_MAX_HANDSHAKE_FLIGHT]; + uint8_t outgoing_messages_len; unsigned int mtu; /* max DTLS packet size */ - struct hm_header_st w_msg_hdr; - /* num_timeouts is the number of times the retransmit timer has fired since * the last time it was reset. */ unsigned int num_timeouts; @@ -1070,9 +1073,6 @@ int dtls1_get_record(SSL *ssl); int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek); int dtls1_read_change_cipher_spec(SSL *ssl); void dtls1_read_close_notify(SSL *ssl); -void dtls1_set_message_header(SSL *ssl, uint8_t mt, unsigned long len, - unsigned short seq_num, unsigned long frag_off, - unsigned long frag_len); int dtls1_write_app_data(SSL *ssl, const void *buf, int len);