diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 5b1c0679..0fa10b56 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h @@ -3098,12 +3098,15 @@ struct ssl_quic_method_st { int (*set_encryption_secrets)(SSL *ssl, enum ssl_encryption_level_t level, const uint8_t *read_secret, const uint8_t *write_secret, size_t secret_len); - // add_message adds a message to the current flight at the given encryption - // level. A single handshake flight may include multiple encryption levels. - // Callers can defer writing data to the network until |flush_flight| for - // optimal packing. It returns one on success and zero on error. - int (*add_message)(SSL *ssl, enum ssl_encryption_level_t level, - const uint8_t *data, size_t len); + // add_handshake_data adds handshake data to the current flight at the given + // encryption level. It returns one on success and zero on error. + // + // BoringSSL will pack data from a single encryption level together, but a + // single handshake flight may include multiple encryption levels. Callers + // should defer writing data to the network until |flush_flight| to better + // pack QUIC packets into transport datagrams. + int (*add_handshake_data)(SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len); // flush_flight is called when the current flight is complete and should be // written to the transport. Note a flight may contain data at several // encryption levels. It returns one on success and zero on error. diff --git a/ssl/s3_both.cc b/ssl/s3_both.cc index 689dd1d8..55e9aaa9 100644 --- a/ssl/s3_both.cc +++ b/ssl/s3_both.cc @@ -184,56 +184,49 @@ bool ssl3_finish_message(SSL *ssl, CBB *cbb, Array *out_msg) { } bool ssl3_add_message(SSL *ssl, Array msg) { - if (ssl->ctx->quic_method) { - if (!ssl->ctx->quic_method->add_message(ssl, ssl->s3->write_level, - msg.data(), msg.size())) { - OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR); - return false; + // Pack handshake data into the minimal number of records. This avoids + // unnecessary encryption overhead, notably in TLS 1.3 where we send several + // encrypted messages in a row. For now, we do not do this for the null + // cipher. The benefit is smaller and there is a risk of breaking buggy + // implementations. Additionally, we tie this to draft-28 as a sanity check, + // on the off chance middleboxes have fixated on sizes. + // + // TODO(davidben): See if we can do this uniformly. + Span rest = msg; + if (ssl->ctx->quic_method == nullptr && + (ssl->s3->aead_write_ctx->is_null_cipher() || + ssl->version == TLS1_3_DRAFT23_VERSION)) { + while (!rest.empty()) { + Span chunk = rest.subspan(0, ssl->max_send_fragment); + rest = rest.subspan(chunk.size()); + + if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) { + return false; + } } } else { - // Pack handshake data into the minimal number of records. This avoids - // unnecessary encryption overhead, notably in TLS 1.3 where we send several - // encrypted messages in a row. For now, we do not do this for the null - // cipher. The benefit is smaller and there is a risk of breaking buggy - // implementations. Additionally, we tie this to draft-28 as a sanity check, - // on the off chance middleboxes have fixated on sizes. - // - // TODO(davidben): See if we can do this uniformly. - Span rest = msg; - if (ssl->s3->aead_write_ctx->is_null_cipher() || - ssl->version == TLS1_3_DRAFT23_VERSION) { - while (!rest.empty()) { - Span chunk = rest.subspan(0, ssl->max_send_fragment); - rest = rest.subspan(chunk.size()); - - if (!add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, chunk)) { - return false; - } + while (!rest.empty()) { + // Flush if |pending_hs_data| is full. + if (ssl->s3->pending_hs_data && + ssl->s3->pending_hs_data->length >= ssl->max_send_fragment && + !tls_flush_pending_hs_data(ssl)) { + return false; } - } else { - while (!rest.empty()) { - // Flush if |pending_hs_data| is full. - if (ssl->s3->pending_hs_data && - ssl->s3->pending_hs_data->length >= ssl->max_send_fragment && - !tls_flush_pending_hs_data(ssl)) { - return false; - } - size_t pending_len = - ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0; - Span chunk = - rest.subspan(0, ssl->max_send_fragment - pending_len); - assert(!chunk.empty()); - rest = rest.subspan(chunk.size()); + size_t pending_len = + ssl->s3->pending_hs_data ? ssl->s3->pending_hs_data->length : 0; + Span chunk = + rest.subspan(0, ssl->max_send_fragment - pending_len); + assert(!chunk.empty()); + rest = rest.subspan(chunk.size()); - if (!ssl->s3->pending_hs_data) { - ssl->s3->pending_hs_data.reset(BUF_MEM_new()); - } - if (!ssl->s3->pending_hs_data || - !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(), - chunk.size())) { - return false; - } + if (!ssl->s3->pending_hs_data) { + ssl->s3->pending_hs_data.reset(BUF_MEM_new()); + } + if (!ssl->s3->pending_hs_data || + !BUF_MEM_append(ssl->s3->pending_hs_data.get(), chunk.data(), + chunk.size())) { + return false; } } } @@ -249,16 +242,24 @@ bool ssl3_add_message(SSL *ssl, Array msg) { } bool tls_flush_pending_hs_data(SSL *ssl) { - if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0 || - ssl->ctx->quic_method) { + if (!ssl->s3->pending_hs_data || ssl->s3->pending_hs_data->length == 0) { return true; } UniquePtr pending_hs_data = std::move(ssl->s3->pending_hs_data); - return add_record_to_flight( - ssl, SSL3_RT_HANDSHAKE, + auto data = MakeConstSpan(reinterpret_cast(pending_hs_data->data), - pending_hs_data->length)); + pending_hs_data->length); + if (ssl->ctx->quic_method) { + if (!ssl->ctx->quic_method->add_handshake_data(ssl, ssl->s3->write_level, + data.data(), data.size())) { + OPENSSL_PUT_ERROR(SSL, SSL_R_QUIC_INTERNAL_ERROR); + return false; + } + return true; + } + + return add_record_to_flight(ssl, SSL3_RT_HANDSHAKE, data); } bool ssl3_add_change_cipher_spec(SSL *ssl) { @@ -280,6 +281,10 @@ bool ssl3_add_change_cipher_spec(SSL *ssl) { } int ssl3_flush_flight(SSL *ssl) { + if (!tls_flush_pending_hs_data(ssl)) { + return -1; + } + if (ssl->ctx->quic_method) { if (ssl->s3->write_shutdown != ssl_shutdown_none) { OPENSSL_PUT_ERROR(SSL, SSL_R_PROTOCOL_IS_SHUTDOWN); @@ -292,10 +297,6 @@ int ssl3_flush_flight(SSL *ssl) { } } - if (!tls_flush_pending_hs_data(ssl)) { - return -1; - } - if (ssl->s3->pending_flight == nullptr) { return 1; } diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index c237809e..47925607 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc @@ -4650,8 +4650,9 @@ class QUICMethodTest : public testing::Test { write_key, key_len); } - static int AddMessageCallback(SSL *ssl, enum ssl_encryption_level_t level, - const uint8_t *data, size_t len) { + static int AddHandshakeDataCallback(SSL *ssl, + enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) { EXPECT_EQ(level, SSL_quic_write_level(ssl)); return TransportFromSSL(ssl)->WriteHandshakeData(level, MakeConstSpan(data, len)); @@ -4681,7 +4682,7 @@ UnownedSSLExData QUICMethodTest::ex_data_; TEST_F(QUICMethodTest, Basic) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - AddMessageCallback, + AddHandshakeDataCallback, FlushFlightCallback, SendAlertCallback, }; @@ -4732,7 +4733,7 @@ TEST_F(QUICMethodTest, Basic) { TEST_F(QUICMethodTest, Async) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - AddMessageCallback, + AddHandshakeDataCallback, FlushFlightCallback, SendAlertCallback, }; @@ -4793,8 +4794,8 @@ TEST_F(QUICMethodTest, Buffered) { }; static UnownedSSLExData buffered_flights; - auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level, - const uint8_t *data, size_t len) -> int { + auto add_handshake_data = [](SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) -> int { BufferedFlight *flight = buffered_flights.Get(ssl); flight->data[level].insert(flight->data[level].end(), data, data + len); return 1; @@ -4817,7 +4818,7 @@ TEST_F(QUICMethodTest, Buffered) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - add_message, + add_handshake_data, flush_flight, SendAlertCallback, }; @@ -4862,8 +4863,8 @@ TEST_F(QUICMethodTest, Buffered) { // EncryptedExtensions in a single chunk, BoringSSL notices and rejects this on // key change. TEST_F(QUICMethodTest, ExcessProvidedData) { - auto add_message = [](SSL *ssl, enum ssl_encryption_level_t level, - const uint8_t *data, size_t len) -> int { + auto add_handshake_data = [](SSL *ssl, enum ssl_encryption_level_t level, + const uint8_t *data, size_t len) -> int { // Switch everything to the initial level. return TransportFromSSL(ssl)->WriteHandshakeData(ssl_encryption_initial, MakeConstSpan(data, len)); @@ -4871,7 +4872,7 @@ TEST_F(QUICMethodTest, ExcessProvidedData) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - add_message, + add_handshake_data, FlushFlightCallback, SendAlertCallback, }; @@ -4891,8 +4892,8 @@ TEST_F(QUICMethodTest, ExcessProvidedData) { // encryption. ASSERT_EQ(ssl_encryption_initial, SSL_quic_read_level(client_.get())); - // |add_message| incorrectly wrote everything at the initial level, so this - // queues up ServerHello through Finished in one chunk. + // |add_handshake_data| incorrectly wrote everything at the initial level, so + // this queues up ServerHello through Finished in one chunk. ASSERT_TRUE(ProvideHandshakeData(client_.get())); // The client reads ServerHello successfully, but then rejects the buffered @@ -4917,7 +4918,7 @@ TEST_F(QUICMethodTest, ExcessProvidedData) { TEST_F(QUICMethodTest, ProvideWrongLevel) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - AddMessageCallback, + AddHandshakeDataCallback, FlushFlightCallback, SendAlertCallback, }; @@ -4962,7 +4963,7 @@ TEST_F(QUICMethodTest, ProvideWrongLevel) { TEST_F(QUICMethodTest, TooMuchData) { const SSL_QUIC_METHOD quic_method = { SetEncryptionSecretsCallback, - AddMessageCallback, + AddHandshakeDataCallback, FlushFlightCallback, SendAlertCallback, };