@@ -139,6 +139,15 @@ static const unsigned int kMinMTU = 256 - 28;
* the underlying BIO supplies one. */
* the underlying BIO supplies one. */
static const unsigned int kDefaultMTU = 1500 - 28;
static const unsigned int kDefaultMTU = 1500 - 28;
static void dtls1_hm_fragment_free(hm_fragment *frag) {
if (frag == NULL) {
return;
}
OPENSSL_free(frag->fragment);
OPENSSL_free(frag->reassembly);
OPENSSL_free(frag);
}
static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
static hm_fragment *dtls1_hm_fragment_new(size_t frag_len) {
hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
if (frag == NULL) {
if (frag == NULL) {
@@ -177,15 +186,6 @@ err:
return NULL;
return NULL;
}
}
void dtls1_hm_fragment_free(hm_fragment *frag) {
if (frag == NULL) {
return;
}
OPENSSL_free(frag->fragment);
OPENSSL_free(frag->reassembly);
OPENSSL_free(frag);
}
#if !defined(inline)
#if !defined(inline)
#define inline __inline
#define inline __inline
#endif
#endif
@@ -407,16 +407,9 @@ err:
/* dtls1_is_next_message_complete returns one if the next handshake message is
/* dtls1_is_next_message_complete returns one if the next handshake message is
* complete and zero otherwise. */
* complete and zero otherwise. */
static int dtls1_is_next_message_complete(SSL *ssl) {
static int dtls1_is_next_message_complete(SSL *ssl) {
pitem *item = pqueue_peek(ssl->d1->buffered_messages);
if (item == NULL) {
return 0;
}
hm_fragment *frag = (hm_fragment *)item->data;
assert(ssl->d1->handshake_read_seq <= frag->msg_header.seq);
return ssl->d1->handshake_read_seq == frag->msg_header.seq &&
frag->reassembly == NULL;
hm_fragment *frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
SSL_MAX_HANDSHAKE_FLIGHT];
return frag != NULL && frag->reassembly == NULL;
}
}
/* dtls1_get_buffered_message returns the buffered message corresponding to
/* dtls1_get_buffered_message returns the buffered message corresponding to
@@ -425,41 +418,33 @@ static int dtls1_is_next_message_complete(SSL *ssl) {
* returns NULL on failure. The caller does not take ownership of the result. */
* returns NULL on failure. The caller does not take ownership of the result. */
static hm_fragment *dtls1_get_buffered_message(
static hm_fragment *dtls1_get_buffered_message(
SSL *ssl, const struct hm_header_st *msg_hdr) {
SSL *ssl, const struct hm_header_st *msg_hdr) {
uint8_t seq64be[8];
memset(seq64be, 0, sizeof(seq64be));
seq64be[6] = (uint8_t)(msg_hdr->seq >> 8);
seq64be[7] = (uint8_t)msg_hdr->seq;
pitem *item = pqueue_find(ssl->d1->buffered_messages, seq64be);
hm_fragment *frag;
if (item == NULL) {
/* This is the first fragment from this message. */
frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
if (frag == NULL) {
return NULL;
}
memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
item = pitem_new(seq64be, frag);
if (item == NULL) {
dtls1_hm_fragment_free(frag);
return NULL;
}
item = pqueue_insert(ssl->d1->buffered_messages, item);
/* |pqueue_insert| fails iff a duplicate item is inserted, but |item| cannot
* be a duplicate. */
assert(item != NULL);
} else {
frag = item->data;
if (msg_hdr->seq < ssl->d1->handshake_read_seq ||
msg_hdr->seq - ssl->d1->handshake_read_seq >= SSL_MAX_HANDSHAKE_FLIGHT) {
return NULL;
}
size_t idx = msg_hdr->seq % SSL_MAX_HANDSHAKE_FLIGHT;
hm_fragment *frag = ssl->d1->incoming_messages[idx];
if (frag != NULL) {
assert(frag->msg_header.seq == msg_hdr->seq);
assert(frag->msg_header.seq == msg_hdr->seq);
/* The new fragment must be compatible with the previous fragments from this
* message. */
if (frag->msg_header.type != msg_hdr->type ||
if (frag->msg_header.type != msg_hdr->type ||
frag->msg_header.msg_len != msg_hdr->msg_len) {
frag->msg_header.msg_len != msg_hdr->msg_len) {
/* The new fragment must be compatible with the previous fragments from
* this message. */
OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return NULL;
return NULL;
}
}
return frag;
}
/* This is the first fragment from this message. */
frag = dtls1_hm_fragment_new(msg_hdr->msg_len);
if (frag == NULL) {
return NULL;
}
}
memcpy(&frag->msg_header, msg_hdr, sizeof(*msg_hdr));
ssl->d1->incoming_messages[idx] = frag;
return frag;
return frag;
}
}
@@ -557,7 +542,6 @@ start:
* arrive in fragments. */
* arrive in fragments. */
long dtls1_get_message(SSL *ssl, int msg_type,
long dtls1_get_message(SSL *ssl, int msg_type,
enum ssl_hash_message_t hash_message, int *ok) {
enum ssl_hash_message_t hash_message, int *ok) {
pitem *item = NULL;
hm_fragment *frag = NULL;
hm_fragment *frag = NULL;
int al;
int al;
@@ -590,12 +574,17 @@ long dtls1_get_message(SSL *ssl, int msg_type,
}
}
}
}
/* Read out the next complete handshake message. */
item = pqueue_pop(ssl->d1->buffered_messages);
assert(item != NULL);
frag = (hm_fragment *)item->data;
assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
/* Pop an entry from the ring buffer. */
frag = ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
SSL_MAX_HANDSHAKE_FLIGHT];
ssl->d1->incoming_messages[ssl->d1->handshake_read_seq %
SSL_MAX_HANDSHAKE_FLIGHT] = NULL;
assert(frag != NULL);
assert(frag->reassembly == NULL);
assert(frag->reassembly == NULL);
assert(ssl->d1->handshake_read_seq == frag->msg_header.seq);
ssl->d1->handshake_read_seq++;
/* Reconstruct the assembled message. */
/* Reconstruct the assembled message. */
CBB cbb;
CBB cbb;
@@ -618,8 +607,6 @@ long dtls1_get_message(SSL *ssl, int msg_type,
assert(ssl->init_buf->length ==
assert(ssl->init_buf->length ==
(size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
(size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
ssl->d1->handshake_read_seq++;
/* TODO(davidben): This function has a lot of implicit outputs. Simplify the
/* TODO(davidben): This function has a lot of implicit outputs. Simplify the
* |ssl_get_message| API. */
* |ssl_get_message| API. */
ssl->s3->tmp.message_type = frag->msg_header.type;
ssl->s3->tmp.message_type = frag->msg_header.type;
@@ -639,7 +626,6 @@ long dtls1_get_message(SSL *ssl, int msg_type,
ssl->init_buf->data,
ssl->init_buf->data,
ssl->init_num + DTLS1_HM_HEADER_LENGTH);
ssl->init_num + DTLS1_HM_HEADER_LENGTH);
pitem_free(item);
dtls1_hm_fragment_free(frag);
dtls1_hm_fragment_free(frag);
*ok = 1;
*ok = 1;
@@ -648,7 +634,6 @@ long dtls1_get_message(SSL *ssl, int msg_type,
f_err:
f_err:
ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
ssl3_send_alert(ssl, SSL3_AL_FATAL, al);
err:
err:
pitem_free(item);
dtls1_hm_fragment_free(frag);
dtls1_hm_fragment_free(frag);
*ok = 0;
*ok = 0;
return -1;
return -1;
@@ -765,6 +750,14 @@ void dtls_clear_outgoing_messages(SSL *ssl) {
ssl->d1->outgoing_messages_len = 0;
ssl->d1->outgoing_messages_len = 0;
}
}
void dtls_clear_incoming_messages(SSL *ssl) {
size_t i;
for (i = 0; i < SSL_MAX_HANDSHAKE_FLIGHT; i++) {
dtls1_hm_fragment_free(ssl->d1->incoming_messages[i]);
ssl->d1->incoming_messages[i] = NULL;
}
}
unsigned int dtls1_min_mtu(void) {
unsigned int dtls1_min_mtu(void) {
return kMinMTU;
return kMinMTU;
}
}