diff --git a/ssl/d1_both.c b/ssl/d1_both.c index f6442fd6..3eb26c52 100644 --- a/ssl/d1_both.c +++ b/ssl/d1_both.c @@ -259,7 +259,7 @@ static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start, /* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or * SSL3_RT_CHANGE_CIPHER_SPEC) */ -int dtls1_do_write(SSL *s, int type) { +int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) { int ret; int curr_mtu; unsigned int len, frag_off; @@ -350,7 +350,8 @@ int dtls1_do_write(SSL *s, int type) { len = s->init_num; } - ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len); + ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len, + use_epoch); if (ret < 0) { return -1; } @@ -684,7 +685,7 @@ int dtls1_send_change_cipher_spec(SSL *s, int a, int b) { } /* SSL3_ST_CW_CHANGE_B */ - return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC); + return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC, dtls1_use_current_epoch); } int dtls1_read_failed(SSL *s, int code) { @@ -724,7 +725,6 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) { int ret; /* XDTLS: for now assuming that read/writes are blocking */ unsigned long header_length; - uint8_t save_write_sequence[8]; /* assert(s->init_num == 0); assert(s->init_off == 0); */ @@ -743,45 +743,18 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) { frag->msg_header.msg_len, frag->msg_header.seq, 0, frag->msg_header.frag_len); - /* Save current state. */ - SSL_AEAD_CTX *aead_write_ctx = s->aead_write_ctx; - uint16_t epoch = s->d1->w_epoch; - /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1 * (negotiated cipher) exist. */ - assert(epoch == 0 || epoch == 1); - assert(frag->msg_header.epoch <= epoch); - const int fragment_from_previous_epoch = (epoch == 1 && - frag->msg_header.epoch == 0); - if (fragment_from_previous_epoch) { - /* Rewind to the previous epoch. - * - * TODO(davidben): Instead of swapping out connection-global state, this - * logic should pass a "use previous epoch" parameter down to lower-level - * functions. */ - s->d1->w_epoch = frag->msg_header.epoch; - s->aead_write_ctx = NULL; - memcpy(save_write_sequence, s->s3->write_sequence, - sizeof(s->s3->write_sequence)); - memcpy(s->s3->write_sequence, s->d1->last_write_sequence, - sizeof(s->s3->write_sequence)); - } else { - /* Otherwise the messages must be from the same epoch. */ - assert(frag->msg_header.epoch == epoch); + assert(s->d1->w_epoch == 0 || s->d1->w_epoch == 1); + assert(frag->msg_header.epoch <= s->d1->w_epoch); + enum dtls1_use_epoch_t use_epoch = dtls1_use_current_epoch; + if (s->d1->w_epoch == 1 && frag->msg_header.epoch == 0) { + use_epoch = dtls1_use_previous_epoch; } ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC - : SSL3_RT_HANDSHAKE); - - if (fragment_from_previous_epoch) { - /* Restore the current epoch. */ - s->aead_write_ctx = aead_write_ctx; - s->d1->w_epoch = epoch; - memcpy(s->d1->last_write_sequence, s->s3->write_sequence, - sizeof(s->s3->write_sequence)); - memcpy(s->s3->write_sequence, save_write_sequence, - sizeof(s->s3->write_sequence)); - } + : SSL3_RT_HANDSHAKE, + use_epoch); (void)BIO_flush(SSL_get_wbio(s)); return ret; diff --git a/ssl/d1_lib.c b/ssl/d1_lib.c index e53156f2..473588dd 100644 --- a/ssl/d1_lib.c +++ b/ssl/d1_lib.c @@ -338,5 +338,5 @@ int dtls1_set_handshake_header(SSL *s, int htype, unsigned long len) { } int dtls1_handshake_write(SSL *s) { - return dtls1_do_write(s, SSL3_RT_HANDSHAKE); + return dtls1_do_write(s, SSL3_RT_HANDSHAKE, dtls1_use_current_epoch); } diff --git a/ssl/d1_pkt.c b/ssl/d1_pkt.c index e80e773b..b6570ee3 100644 --- a/ssl/d1_pkt.c +++ b/ssl/d1_pkt.c @@ -185,7 +185,7 @@ static int dtls1_record_replay_check(SSL *s, DTLS1_BITMAP *bitmap); static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap); static int dtls1_process_record(SSL *s); static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, - unsigned int len); + unsigned int len, enum dtls1_use_epoch_t use_epoch); static int dtls1_process_record(SSL *s) { int al; @@ -695,18 +695,19 @@ int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf_, int len) { return -1; } - i = dtls1_write_bytes(s, type, buf_, len); + i = dtls1_write_bytes(s, type, buf_, len, dtls1_use_current_epoch); return i; } /* Call this to write data in records of type 'type' It will return <= 0 if not * all data has been sent or non-blocking IO. */ -int dtls1_write_bytes(SSL *s, int type, const void *buf, int len) { +int dtls1_write_bytes(SSL *s, int type, const void *buf, int len, + enum dtls1_use_epoch_t use_epoch) { int i; assert(len <= SSL3_RT_MAX_PLAIN_LENGTH); s->rwstate = SSL_NOTHING; - i = do_dtls1_write(s, type, buf, len); + i = do_dtls1_write(s, type, buf, len, use_epoch); return i; } @@ -716,27 +717,40 @@ int dtls1_write_bytes(SSL *s, int type, const void *buf, int len) { * number. */ static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len, size_t max_out, uint8_t type, const uint8_t *in, - size_t in_len) { + size_t in_len, enum dtls1_use_epoch_t use_epoch) { if (max_out < DTLS1_RT_HEADER_LENGTH) { OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL); return 0; } + /* Determine the parameters for the current epoch. */ + uint16_t epoch = s->d1->w_epoch; + SSL_AEAD_CTX *aead = s->aead_write_ctx; + uint8_t *seq = s->s3->write_sequence; + if (use_epoch == dtls1_use_previous_epoch) { + /* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1 + * (negotiated cipher) exist. */ + assert(s->d1->w_epoch == 1); + epoch = s->d1->w_epoch - 1; + aead = NULL; + seq = s->d1->last_write_sequence; + } + out[0] = type; uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION; out[1] = wire_version >> 8; out[2] = wire_version & 0xff; - out[3] = s->d1->w_epoch >> 8; - out[4] = s->d1->w_epoch & 0xff; - memcpy(&out[5], &s->s3->write_sequence[2], 6); + out[3] = epoch >> 8; + out[4] = epoch & 0xff; + memcpy(&out[5], &seq[2], 6); size_t ciphertext_len; - if (!SSL_AEAD_CTX_seal(s->aead_write_ctx, out + DTLS1_RT_HEADER_LENGTH, - &ciphertext_len, max_out - DTLS1_RT_HEADER_LENGTH, - type, wire_version, &out[3] /* seq */, in, in_len) || - !ssl3_record_sequence_update(&s->s3->write_sequence[2], 6)) { + if (!SSL_AEAD_CTX_seal(aead, out + DTLS1_RT_HEADER_LENGTH, &ciphertext_len, + max_out - DTLS1_RT_HEADER_LENGTH, type, wire_version, + &out[3] /* seq */, in, in_len) || + !ssl3_record_sequence_update(&seq[2], 6)) { return 0; } @@ -758,7 +772,7 @@ static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len, } static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, - unsigned int len) { + unsigned int len, enum dtls1_use_epoch_t use_epoch) { SSL3_BUFFER *wb = &s->s3->wbuf; /* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there @@ -790,7 +804,8 @@ static int do_dtls1_write(SSL *s, int type, const uint8_t *buf, size_t max_out = wb->len - wb->offset; size_t ciphertext_len; - if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len)) { + if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len, + use_epoch)) { return -1; } @@ -863,7 +878,8 @@ int dtls1_dispatch_alert(SSL *s) { *ptr++ = s->s3->send_alert[0]; *ptr++ = s->s3->send_alert[1]; - i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf)); + i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf), + dtls1_use_current_epoch); if (i <= 0) { s->s3->alert_dispatch = 1; } else { diff --git a/ssl/internal.h b/ssl/internal.h index c4e931d3..08a74dc5 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -940,7 +940,12 @@ int ssl3_do_change_cipher_spec(SSL *ssl); int ssl3_set_handshake_header(SSL *s, int htype, unsigned long len); int ssl3_handshake_write(SSL *s); -int dtls1_do_write(SSL *s, int type); +enum dtls1_use_epoch_t { + dtls1_use_previous_epoch, + dtls1_use_current_epoch, +}; + +int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch); int ssl3_read_n(SSL *s, int n, int extend); int dtls1_read_bytes(SSL *s, int type, uint8_t *buf, int len, int peek); int ssl3_write_pending(SSL *s, int type, const uint8_t *buf, unsigned int len); @@ -949,7 +954,8 @@ void dtls1_set_message_header(SSL *s, uint8_t mt, unsigned long len, unsigned long frag_len); int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf, int len); -int dtls1_write_bytes(SSL *s, int type, const void *buf, int len); +int dtls1_write_bytes(SSL *s, int type, const void *buf, int len, + enum dtls1_use_epoch_t use_epoch); int dtls1_send_change_cipher_spec(SSL *s, int a, int b); int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen);