Don't use dtls1_read_bytes to read messages.

This was probably the worst offender of them all as read_bytes is the wrong
abstraction to begin with. Note this is a slight change in how processing a
record works. Rather than reading one fragment at a time, we process all
fragments in a record and return. The intent here is so that all records are
processed atomically since the connection eventually will not be able to retain
a buffer holding the record.

This loses a ton of (though not quite all yet) those a2b macros.

Change-Id: Ibe4bbcc33c496328de08d272457d2282c411b38b
Reviewed-on: https://boringssl-review.googlesource.com/8176
Reviewed-by: David Benjamin <davidben@google.com>
This commit is contained in:
David Benjamin 2016-06-02 16:38:35 -04:00
parent 585320c9e9
commit c660417bd7
4 changed files with 107 additions and 202 deletions

View File

@ -415,24 +415,6 @@ static int dtls1_is_next_message_complete(SSL *ssl) {
frag->reassembly == NULL;
}
/* dtls1_discard_fragment_body discards a handshake fragment body of length
* |frag_len|. It returns one on success and zero on error.
*
* TODO(davidben): This function will go away when ssl_read_bytes is gone from
* the DTLS side. */
static int dtls1_discard_fragment_body(SSL *ssl, size_t frag_len) {
uint8_t discard[256];
while (frag_len > 0) {
size_t chunk = frag_len < sizeof(discard) ? frag_len : sizeof(discard);
int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, discard, chunk, 0);
if (ret != (int) chunk) {
return 0;
}
frag_len -= chunk;
}
return 1;
}
/* dtls1_get_buffered_message returns the buffered message corresponding to
* |msg_hdr|. If none exists, it creates a new one and inserts it in the
* queue. Otherwise, it checks |msg_hdr| is consistent with the existing one. It
@ -478,75 +460,92 @@ static hm_fragment *dtls1_get_buffered_message(
return frag;
}
/* dtls1_process_fragment reads a handshake fragment and processes it. It
* returns one if a fragment was successfully processed and 0 or -1 on error. */
static int dtls1_process_fragment(SSL *ssl) {
/* Read handshake message header. */
uint8_t header[DTLS1_HM_HEADER_LENGTH];
int ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, header,
DTLS1_HM_HEADER_LENGTH, 0);
if (ret <= 0) {
return ret;
/* dtls1_process_handshake_record reads a handshake record and processes it. It
* returns one if the record was successfully processed and 0 or -1 on error. */
static int dtls1_process_handshake_record(SSL *ssl) {
SSL3_RECORD *rr = &ssl->s3->rrec;
start:
if (rr->length == 0) {
int ret = dtls1_get_record(ssl);
if (ret <= 0) {
return ret;
}
}
if (ret != DTLS1_HM_HEADER_LENGTH) {
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
/* Cross-epoch records are discarded, but we may receive out-of-order
* application data between ChangeCipherSpec and Finished or a ChangeCipherSpec
* before the appropriate point in the handshake. Those must be silently
* discarded.
*
* However, only allow the out-of-order records in the correct epoch.
* Application data must come in the encrypted epoch, and ChangeCipherSpec in
* the unencrypted epoch (we never renegotiate). Other cases fall through and
* fail with a fatal error. */
if ((rr->type == SSL3_RT_APPLICATION_DATA &&
ssl->s3->aead_read_ctx != NULL) ||
(rr->type == SSL3_RT_CHANGE_CIPHER_SPEC &&
ssl->s3->aead_read_ctx == NULL)) {
rr->length = 0;
goto start;
}
if (rr->type != SSL3_RT_HANDSHAKE) {
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_RECORD);
return -1;
}
/* Parse the message fragment header. */
struct hm_header_st msg_hdr;
dtls1_get_message_header(header, &msg_hdr);
CBS cbs;
CBS_init(&cbs, rr->data, rr->length);
/* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There
* should be no need to reach into |ssl->s3->rrec.length|. */
const size_t frag_off = msg_hdr.frag_off;
const size_t frag_len = msg_hdr.frag_len;
const size_t msg_len = msg_hdr.msg_len;
if (frag_off > msg_len || frag_off + frag_len < frag_off ||
frag_off + frag_len > msg_len ||
msg_len > ssl_max_handshake_message_len(ssl) ||
frag_len > ssl->s3->rrec.length) {
OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
}
if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
msg_hdr.seq > (unsigned)ssl->d1->handshake_read_seq +
kHandshakeBufferSize) {
/* Ignore fragments from the past, or ones too far in the future. */
if (!dtls1_discard_fragment_body(ssl, frag_len)) {
while (CBS_len(&cbs) > 0) {
/* Read a handshake fragment. */
struct hm_header_st msg_hdr;
CBS body;
if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR);
return -1;
}
return 1;
}
hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr);
if (frag == NULL) {
return -1;
}
assert(frag->msg_header.msg_len == msg_len);
if (frag->reassembly == NULL) {
/* The message is already assembled. */
if (!dtls1_discard_fragment_body(ssl, frag_len)) {
const size_t frag_off = msg_hdr.frag_off;
const size_t frag_len = msg_hdr.frag_len;
const size_t msg_len = msg_hdr.msg_len;
if (frag_off > msg_len || frag_off + frag_len < frag_off ||
frag_off + frag_len > msg_len ||
msg_len > ssl_max_handshake_message_len(ssl)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
}
return 1;
}
assert(msg_len > 0);
/* Read the body of the fragment. */
ret = dtls1_read_bytes(ssl, SSL3_RT_HANDSHAKE, frag->fragment + frag_off,
frag_len, 0);
if (ret != (int) frag_len) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
return -1;
}
dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
if (msg_hdr.seq < ssl->d1->handshake_read_seq ||
msg_hdr.seq >
(unsigned)ssl->d1->handshake_read_seq + kHandshakeBufferSize) {
/* Ignore fragments from the past, or ones too far in the future. */
continue;
}
hm_fragment *frag = dtls1_get_buffered_message(ssl, &msg_hdr);
if (frag == NULL) {
return -1;
}
assert(frag->msg_header.msg_len == msg_len);
if (frag->reassembly == NULL) {
/* The message is already assembled. */
continue;
}
assert(msg_len > 0);
/* Copy the body into the fragment. */
memcpy(frag->fragment + frag_off, CBS_data(&body), CBS_len(&body));
dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
}
rr->length = 0;
ssl_read_buffer_discard(ssl);
return 1;
}
@ -579,9 +578,9 @@ long dtls1_get_message(SSL *ssl, int msg_type,
return ssl->init_num;
}
/* Process fragments until one is found. */
/* Process handshake records until the next message is ready. */
while (!dtls1_is_next_message_complete(ssl)) {
int ret = dtls1_process_fragment(ssl);
int ret = dtls1_process_handshake_record(ssl);
if (ret <= 0) {
*ok = 0;
return ret;
@ -835,13 +834,18 @@ unsigned int dtls1_min_mtu(void) {
return kMinMTU;
}
void dtls1_get_message_header(uint8_t *data,
struct hm_header_st *msg_hdr) {
memset(msg_hdr, 0x00, sizeof(struct hm_header_st));
msg_hdr->type = *(data++);
n2l3(data, msg_hdr->msg_len);
int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
CBS *out_body) {
memset(out_hdr, 0x00, sizeof(struct hm_header_st));
n2s(data, msg_hdr->seq);
n2l3(data, msg_hdr->frag_off);
n2l3(data, msg_hdr->frag_len);
if (!CBS_get_u8(cbs, &out_hdr->type) ||
!CBS_get_u24(cbs, &out_hdr->msg_len) ||
!CBS_get_u16(cbs, &out_hdr->seq) ||
!CBS_get_u24(cbs, &out_hdr->frag_off) ||
!CBS_get_u24(cbs, &out_hdr->frag_len) ||
!CBS_get_bytes(cbs, out_body, out_hdr->frag_len)) {
return 0;
}
return 1;
}

View File

@ -116,6 +116,7 @@
#include <openssl/bio.h>
#include <openssl/buf.h>
#include <openssl/bytestring.h>
#include <openssl/mem.h>
#include <openssl/evp.h>
#include <openssl/err.h>
@ -127,10 +128,7 @@
static int do_dtls1_write(SSL *ssl, int type, const uint8_t *buf,
unsigned int len, enum dtls1_use_epoch_t use_epoch);
/* dtls1_get_record reads a new input record. On success, it places it in
* |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
* more data is needed. */
static int dtls1_get_record(SSL *ssl) {
int dtls1_get_record(SSL *ssl) {
again:
switch (ssl->s3->recv_shutdown) {
case ssl_shutdown_none:
@ -258,10 +256,7 @@ void dtls1_read_close_notify(SSL *ssl) {
}
/* Return up to 'len' payload bytes received in 'type' records.
* 'type' is one of the following:
*
* - SSL3_RT_HANDSHAKE (when dtls1_get_message calls us)
* - SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us)
* 'type' must be SSL3_RT_APPLICATION_DATA (when dtls1_read_app_data calls us).
*
* If we don't have stored data to work from, read a DTLS record first (possibly
* multiple records if we still don't have anything to return).
@ -273,8 +268,7 @@ int dtls1_read_bytes(SSL *ssl, int type, unsigned char *buf, int len, int peek)
unsigned int n;
SSL3_RECORD *rr;
if ((type != SSL3_RT_APPLICATION_DATA && type != SSL3_RT_HANDSHAKE) ||
(peek && type != SSL3_RT_APPLICATION_DATA)) {
if (type != SSL3_RT_APPLICATION_DATA) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return -1;
}
@ -327,35 +321,18 @@ start:
/* If we get here, then type != rr->type. */
/* Cross-epoch records are discarded, but we may receive out-of-order
* application data between ChangeCipherSpec and Finished or a ChangeCipherSpec
* before the appropriate point in the handshake. Those must be silently
* discarded.
*
* However, only allow the out-of-order records in the correct epoch.
* Application data must come in the encrypted epoch, and ChangeCipherSpec in
* the unencrypted epoch (we never renegotiate). Other cases fall through and
* fail with a fatal error. */
if ((rr->type == SSL3_RT_APPLICATION_DATA &&
ssl->s3->aead_read_ctx != NULL) ||
(rr->type == SSL3_RT_CHANGE_CIPHER_SPEC &&
ssl->s3->aead_read_ctx == NULL)) {
rr->length = 0;
goto start;
}
if (rr->type == SSL3_RT_HANDSHAKE) {
assert(type == SSL3_RT_APPLICATION_DATA);
/* Parse the first fragment header to determine if this is a pre-CCS or
* post-CCS handshake record. DTLS resets handshake message numbers on each
* handshake, so renegotiations and retransmissions are ambiguous. */
if (rr->length < DTLS1_HM_HEADER_LENGTH) {
CBS cbs, body;
struct hm_header_st msg_hdr;
CBS_init(&cbs, rr->data, rr->length);
if (!dtls1_parse_fragment(&cbs, &msg_hdr, &body)) {
al = SSL_AD_DECODE_ERROR;
OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_HANDSHAKE_RECORD);
goto f_err;
}
struct hm_header_st msg_hdr;
dtls1_get_message_header(rr->data, &msg_hdr);
if (msg_hdr.type == SSL3_MT_FINISHED) {
if (msg_hdr.frag_off == 0) {

View File

@ -704,100 +704,17 @@ void ssl_write_buffer_clear(SSL *ssl);
*
* Functions below here haven't been touched up and may be underdocumented. */
#define c2l(c, l) \
(l = ((unsigned long)(*((c)++))), l |= (((unsigned long)(*((c)++))) << 8), \
l |= (((unsigned long)(*((c)++))) << 16), \
l |= (((unsigned long)(*((c)++))) << 24))
/* NOTE - c is not incremented as per c2l */
#define c2ln(c, l1, l2, n) \
{ \
c += n; \
l1 = l2 = 0; \
switch (n) { \
case 8: \
l2 = ((unsigned long)(*(--(c)))) << 24; \
case 7: \
l2 |= ((unsigned long)(*(--(c)))) << 16; \
case 6: \
l2 |= ((unsigned long)(*(--(c)))) << 8; \
case 5: \
l2 |= ((unsigned long)(*(--(c)))); \
case 4: \
l1 = ((unsigned long)(*(--(c)))) << 24; \
case 3: \
l1 |= ((unsigned long)(*(--(c)))) << 16; \
case 2: \
l1 |= ((unsigned long)(*(--(c)))) << 8; \
case 1: \
l1 |= ((unsigned long)(*(--(c)))); \
} \
}
#define l2c(l, c) \
(*((c)++) = (uint8_t)(((l)) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 8) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 24) & 0xff))
#define n2l(c, l) \
(l = ((unsigned long)(*((c)++))) << 24, \
l |= ((unsigned long)(*((c)++))) << 16, \
l |= ((unsigned long)(*((c)++))) << 8, l |= ((unsigned long)(*((c)++))))
#define l2n(l, c) \
(*((c)++) = (uint8_t)(((l) >> 24) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 8) & 0xff), \
*((c)++) = (uint8_t)(((l)) & 0xff))
#define l2n8(l, c) \
(*((c)++) = (uint8_t)(((l) >> 56) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 48) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 40) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 32) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 24) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 16) & 0xff), \
*((c)++) = (uint8_t)(((l) >> 8) & 0xff), \
*((c)++) = (uint8_t)(((l)) & 0xff))
/* NOTE - c is not incremented as per l2c */
#define l2cn(l1, l2, c, n) \
{ \
c += n; \
switch (n) { \
case 8: \
*(--(c)) = (uint8_t)(((l2) >> 24) & 0xff); \
case 7: \
*(--(c)) = (uint8_t)(((l2) >> 16) & 0xff); \
case 6: \
*(--(c)) = (uint8_t)(((l2) >> 8) & 0xff); \
case 5: \
*(--(c)) = (uint8_t)(((l2)) & 0xff); \
case 4: \
*(--(c)) = (uint8_t)(((l1) >> 24) & 0xff); \
case 3: \
*(--(c)) = (uint8_t)(((l1) >> 16) & 0xff); \
case 2: \
*(--(c)) = (uint8_t)(((l1) >> 8) & 0xff); \
case 1: \
*(--(c)) = (uint8_t)(((l1)) & 0xff); \
} \
}
#define n2s(c, s) \
((s = (((unsigned int)(c[0])) << 8) | (((unsigned int)(c[1])))), c += 2)
#define s2n(s, c) \
((c[0] = (uint8_t)(((s) >> 8) & 0xff), \
c[1] = (uint8_t)(((s)) & 0xff)), \
c += 2)
#define n2l3(c, l) \
((l = (((unsigned long)(c[0])) << 16) | (((unsigned long)(c[1])) << 8) | \
(((unsigned long)(c[2])))), \
c += 3)
#define l2n3(l, c) \
((c[0] = (uint8_t)(((l) >> 16) & 0xff), \
c[1] = (uint8_t)(((l) >> 8) & 0xff), \
@ -1126,6 +1043,12 @@ int ssl3_set_handshake_header(SSL *ssl, int htype, unsigned long len);
int ssl3_handshake_write(SSL *ssl);
int dtls1_do_handshake_write(SSL *ssl, enum dtls1_use_epoch_t use_epoch);
/* dtls1_get_record reads a new input record. On success, it places it in
* |ssl->s3->rrec| and returns one. Otherwise it returns <= 0 on error or if
* more data is needed. */
int dtls1_get_record(SSL *ssl);
int dtls1_read_app_data(SSL *ssl, uint8_t *buf, int len, int peek);
int dtls1_read_change_cipher_spec(SSL *ssl);
void dtls1_read_close_notify(SSL *ssl);
@ -1143,7 +1066,8 @@ int dtls1_send_finished(SSL *ssl, int a, int b, const char *sender, int slen);
int dtls1_buffer_message(SSL *ssl);
int dtls1_retransmit_buffered_messages(SSL *ssl);
void dtls1_clear_record_buffer(SSL *ssl);
void dtls1_get_message_header(uint8_t *data, struct hm_header_st *msg_hdr);
int dtls1_parse_fragment(CBS *cbs, struct hm_header_st *out_hdr,
CBS *out_body);
int dtls1_check_timeout_num(SSL *ssl);
int dtls1_set_handshake_header(SSL *ssl, int type, unsigned long len);
int dtls1_handshake_write(SSL *ssl);

View File

@ -1772,7 +1772,7 @@ func addBasicTests() {
},
},
shouldFail: true,
expectedError: ":UNEXPECTED_MESSAGE:",
expectedError: ":BAD_HANDSHAKE_RECORD:",
},
{
protocol: dtls,
@ -1783,7 +1783,7 @@ func addBasicTests() {
},
},
shouldFail: true,
expectedError: ":EXCESSIVE_MESSAGE_SIZE:",
expectedError: ":BAD_HANDSHAKE_RECORD:",
},
{
protocol: dtls,
@ -1794,7 +1794,7 @@ func addBasicTests() {
},
},
shouldFail: true,
expectedError: ":EXCESSIVE_MESSAGE_SIZE:",
expectedError: ":BAD_HANDSHAKE_RECORD:",
},
{
protocol: dtls,