diff --git a/ssl/internal.h b/ssl/internal.h index 28ea87b7..3deedc43 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -1070,6 +1070,9 @@ bool tls_has_unprocessed_handshake_data(const SSL *ssl); // |tls_has_unprocessed_handshake_data| for DTLS. bool dtls_has_unprocessed_handshake_data(const SSL *ssl); +// tls_flush_pending_hs_data flushes any handshake plaintext data. +bool tls_flush_pending_hs_data(SSL *ssl); + struct DTLS_OUTGOING_MESSAGE { DTLS_OUTGOING_MESSAGE() {} DTLS_OUTGOING_MESSAGE(const DTLS_OUTGOING_MESSAGE &) = delete; @@ -2399,6 +2402,11 @@ struct SSL3_STATE { // hs_buf is the buffer of handshake data to process. UniquePtr hs_buf; + // pending_hs_data contains the pending handshake data that has not yet + // been encrypted to |pending_flight|. This allows packing the handshake into + // fewer records. + UniquePtr pending_hs_data; + // pending_flight is the pending outgoing flight. This is used to flush each // handshake flight in a single write. |write_buffer| must be written out // before this data. diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc index 9d1218cc..788a897a 100644 --- a/ssl/s3_both.cc +++ b/ssl/s3_both.cc @@ -134,6 +134,8 @@ namespace bssl { static bool add_record_to_flight(SSL *ssl, uint8_t type, Span in) { + // The caller should have flushed |pending_hs_data| first. + assert(!ssl->s3->pending_hs_data); // We'll never add a flight while in the process of writing it out. assert(ssl->s3->pending_flight_offset == 0); @@ -182,17 +184,51 @@ bool ssl3_finish_message(SSL *ssl, CBB *cbb, Array *out_msg) { } bool ssl3_add_message(SSL *ssl, Array msg) { - // Add the message to the current flight, splitting into several records if - // needed. + // Pack handshake data into the minimal number of records. This avoids + // unnecessary encryption overhead, notably in TLS 1.3 where we send several + // encrypted messages in a row. For now, we do not do this for the null + // cipher. The benefit is smaller and there is a risk of breaking buggy + // implementations. Additionally, we tie this to draft-28 as a sanity check, + // on the off chance middleboxes have fixated on sizes. + // + // TODO(davidben): See if we can do this uniformly. Span rest = msg; - do { - Span chunk = rest.subspan(0, ssl->max_send_fragment); - rest = rest.subspan(chunk.size()); + if (ssl->s3->aead_write_ctx->is_null_cipher() || + ssl->version == TLS1_3_DRAFT23_VERSION) { + while (!rest.empty()) { + Span chunk = rest.subspan(0, ssl->max_send_fragment); + rest = rest.subspan(chunk.size()); - if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) { - return false; + if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) { + return false; + } } - } while (!rest.empty()); + } else { + while (!rest.empty()) { + // Flush if |pending_hs_data| is full. + if (ssl->s3->pending_hs_data && + ssl->s3->pending_hs_data->length >= ssl->max_send_fragment && + !tls_flush_pending_hs_data(ssl)) { + return false; + } + + size_t pending_len = + ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0; + Span chunk = + rest.subspan(0, ssl->max_send_fragment - pending_len); + assert(!chunk.empty()); + rest = rest.subspan(chunk.size()); + + if (!ssl->s3->pending_hs_data) { + ssl->s3->pending_hs_data.reset(BUF_MEM_new()); + } + if (!ssl->s3->pending_hs_data || + !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(), + chunk.size())) { + return false; + } + } + } ssl_do_msg_callback(ssl, 1 /* write */, SSL3_RT_HANDSHAKE, msg); // TODO(svaldez): Move this up a layer to fix abstraction for SSLTranscript on @@ -204,10 +240,23 @@ bool ssl3_add_message(SSL *ssl, Array msg) { return true; } +bool tls_flush_pending_hs_data(SSL *ssl) { + if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) { + return true; + } + + UniquePtr pending_hs_data = std::move(ssl->s3->pending_hs_data); + return add_record_to_flight( + ssl, SSL3_RT_HANDSHAKE, + MakeConstSpan(reinterpret_cast(pending_hs_data->data), + pending_hs_data->length)); +} + bool ssl3_add_change_cipher_spec(SSL *ssl) { static const uint8_t kChangeCipherSpec[1] = {SSL3_MT_CCS}; - if (!add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC, + if (!tls_flush_pending_hs_data(ssl) || + !add_record_to_flight(ssl, SSL3_RT_CHANGE_CIPHER_SPEC, kChangeCipherSpec)) { return false; } @@ -219,7 +268,8 @@ bool ssl3_add_change_cipher_spec(SSL *ssl) { bool ssl3_add_alert(SSL *ssl, uint8_t level, uint8_t desc) { uint8_t alert[2] = {level, desc}; - if (!add_record_to_flight(ssl, SSL3_RT_ALERT, alert)) { + if (!tls_flush_pending_hs_data(ssl) || + !add_record_to_flight(ssl, SSL3_RT_ALERT, alert)) { return false; } @@ -229,6 +279,10 @@ bool ssl3_add_alert(SSL *ssl, uint8_t level, uint8_t desc) { } int ssl3_flush_flight(SSL *ssl) { + if (!tls_flush_pending_hs_data(ssl)) { + return -1; + } + if (ssl->s3->pending_flight == nullptr) { return 1; } diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc index 5eb68f6b..00ba8dfb 100644 --- a/ssl/s3_pkt.cc +++ b/ssl/s3_pkt.cc @@ -234,6 +234,9 @@ static int do_ssl3_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { return 0; } + if (!tls_flush_pending_hs_data(ssl)) { + return -1; + } size_t flight_len = 0; if (ssl->s3->pending_flight != nullptr) { flight_len = diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index dcc1b01f..5d3345d6 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go @@ -1506,6 +1506,15 @@ type ProtocolBugs struct { // length accepted from the peer. MaxReceivePlaintext int + // ExpectPackedEncryptedHandshake, if non-zero, requires that the peer maximally + // pack their encrypted handshake messages, fitting at most the + // specified number of plaintext bytes per record. + ExpectPackedEncryptedHandshake int + + // ForbidHandshakePacking, if true, requires the peer place a record + // boundary after every handshake message. + ForbidHandshakePacking bool + // SendTicketLifetime, if non-zero, is the ticket lifetime to send in // NewSessionTicket messages. SendTicketLifetime time.Duration diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index 9cd61eb4..55e42a51 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -111,6 +111,11 @@ type Conn struct { expectTLS13ChangeCipherSpec bool + // seenHandshakePackEnd is whether the most recent handshake record was + // not full for ExpectPackedEncryptedHandshake. If true, no more + // handshake data may be received until the next flight or epoch change. + seenHandshakePackEnd bool + tmp [16]byte } @@ -756,6 +761,7 @@ func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []b side = clientWrite } c.in.useTrafficSecret(version, suite, secret, side) + c.seenHandshakePackEnd = false return nil } @@ -975,6 +981,13 @@ Again: return c.in.setErrorLocked(err) } + if typ != recordTypeHandshake { + c.seenHandshakePackEnd = false + } else if c.seenHandshakePackEnd { + c.in.freeBlock(b) + return c.in.setErrorLocked(errors.New("tls: peer violated ExpectPackedEncryptedHandshake")) + } + switch typ { default: c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) @@ -1037,6 +1050,9 @@ Again: return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } c.hand.Write(data) + if pack := c.config.Bugs.ExpectPackedEncryptedHandshake; pack > 0 && len(data) < pack && c.out.cipher != nil { + c.seenHandshakePackEnd = true + } } if b != nil { @@ -1095,6 +1111,7 @@ func (c *Conn) writeV2Record(data []byte) (n int, err error) { // to the connection and updates the record layer state. // c.out.Mutex <= L. func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { + c.seenHandshakePackEnd = false if typ == recordTypeHandshake { msgType := data[0] if c.config.Bugs.SendWrongMessageType != 0 && msgType == c.config.Bugs.SendWrongMessageType { @@ -1304,6 +1321,9 @@ func (c *Conn) doReadHandshake() ([]byte, error) { return nil, err } } + if c.hand.Len() > 4+n && c.config.Bugs.ForbidHandshakePacking { + return nil, errors.New("tls: forbidden trailing data after a handshake message") + } return c.hand.Next(4 + n), nil } diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index ba52706c..15122a17 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -2998,10 +2998,12 @@ read alert 1 0 config: Config{ MaxVersion: VersionTLS13, Bugs: ProtocolBugs{ - MaxReceivePlaintext: 512, + MaxReceivePlaintext: 512, + ExpectPackedEncryptedHandshake: 512, }, }, - messageLen: 1024, + tls13Variant: TLS13Draft28, + messageLen: 1024, flags: []string{ "-max-send-fragment", "512", "-read-size", "1024", @@ -3028,6 +3030,32 @@ read alert 1 0 shouldFail: true, expectedLocalError: "local error: record overflow", }, + { + // Test that handshake data is not packed in TLS 1.3 + // draft-23. + testType: serverTest, + name: "ForbidHandshakePacking-TLS13Draft23", + config: Config{ + MaxVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ForbidHandshakePacking: true, + }, + }, + tls13Variant: TLS13Draft23, + }, + { + // Test that handshake data is tightly packed in TLS 1.3 + // draft-28. + testType: serverTest, + name: "PackedEncryptedHandshake-TLS13Draft28", + config: Config{ + MaxVersion: VersionTLS13, + Bugs: ProtocolBugs{ + ExpectPackedEncryptedHandshake: 16384, + }, + }, + tls13Variant: TLS13Draft28, + }, { // Test that DTLS can handle multiple application data // records in a single packet. diff --git a/ssl/tls_method.cc b/ssl/tls_method.cc index d0adcdba..2af51719 100644 --- a/ssl/tls_method.cc +++ b/ssl/tls_method.cc @@ -95,6 +95,10 @@ static bool ssl3_set_read_state(SSL *ssl, UniquePtr aead_ctx) { } static bool ssl3_set_write_state(SSL *ssl, UniquePtr aead_ctx) { + if (!tls_flush_pending_hs_data(ssl)) { + return false; + } + OPENSSL_memset(ssl->s3->write_sequence, 0, sizeof(ssl->s3->write_sequence)); ssl->s3->aead_write_ctx = std::move(aead_ctx); return true;