Rewrite DTLS handshake message sending logic.

This fixes a number of bugs with the original logic:

- If handshake messages are fragmented and writes need to be retried, frag_off
  gets completely confused.

- The BIO_flush call didn't set rwstate, so it wasn't resumable at that point.

- The msg_callback call gets garbage because the fragment header would get
  scribbled over the handshake buffer.

The original logic was also extremely confusing with how it handles init_off.
(init_off gets rewound to make room for the fragment header.  Depending on
where you pause, resuming may or may not have already been rewound.)

For simplicity, just allocate a new buffer to assemble the fragment in and
avoid clobbering the old one. I don't think it's worth the complexity to
optimize that. If we want to optimize this sort of thing, not clobbering seems
better anyway because the message may need to be retransmitted. We could avoid
doing a copy when buffering the outgoing message for retransmission later.

We do still need to track how far we are in sending the current message via
init_off, so I haven't opted to disconnect this function from
init_{buf,off,num} yet.

Test the fix to the retry + fragment case by having the splitHandshake option
to the state machine tests, in DTLS, also clamp the MTU to force handshake
fragmentation.

Change-Id: I66f634d6c752ea63649db8ed2f898f9cc2b13908
Reviewed-on: https://boringssl-review.googlesource.com/6421
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2015-11-03 15:39:45 -05:00 committed by Adam Langley
parent c81ee8b40c
commit 16285ea800
2 changed files with 112 additions and 126 deletions

View File

@ -145,10 +145,6 @@ static const unsigned int kDefaultMTU = 1500 - 28;
* current one to buffer. */ * current one to buffer. */
static const unsigned int kHandshakeBufferSize = 10; static const unsigned int kHandshakeBufferSize = 10;
static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
unsigned long frag_len);
static unsigned char *dtls1_write_message_header(SSL *s, unsigned char *p);
static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) { static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) {
hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment)); hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
if (frag == NULL) { if (frag == NULL) {
@ -268,14 +264,34 @@ static void dtls1_update_mtu(SSL *ssl) {
assert(ssl->d1->mtu >= dtls1_min_mtu()); assert(ssl->d1->mtu >= dtls1_min_mtu());
} }
/* dtls1_max_record_size returns the maximum record body length that may be
* written without exceeding the MTU. It accounts for any buffering installed on
* the write BIO. If no record may be written, it returns zero. */
static size_t dtls1_max_record_size(SSL *ssl) {
size_t ret = ssl->d1->mtu;
size_t overhead = ssl_max_seal_overhead(ssl);
if (ret <= overhead) {
return 0;
}
ret -= overhead;
size_t pending = BIO_wpending(SSL_get_wbio(ssl));
if (ret <= pending) {
return 0;
}
ret -= pending;
return ret;
}
static int dtls1_write_change_cipher_spec(SSL *ssl, static int dtls1_write_change_cipher_spec(SSL *ssl,
enum dtls1_use_epoch_t use_epoch) { enum dtls1_use_epoch_t use_epoch) {
dtls1_update_mtu(ssl); dtls1_update_mtu(ssl);
/* During the handshake, wbio is buffered to pack messages together. Flush the /* During the handshake, wbio is buffered to pack messages together. Flush the
* buffer if the ChangeCipherSpec would not fit in a packet. */ * buffer if the ChangeCipherSpec would not fit in a packet. */
if (BIO_wpending(SSL_get_wbio(ssl)) + ssl_max_seal_overhead(ssl) + 1 > if (dtls1_max_record_size(ssl) == 0) {
ssl->d1->mtu) {
ssl->rwstate = SSL_WRITING; ssl->rwstate = SSL_WRITING;
int ret = BIO_flush(SSL_get_wbio(ssl)); int ret = BIO_flush(SSL_get_wbio(ssl));
if (ret <= 0) { if (ret <= 0) {
@ -300,103 +316,96 @@ static int dtls1_write_change_cipher_spec(SSL *ssl,
return 1; return 1;
} }
int dtls1_do_handshake_write(SSL *s, enum dtls1_use_epoch_t use_epoch) { int dtls1_do_handshake_write(SSL *ssl, enum dtls1_use_epoch_t use_epoch) {
int ret; dtls1_update_mtu(ssl);
int curr_mtu;
unsigned int len, frag_off;
dtls1_update_mtu(s); int ret = -1;
CBB cbb;
if (s->init_off == 0) { CBB_zero(&cbb);
assert(s->init_num == /* Allocate a temporary buffer to hold the message fragments to avoid
(int)s->d1->w_msg_hdr.msg_len + DTLS1_HM_HEADER_LENGTH); * clobbering the message. */
uint8_t *buf = OPENSSL_malloc(ssl->d1->mtu);
if (buf == NULL) {
goto err;
} }
/* Determine the maximum overhead of the current cipher. */ /* Consume the message header. Fragments will have different headers
size_t max_overhead = SSL_AEAD_CTX_max_overhead(s->aead_write_ctx); * prepended. */
if (ssl->init_off == 0) {
frag_off = 0; ssl->init_off += DTLS1_HM_HEADER_LENGTH;
while (s->init_num) { ssl->init_num -= DTLS1_HM_HEADER_LENGTH;
/* Account for data in the buffering BIO; multiple records may be packed
* into a single packet during the handshake.
*
* TODO(davidben): This is buggy; if the MTU is larger than the buffer size,
* the large record will be split across two packets. Moreover, in that
* case, the |dtls1_write_bytes| call may not return synchronously. This
* will break on retry as the |s->init_off| and |s->init_num| adjustment
* will run a second time. */
curr_mtu = s->d1->mtu - BIO_wpending(SSL_get_wbio(s)) -
DTLS1_RT_HEADER_LENGTH - max_overhead;
if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
/* Flush the buffer and continue with a fresh packet.
*
* TODO(davidben): If |BIO_flush| is not synchronous and requires multiple
* calls to |dtls1_do_write|, |frag_off| will be wrong. */
ret = BIO_flush(SSL_get_wbio(s));
if (ret <= 0) {
return ret;
} }
assert(BIO_wpending(SSL_get_wbio(s)) == 0); assert(ssl->init_off >= DTLS1_HM_HEADER_LENGTH);
curr_mtu = s->d1->mtu - DTLS1_RT_HEADER_LENGTH - max_overhead;
do {
/* During the handshake, wbio is buffered to pack messages together. Flush
* the buffer if there isn't enough room to make progress. */
if (dtls1_max_record_size(ssl) < DTLS1_HM_HEADER_LENGTH + 1) {
ssl->rwstate = SSL_WRITING;
int flush_ret = BIO_flush(SSL_get_wbio(ssl));
if (flush_ret <= 0) {
ret = flush_ret;
goto err;
}
ssl->rwstate = SSL_NOTHING;
assert(BIO_wpending(SSL_get_wbio(ssl)) == 0);
} }
/* If this isn't the first fragment, reserve space to prepend a new fragment size_t todo = dtls1_max_record_size(ssl);
* header. This will override the body of a previous fragment. */ if (todo < DTLS1_HM_HEADER_LENGTH + 1) {
if (s->init_off != 0) {
assert(s->init_off > DTLS1_HM_HEADER_LENGTH);
s->init_off -= DTLS1_HM_HEADER_LENGTH;
s->init_num += DTLS1_HM_HEADER_LENGTH;
}
if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
/* To make forward progress, the MTU must, at minimum, fit the handshake /* To make forward progress, the MTU must, at minimum, fit the handshake
* header and one byte of handshake body. */ * header and one byte of handshake body. */
OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL); OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
return -1; goto err;
}
todo -= DTLS1_HM_HEADER_LENGTH;
if (todo > (size_t)ssl->init_num) {
todo = ssl->init_num;
}
if (todo >= (1u << 24)) {
todo = (1u << 24) - 1;
} }
if (s->init_num > curr_mtu) { size_t len;
len = curr_mtu; if (!CBB_init_fixed(&cbb, buf, ssl->d1->mtu) ||
} else { !CBB_add_u8(&cbb, ssl->d1->w_msg_hdr.type) ||
len = s->init_num; !CBB_add_u24(&cbb, ssl->d1->w_msg_hdr.msg_len) ||
} !CBB_add_u16(&cbb, ssl->d1->w_msg_hdr.seq) ||
assert(len >= DTLS1_HM_HEADER_LENGTH); !CBB_add_u24(&cbb, ssl->init_off - DTLS1_HM_HEADER_LENGTH) ||
!CBB_add_u24(&cbb, todo) ||
dtls1_fix_message_header(s, frag_off, len - DTLS1_HM_HEADER_LENGTH); !CBB_add_bytes(
dtls1_write_message_header( &cbb, (const uint8_t *)ssl->init_buf->data + ssl->init_off, todo) ||
s, (uint8_t *)&s->init_buf->data[s->init_off]); !CBB_finish(&cbb, NULL, &len)) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
ret = dtls1_write_bytes(s, SSL3_RT_HANDSHAKE, goto err;
&s->init_buf->data[s->init_off], len, use_epoch);
if (ret < 0) {
return -1;
} }
/* bad if this assert fails, only part of the handshake message got sent. int write_ret = dtls1_write_bytes(ssl, SSL3_RT_HANDSHAKE, buf, len,
* But why would this happen? */ use_epoch);
assert(len == (unsigned int)ret); if (write_ret <= 0) {
ret = write_ret;
goto err;
}
ssl->init_off += todo;
ssl->init_num -= todo;
} while (ssl->init_num > 0);
if (ret == s->init_num) { if (ssl->msg_callback != NULL) {
if (s->msg_callback) { ssl->msg_callback(
/* TODO(davidben): At this point, |s->init_buf->data| has been clobbered 1 /* write */, ssl->version, SSL3_RT_HANDSHAKE, ssl->init_buf->data,
* already. */ (size_t)(ssl->init_off + ssl->init_num), ssl, ssl->msg_callback_arg);
s->msg_callback(1, s->version, SSL3_RT_HANDSHAKE, s->init_buf->data,
(size_t)(s->init_off + s->init_num), s,
s->msg_callback_arg);
} }
s->init_off = 0; /* done writing this message */ ssl->init_off = 0;
s->init_num = 0; ssl->init_num = 0;
return 1; ret = 1;
}
s->init_off += ret;
s->init_num -= ret;
frag_off += (ret -= DTLS1_HM_HEADER_LENGTH);
}
return 0; err:
CBB_cleanup(&cbb);
OPENSSL_free(buf);
return ret;
} }
/* dtls1_is_next_message_complete returns one if the next handshake message is /* dtls1_is_next_message_complete returns one if the next handshake message is
@ -855,27 +864,6 @@ void dtls1_set_message_header(SSL *s, uint8_t mt, unsigned long len,
msg_hdr->frag_len = frag_len; msg_hdr->frag_len = frag_len;
} }
static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
unsigned long frag_len) {
struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr;
msg_hdr->frag_off = frag_off;
msg_hdr->frag_len = frag_len;
}
static uint8_t *dtls1_write_message_header(SSL *s, uint8_t *p) {
struct hm_header_st *msg_hdr = &s->d1->w_msg_hdr;
*p++ = msg_hdr->type;
l2n3(msg_hdr->msg_len, p);
s2n(msg_hdr->seq, p);
l2n3(msg_hdr->frag_off, p);
l2n3(msg_hdr->frag_len, p);
return p;
}
unsigned int dtls1_min_mtu(void) { unsigned int dtls1_min_mtu(void) {
return kMinMTU; return kMinMTU;
} }

View File

@ -2932,27 +2932,25 @@ func addStateMachineCoverageTests(async, splitHandshake bool, protocol protocol)
}) })
} }
var suffix string
var flags []string
var maxHandshakeRecordLength int
if protocol == dtls {
suffix = "-DTLS"
}
if async {
suffix += "-Async"
flags = append(flags, "-async")
} else {
suffix += "-Sync"
}
if splitHandshake {
suffix += "-SplitHandshakeRecords"
maxHandshakeRecordLength = 1
}
for _, test := range tests { for _, test := range tests {
test.protocol = protocol test.protocol = protocol
test.name += suffix if protocol == dtls {
test.config.Bugs.MaxHandshakeRecordLength = maxHandshakeRecordLength test.name += "-DTLS"
test.flags = append(test.flags, flags...) }
if async {
test.name += "-Async"
test.flags = append(test.flags, "-async")
} else {
test.name += "-Sync"
}
if splitHandshake {
test.name += "-SplitHandshakeRecords"
test.config.Bugs.MaxHandshakeRecordLength = 1
if protocol == dtls {
test.config.Bugs.MaxPacketLength = 256
test.flags = append(test.flags, "-mtu", "256")
}
}
testCases = append(testCases, test) testCases = append(testCases, test)
} }
} }