Pass a dtls1_use_epoch enum down to dtls1_seal_record.
This is considerably less scary than swapping out connection state. It also fixes a minor bug where, if dtls1_do_write had an alert to dispatch and we happened to retry during a rexmit, it would use the wrong epoch. BUG=468889 Change-Id: I754b0d46bfd02f797f4c3f7cfde28d3e5f30c52b Reviewed-on: https://boringssl-review.googlesource.com/4793 Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
parent
31a07798a5
commit
3e3090dc50
@ -259,7 +259,7 @@ static void dtls1_hm_fragment_mark(hm_fragment *frag, size_t start,
|
|||||||
|
|
||||||
/* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
|
/* send s->init_buf in records of type 'type' (SSL3_RT_HANDSHAKE or
|
||||||
* SSL3_RT_CHANGE_CIPHER_SPEC) */
|
* SSL3_RT_CHANGE_CIPHER_SPEC) */
|
||||||
int dtls1_do_write(SSL *s, int type) {
|
int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
|
||||||
int ret;
|
int ret;
|
||||||
int curr_mtu;
|
int curr_mtu;
|
||||||
unsigned int len, frag_off;
|
unsigned int len, frag_off;
|
||||||
@ -350,7 +350,8 @@ int dtls1_do_write(SSL *s, int type) {
|
|||||||
len = s->init_num;
|
len = s->init_num;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len);
|
ret = dtls1_write_bytes(s, type, &s->init_buf->data[s->init_off], len,
|
||||||
|
use_epoch);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
@ -684,7 +685,7 @@ int dtls1_send_change_cipher_spec(SSL *s, int a, int b) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* SSL3_ST_CW_CHANGE_B */
|
/* SSL3_ST_CW_CHANGE_B */
|
||||||
return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC);
|
return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC, dtls1_use_current_epoch);
|
||||||
}
|
}
|
||||||
|
|
||||||
int dtls1_read_failed(SSL *s, int code) {
|
int dtls1_read_failed(SSL *s, int code) {
|
||||||
@ -724,7 +725,6 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) {
|
|||||||
int ret;
|
int ret;
|
||||||
/* XDTLS: for now assuming that read/writes are blocking */
|
/* XDTLS: for now assuming that read/writes are blocking */
|
||||||
unsigned long header_length;
|
unsigned long header_length;
|
||||||
uint8_t save_write_sequence[8];
|
|
||||||
|
|
||||||
/* assert(s->init_num == 0);
|
/* assert(s->init_num == 0);
|
||||||
assert(s->init_off == 0); */
|
assert(s->init_off == 0); */
|
||||||
@ -743,45 +743,18 @@ static int dtls1_retransmit_message(SSL *s, hm_fragment *frag) {
|
|||||||
frag->msg_header.msg_len, frag->msg_header.seq,
|
frag->msg_header.msg_len, frag->msg_header.seq,
|
||||||
0, frag->msg_header.frag_len);
|
0, frag->msg_header.frag_len);
|
||||||
|
|
||||||
/* Save current state. */
|
|
||||||
SSL_AEAD_CTX *aead_write_ctx = s->aead_write_ctx;
|
|
||||||
uint16_t epoch = s->d1->w_epoch;
|
|
||||||
|
|
||||||
/* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
|
/* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
|
||||||
* (negotiated cipher) exist. */
|
* (negotiated cipher) exist. */
|
||||||
assert(epoch == 0 || epoch == 1);
|
assert(s->d1->w_epoch == 0 || s->d1->w_epoch == 1);
|
||||||
assert(frag->msg_header.epoch <= epoch);
|
assert(frag->msg_header.epoch <= s->d1->w_epoch);
|
||||||
const int fragment_from_previous_epoch = (epoch == 1 &&
|
enum dtls1_use_epoch_t use_epoch = dtls1_use_current_epoch;
|
||||||
frag->msg_header.epoch == 0);
|
if (s->d1->w_epoch == 1 && frag->msg_header.epoch == 0) {
|
||||||
if (fragment_from_previous_epoch) {
|
use_epoch = dtls1_use_previous_epoch;
|
||||||
/* Rewind to the previous epoch.
|
|
||||||
*
|
|
||||||
* TODO(davidben): Instead of swapping out connection-global state, this
|
|
||||||
* logic should pass a "use previous epoch" parameter down to lower-level
|
|
||||||
* functions. */
|
|
||||||
s->d1->w_epoch = frag->msg_header.epoch;
|
|
||||||
s->aead_write_ctx = NULL;
|
|
||||||
memcpy(save_write_sequence, s->s3->write_sequence,
|
|
||||||
sizeof(s->s3->write_sequence));
|
|
||||||
memcpy(s->s3->write_sequence, s->d1->last_write_sequence,
|
|
||||||
sizeof(s->s3->write_sequence));
|
|
||||||
} else {
|
|
||||||
/* Otherwise the messages must be from the same epoch. */
|
|
||||||
assert(frag->msg_header.epoch == epoch);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC
|
ret = dtls1_do_write(s, frag->msg_header.is_ccs ? SSL3_RT_CHANGE_CIPHER_SPEC
|
||||||
: SSL3_RT_HANDSHAKE);
|
: SSL3_RT_HANDSHAKE,
|
||||||
|
use_epoch);
|
||||||
if (fragment_from_previous_epoch) {
|
|
||||||
/* Restore the current epoch. */
|
|
||||||
s->aead_write_ctx = aead_write_ctx;
|
|
||||||
s->d1->w_epoch = epoch;
|
|
||||||
memcpy(s->d1->last_write_sequence, s->s3->write_sequence,
|
|
||||||
sizeof(s->s3->write_sequence));
|
|
||||||
memcpy(s->s3->write_sequence, save_write_sequence,
|
|
||||||
sizeof(s->s3->write_sequence));
|
|
||||||
}
|
|
||||||
|
|
||||||
(void)BIO_flush(SSL_get_wbio(s));
|
(void)BIO_flush(SSL_get_wbio(s));
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -338,5 +338,5 @@ int dtls1_set_handshake_header(SSL *s, int htype, unsigned long len) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int dtls1_handshake_write(SSL *s) {
|
int dtls1_handshake_write(SSL *s) {
|
||||||
return dtls1_do_write(s, SSL3_RT_HANDSHAKE);
|
return dtls1_do_write(s, SSL3_RT_HANDSHAKE, dtls1_use_current_epoch);
|
||||||
}
|
}
|
||||||
|
46
ssl/d1_pkt.c
46
ssl/d1_pkt.c
@ -185,7 +185,7 @@ static int dtls1_record_replay_check(SSL *s, DTLS1_BITMAP *bitmap);
|
|||||||
static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap);
|
static void dtls1_record_bitmap_update(SSL *s, DTLS1_BITMAP *bitmap);
|
||||||
static int dtls1_process_record(SSL *s);
|
static int dtls1_process_record(SSL *s);
|
||||||
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
|
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
|
||||||
unsigned int len);
|
unsigned int len, enum dtls1_use_epoch_t use_epoch);
|
||||||
|
|
||||||
static int dtls1_process_record(SSL *s) {
|
static int dtls1_process_record(SSL *s) {
|
||||||
int al;
|
int al;
|
||||||
@ -695,18 +695,19 @@ int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf_, int len) {
|
|||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
i = dtls1_write_bytes(s, type, buf_, len);
|
i = dtls1_write_bytes(s, type, buf_, len, dtls1_use_current_epoch);
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Call this to write data in records of type 'type' It will return <= 0 if not
|
/* Call this to write data in records of type 'type' It will return <= 0 if not
|
||||||
* all data has been sent or non-blocking IO. */
|
* all data has been sent or non-blocking IO. */
|
||||||
int dtls1_write_bytes(SSL *s, int type, const void *buf, int len) {
|
int dtls1_write_bytes(SSL *s, int type, const void *buf, int len,
|
||||||
|
enum dtls1_use_epoch_t use_epoch) {
|
||||||
int i;
|
int i;
|
||||||
|
|
||||||
assert(len <= SSL3_RT_MAX_PLAIN_LENGTH);
|
assert(len <= SSL3_RT_MAX_PLAIN_LENGTH);
|
||||||
s->rwstate = SSL_NOTHING;
|
s->rwstate = SSL_NOTHING;
|
||||||
i = do_dtls1_write(s, type, buf, len);
|
i = do_dtls1_write(s, type, buf, len, use_epoch);
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -716,27 +717,40 @@ int dtls1_write_bytes(SSL *s, int type, const void *buf, int len) {
|
|||||||
* number. */
|
* number. */
|
||||||
static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
|
static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
|
||||||
size_t max_out, uint8_t type, const uint8_t *in,
|
size_t max_out, uint8_t type, const uint8_t *in,
|
||||||
size_t in_len) {
|
size_t in_len, enum dtls1_use_epoch_t use_epoch) {
|
||||||
if (max_out < DTLS1_RT_HEADER_LENGTH) {
|
if (max_out < DTLS1_RT_HEADER_LENGTH) {
|
||||||
OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL);
|
OPENSSL_PUT_ERROR(SSL, dtls1_seal_record, SSL_R_BUFFER_TOO_SMALL);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Determine the parameters for the current epoch. */
|
||||||
|
uint16_t epoch = s->d1->w_epoch;
|
||||||
|
SSL_AEAD_CTX *aead = s->aead_write_ctx;
|
||||||
|
uint8_t *seq = s->s3->write_sequence;
|
||||||
|
if (use_epoch == dtls1_use_previous_epoch) {
|
||||||
|
/* DTLS renegotiation is unsupported, so only epochs 0 (NULL cipher) and 1
|
||||||
|
* (negotiated cipher) exist. */
|
||||||
|
assert(s->d1->w_epoch == 1);
|
||||||
|
epoch = s->d1->w_epoch - 1;
|
||||||
|
aead = NULL;
|
||||||
|
seq = s->d1->last_write_sequence;
|
||||||
|
}
|
||||||
|
|
||||||
out[0] = type;
|
out[0] = type;
|
||||||
|
|
||||||
uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION;
|
uint16_t wire_version = s->s3->have_version ? s->version : DTLS1_VERSION;
|
||||||
out[1] = wire_version >> 8;
|
out[1] = wire_version >> 8;
|
||||||
out[2] = wire_version & 0xff;
|
out[2] = wire_version & 0xff;
|
||||||
|
|
||||||
out[3] = s->d1->w_epoch >> 8;
|
out[3] = epoch >> 8;
|
||||||
out[4] = s->d1->w_epoch & 0xff;
|
out[4] = epoch & 0xff;
|
||||||
memcpy(&out[5], &s->s3->write_sequence[2], 6);
|
memcpy(&out[5], &seq[2], 6);
|
||||||
|
|
||||||
size_t ciphertext_len;
|
size_t ciphertext_len;
|
||||||
if (!SSL_AEAD_CTX_seal(s->aead_write_ctx, out + DTLS1_RT_HEADER_LENGTH,
|
if (!SSL_AEAD_CTX_seal(aead, out + DTLS1_RT_HEADER_LENGTH, &ciphertext_len,
|
||||||
&ciphertext_len, max_out - DTLS1_RT_HEADER_LENGTH,
|
max_out - DTLS1_RT_HEADER_LENGTH, type, wire_version,
|
||||||
type, wire_version, &out[3] /* seq */, in, in_len) ||
|
&out[3] /* seq */, in, in_len) ||
|
||||||
!ssl3_record_sequence_update(&s->s3->write_sequence[2], 6)) {
|
!ssl3_record_sequence_update(&seq[2], 6)) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -758,7 +772,7 @@ static int dtls1_seal_record(SSL *s, uint8_t *out, size_t *out_len,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
|
static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
|
||||||
unsigned int len) {
|
unsigned int len, enum dtls1_use_epoch_t use_epoch) {
|
||||||
SSL3_BUFFER *wb = &s->s3->wbuf;
|
SSL3_BUFFER *wb = &s->s3->wbuf;
|
||||||
|
|
||||||
/* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there
|
/* ssl3_write_pending drops the write if |BIO_write| fails in DTLS, so there
|
||||||
@ -790,7 +804,8 @@ static int do_dtls1_write(SSL *s, int type, const uint8_t *buf,
|
|||||||
size_t max_out = wb->len - wb->offset;
|
size_t max_out = wb->len - wb->offset;
|
||||||
|
|
||||||
size_t ciphertext_len;
|
size_t ciphertext_len;
|
||||||
if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len)) {
|
if (!dtls1_seal_record(s, out, &ciphertext_len, max_out, type, buf, len,
|
||||||
|
use_epoch)) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -863,7 +878,8 @@ int dtls1_dispatch_alert(SSL *s) {
|
|||||||
*ptr++ = s->s3->send_alert[0];
|
*ptr++ = s->s3->send_alert[0];
|
||||||
*ptr++ = s->s3->send_alert[1];
|
*ptr++ = s->s3->send_alert[1];
|
||||||
|
|
||||||
i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf));
|
i = do_dtls1_write(s, SSL3_RT_ALERT, &buf[0], sizeof(buf),
|
||||||
|
dtls1_use_current_epoch);
|
||||||
if (i <= 0) {
|
if (i <= 0) {
|
||||||
s->s3->alert_dispatch = 1;
|
s->s3->alert_dispatch = 1;
|
||||||
} else {
|
} else {
|
||||||
|
@ -940,7 +940,12 @@ int ssl3_do_change_cipher_spec(SSL *ssl);
|
|||||||
int ssl3_set_handshake_header(SSL *s, int htype, unsigned long len);
|
int ssl3_set_handshake_header(SSL *s, int htype, unsigned long len);
|
||||||
int ssl3_handshake_write(SSL *s);
|
int ssl3_handshake_write(SSL *s);
|
||||||
|
|
||||||
int dtls1_do_write(SSL *s, int type);
|
enum dtls1_use_epoch_t {
|
||||||
|
dtls1_use_previous_epoch,
|
||||||
|
dtls1_use_current_epoch,
|
||||||
|
};
|
||||||
|
|
||||||
|
int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch);
|
||||||
int ssl3_read_n(SSL *s, int n, int extend);
|
int ssl3_read_n(SSL *s, int n, int extend);
|
||||||
int dtls1_read_bytes(SSL *s, int type, uint8_t *buf, int len, int peek);
|
int dtls1_read_bytes(SSL *s, int type, uint8_t *buf, int len, int peek);
|
||||||
int ssl3_write_pending(SSL *s, int type, const uint8_t *buf, unsigned int len);
|
int ssl3_write_pending(SSL *s, int type, const uint8_t *buf, unsigned int len);
|
||||||
@ -949,7 +954,8 @@ void dtls1_set_message_header(SSL *s, uint8_t mt, unsigned long len,
|
|||||||
unsigned long frag_len);
|
unsigned long frag_len);
|
||||||
|
|
||||||
int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf, int len);
|
int dtls1_write_app_data_bytes(SSL *s, int type, const void *buf, int len);
|
||||||
int dtls1_write_bytes(SSL *s, int type, const void *buf, int len);
|
int dtls1_write_bytes(SSL *s, int type, const void *buf, int len,
|
||||||
|
enum dtls1_use_epoch_t use_epoch);
|
||||||
|
|
||||||
int dtls1_send_change_cipher_spec(SSL *s, int a, int b);
|
int dtls1_send_change_cipher_spec(SSL *s, int a, int b);
|
||||||
int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen);
|
int dtls1_send_finished(SSL *s, int a, int b, const char *sender, int slen);
|
||||||
|
Loading…
Reference in New Issue
Block a user