diff --git a/ssl/d1_both.c b/ssl/d1_both.c index cc95a70b..7bd28241 100644 --- a/ssl/d1_both.c +++ b/ssl/d1_both.c @@ -737,25 +737,38 @@ static int dtls1_retransmit_message(SSL *ssl, hm_fragment *frag) { ret = dtls1_do_handshake_write(ssl, use_epoch); } - /* TODO(davidben): Check return value? */ - (void)BIO_flush(SSL_get_wbio(ssl)); return ret; } - int dtls1_retransmit_buffered_messages(SSL *ssl) { - pqueue sent = ssl->d1->sent_messages; - piterator iter = pqueue_iterator(sent); - pitem *item; + /* Ensure we are packing handshake messages. */ + const int was_buffered = ssl_is_wbio_buffered(ssl); + assert(was_buffered == SSL_in_init(ssl)); + if (!was_buffered && !ssl_init_wbio_buffer(ssl, 1)) { + return -1; + } + assert(ssl_is_wbio_buffered(ssl)); + int ret = -1; + piterator iter = pqueue_iterator(ssl->d1->sent_messages); + pitem *item; for (item = pqueue_next(&iter); item != NULL; item = pqueue_next(&iter)) { hm_fragment *frag = (hm_fragment *)item->data; if (dtls1_retransmit_message(ssl, frag) <= 0) { - return -1; + goto err; } } - return 1; + /* TODO(davidben): Check return value? */ + (void)BIO_flush(SSL_get_wbio(ssl)); + + ret = 1; + +err: + if (!was_buffered) { + ssl_free_wbio_buffer(ssl); + } + return ret; } /* dtls1_buffer_change_cipher_spec adds a ChangeCipherSpec to the current diff --git a/ssl/internal.h b/ssl/internal.h index 982f3040..2c40b44f 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -1136,6 +1136,10 @@ long dtls1_get_message(SSL *ssl, int st1, int stn, int mt, long max, enum ssl_hash_message_t hash_message, int *ok); int dtls1_dispatch_alert(SSL *ssl); +/* ssl_is_wbio_buffered returns one if |ssl|'s write BIO is buffered and zero + * otherwise. */ +int ssl_is_wbio_buffered(const SSL *ssl); + int ssl_init_wbio_buffer(SSL *ssl, int push); void ssl_free_wbio_buffer(SSL *ssl); diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index 84047b24..70bd8321 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c @@ -1865,6 +1865,10 @@ const COMP_METHOD *SSL_get_current_expansion(SSL *ssl) { return NULL; } int *SSL_get_server_tmp_key(SSL *ssl, EVP_PKEY **out_key) { return 0; } +int ssl_is_wbio_buffered(const SSL *ssl) { + return ssl->bbio != NULL && ssl->wbio == ssl->bbio; +} + int ssl_init_wbio_buffer(SSL *ssl, int push) { BIO *bbio; diff --git a/ssl/test/runner/dtls.go b/ssl/test/runner/dtls.go index c3ee5218..a31dfc00 100644 --- a/ssl/test/runner/dtls.go +++ b/ssl/test/runner/dtls.go @@ -46,6 +46,7 @@ func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { b := c.rawInput // Read a new packet only if the current one is empty. + var newPacket bool if len(b.data) == 0 { // Pick some absurdly large buffer size. b.resize(maxCiphertext + recordHeaderLen) @@ -57,6 +58,7 @@ func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") } c.rawInput.resize(n) + newPacket = true } // Read out one record. @@ -108,6 +110,13 @@ func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { c.in.setErrorLocked(c.sendAlert(err)) } b.off = off + + // Require that ChangeCipherSpec always share a packet with either the + // previous or next handshake message. + if newPacket && typ == recordTypeChangeCipherSpec && c.rawInput == nil { + return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: ChangeCipherSpec not packed together with Finished")) + } + return typ, b, nil }