diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 6898674a..d5883950 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h @@ -394,6 +394,26 @@ OPENSSL_EXPORT int SSL_pending(const SSL *ssl); // https://crbug.com/466303. OPENSSL_EXPORT int SSL_write(SSL *ssl, const void *buf, int num); +// SSL_KEY_UPDATE_REQUESTED indicates that the peer should reply to a KeyUpdate +// message with its own, thus updating traffic secrets for both directions on +// the connection. +#define SSL_KEY_UPDATE_REQUESTED 1 + +// SSL_KEY_UPDATE_NOT_REQUESTED indicates that the peer should not reply with +// it's own KeyUpdate message. +#define SSL_KEY_UPDATE_NOT_REQUESTED 0 + +// SSL_key_update queues a TLS 1.3 KeyUpdate message to be sent on |ssl| +// if one is not already queued. The |request_type| argument must one of the +// |SSL_KEY_UPDATE_*| values. This function requires that |ssl| have completed a +// TLS >= 1.3 handshake. It returns one on success or zero on error. +// +// Note that this function does not _send_ the message itself. The next call to +// |SSL_write| will cause the message to be sent. |SSL_write| may be called with +// a zero length to flush a KeyUpdate message when no application data is +// pending. +OPENSSL_EXPORT int SSL_key_update(SSL *ssl, int request_type); + // SSL_shutdown shuts down |ssl|. It runs in two stages. First, it sends // close_notify and returns zero or one on success or -1 on failure. Zero // indicates that close_notify was sent, but not received, and one additionally diff --git a/ssl/internal.h b/ssl/internal.h index bbce7ec4..fa755d26 100644 --- a/ssl/internal.h +++ b/ssl/internal.h @@ -1649,6 +1649,11 @@ const char *ssl_server_handshake_state(SSL_HANDSHAKE *hs); const char *tls13_client_handshake_state(SSL_HANDSHAKE *hs); const char *tls13_server_handshake_state(SSL_HANDSHAKE *hs); +// tls13_add_key_update queues a KeyUpdate message on |ssl|. The +// |update_requested| argument must be one of |SSL_KEY_UPDATE_REQUESTED| or +// |SSL_KEY_UPDATE_NOT_REQUESTED|. +bool tls13_add_key_update(SSL *ssl, int update_requested); + // tls13_post_handshake processes a post-handshake message. It returns true on // success and false on failure. bool tls13_post_handshake(SSL *ssl, const SSLMessage &msg); @@ -2506,10 +2511,6 @@ struct SSL_CONFIG { // From RFC 8446, used in determining PSK modes. #define SSL_PSK_DHE_KE 0x1 -// From RFC 8446, used in determining whether to respond with a KeyUpdate. -#define SSL_KEY_UPDATE_NOT_REQUESTED 0 -#define SSL_KEY_UPDATE_REQUESTED 1 - // kMaxEarlyDataAccepted is the advertised number of plaintext bytes of early // data that will be accepted. This value should be slightly below // kMaxEarlyDataSkipped in tls_record.c, which is measured in ciphertext. diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc index f0ae8a2e..f1d67d0b 100644 --- a/ssl/s3_pkt.cc +++ b/ssl/s3_pkt.cc @@ -232,25 +232,29 @@ static int do_ssl3_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { return -1; } - if (len == 0) { - return 0; - } - if (!tls_flush_pending_hs_data(ssl)) { return -1; } + size_t flight_len = 0; if (ssl->s3->pending_flight != nullptr) { flight_len = ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset; } - size_t max_out = len + SSL_max_seal_overhead(ssl); - if (max_out < len || max_out + flight_len < max_out) { - OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); - return -1; + size_t max_out = flight_len; + if (len > 0) { + const size_t max_ciphertext_len = len + SSL_max_seal_overhead(ssl); + if (max_ciphertext_len < len || max_out + max_ciphertext_len < max_out) { + OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW); + return -1; + } + max_out += max_ciphertext_len; + } + + if (max_out == 0) { + return 0; } - max_out += flight_len; if (!buf->EnsureCap(flight_len + ssl_seal_align_prefix_len(ssl), max_out)) { return -1; @@ -270,12 +274,14 @@ static int do_ssl3_write(SSL *ssl, int type, const uint8_t *in, unsigned len) { buf->DidWrite(flight_len); } - size_t ciphertext_len; - if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len, - buf->remaining().size(), type, in, len)) { - return -1; + if (len > 0) { + size_t ciphertext_len; + if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len, + buf->remaining().size(), type, in, len)) { + return -1; + } + buf->DidWrite(ciphertext_len); } - buf->DidWrite(ciphertext_len); // Now that we've made progress on the connection, uncork KeyUpdate // acknowledgments. diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc index ceeba89c..a3c25f3b 100644 --- a/ssl/ssl_lib.cc +++ b/ssl/ssl_lib.cc @@ -1135,6 +1135,37 @@ int SSL_write(SSL *ssl, const void *buf, int num) { return ret; } +int SSL_key_update(SSL *ssl, int request_type) { + ssl_reset_error_state(ssl); + + if (ssl->do_handshake == NULL) { + OPENSSL_PUT_ERROR(SSL, SSL_R_UNINITIALIZED); + return 0; + } + + if (ssl->ctx->quic_method != nullptr) { + OPENSSL_PUT_ERROR(SSL, ERR_R_SHOULD_NOT_HAVE_BEEN_CALLED); + return 0; + } + + if (!ssl->s3->initial_handshake_complete) { + OPENSSL_PUT_ERROR(SSL, SSL_R_HANDSHAKE_NOT_COMPLETE); + return 0; + } + + if (ssl_protocol_version(ssl) < TLS1_3_VERSION) { + OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_SSL_VERSION); + return 0; + } + + if (!ssl->s3->key_update_pending && + !tls13_add_key_update(ssl, request_type)) { + return 0; + } + + return 1; +} + int SSL_shutdown(SSL *ssl) { ssl_reset_error_state(ssl); diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 8d01c03a..1f09156e 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc @@ -4358,6 +4358,35 @@ TEST(SSLTest, ApplyHandoffRemovesUnsupportedCurves) { EXPECT_EQ(1u, server->config->supported_group_list.size()); } +TEST(SSLTest, ZeroSizedWiteFlushesHandshakeMessages) { + // If there are pending handshake mesages, an |SSL_write| of zero bytes should + // flush them. + bssl::UniquePtr server_ctx(SSL_CTX_new(TLS_method())); + EXPECT_TRUE(SSL_CTX_set_max_proto_version(server_ctx.get(), TLS1_3_VERSION)); + EXPECT_TRUE(SSL_CTX_set_min_proto_version(server_ctx.get(), TLS1_3_VERSION)); + bssl::UniquePtr cert = GetTestCertificate(); + bssl::UniquePtr key = GetTestKey(); + ASSERT_TRUE(cert); + ASSERT_TRUE(key); + ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get())); + ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())); + + bssl::UniquePtr client_ctx(SSL_CTX_new(TLS_method())); + EXPECT_TRUE(SSL_CTX_set_max_proto_version(client_ctx.get(), TLS1_3_VERSION)); + EXPECT_TRUE(SSL_CTX_set_min_proto_version(client_ctx.get(), TLS1_3_VERSION)); + + bssl::UniquePtr client, server; + ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(), + server_ctx.get())); + + BIO *client_wbio = SSL_get_wbio(client.get()); + EXPECT_EQ(0u, BIO_wpending(client_wbio)); + EXPECT_TRUE(SSL_key_update(client.get(), SSL_KEY_UPDATE_NOT_REQUESTED)); + EXPECT_EQ(0u, BIO_wpending(client_wbio)); + EXPECT_EQ(0, SSL_write(client.get(), nullptr, 0)); + EXPECT_NE(0u, BIO_wpending(client_wbio)); +} + TEST_P(SSLVersionTest, VerifyBeforeCertRequest) { // Configure the server to request client certificates. SSL_CTX_set_custom_verify( diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index 77ed7968..62db0767 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc @@ -981,6 +981,12 @@ static bool DoExchange(bssl::UniquePtr *out_session, pending_initial_write = false; } + if (config->key_update && + !SSL_key_update(ssl, SSL_KEY_UPDATE_NOT_REQUESTED)) { + fprintf(stderr, "SSL_key_update failed.\n"); + return false; + } + for (int i = 0; i < n; i++) { buf[i] ^= 0xff; } diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index b6b6ffa0..816ffcad 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -106,6 +106,7 @@ type Conn struct { pendingFragments [][]byte // pending outgoing handshake fragments. pendingPacket []byte // pending outgoing packet. + keyUpdateSeen bool keyUpdateRequested bool seenOneByteRecord bool @@ -1635,6 +1636,8 @@ func (c *Conn) handlePostHandshakeMessage() error { } if keyUpdate, ok := msg.(*keyUpdateMsg); ok { + c.keyUpdateSeen = true + if c.config.Bugs.RejectUnsolicitedKeyUpdate { return errors.New("tls: unexpected KeyUpdate message") } @@ -1665,7 +1668,7 @@ func (c *Conn) ReadKeyUpdateACK() error { keyUpdate, ok := msg.(*keyUpdateMsg) if !ok { c.sendAlert(alertUnexpectedMessage) - return errors.New("tls: unexpected message when reading KeyUpdate") + return fmt.Errorf("tls: unexpected message (%T) when reading KeyUpdate", msg) } if keyUpdate.keyUpdateRequest != keyUpdateNotRequested { @@ -1708,13 +1711,12 @@ func (c *Conn) Read(b []byte) (n int, err error) { // Soft error, like EAGAIN return 0, err } - if c.hand.Len() > 0 { + for c.hand.Len() > 0 { // We received handshake bytes, indicating a // post-handshake message. if err := c.handlePostHandshakeMessage(); err != nil { return 0, err } - continue } } if err := c.in.err; err != nil { diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index b5cc0a79..d64e95f0 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -473,6 +473,11 @@ type testCase struct { sendKeyUpdates int // keyUpdateRequest is the KeyUpdateRequest value to send in KeyUpdate messages. keyUpdateRequest byte + // expectUnsolicitedKeyUpdate makes the test expect a one or more KeyUpdate + // messages while reading data from the shim. Don't use this in combination + // with any of the fields that send a KeyUpdate otherwise any received + // KeyUpdate might not be as unsolicited as expected. + expectUnsolicitedKeyUpdate bool // expectMessageDropped, if true, means the test message is expected to // be dropped by the client rather than echoed back. expectMessageDropped bool @@ -951,6 +956,10 @@ func doExchange(test *testCase, config *Config, conn net.Conn, isResume bool, tr return fmt.Errorf("bad reply contents at byte %d; got %q and wanted %q", i, buf, testMessage) } } + + if seen := tlsConn.keyUpdateSeen; seen != test.expectUnsolicitedKeyUpdate { + return fmt.Errorf("keyUpdateSeen (%t) != expectUnsolicitedKeyUpdate", seen) + } } return nil @@ -2826,7 +2835,7 @@ read alert 1 0 expectedError: ":WRONG_VERSION_NUMBER:", }, { - name: "KeyUpdate-Client", + name: "KeyUpdate-ToClient", config: Config{ MaxVersion: VersionTLS13, }, @@ -2835,13 +2844,30 @@ read alert 1 0 }, { testType: serverTest, - name: "KeyUpdate-Server", + name: "KeyUpdate-ToServer", config: Config{ MaxVersion: VersionTLS13, }, sendKeyUpdates: 1, keyUpdateRequest: keyUpdateNotRequested, }, + { + name: "KeyUpdate-FromClient", + config: Config{ + MaxVersion: VersionTLS13, + }, + expectUnsolicitedKeyUpdate: true, + flags: []string{"-key-update"}, + }, + { + testType: serverTest, + name: "KeyUpdate-FromServer", + config: Config{ + MaxVersion: VersionTLS13, + }, + expectUnsolicitedKeyUpdate: true, + flags: []string{"-key-update"}, + }, { name: "KeyUpdate-InvalidRequestMode", config: Config{ diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index bed05010..b88d0aec 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc @@ -148,6 +148,7 @@ const Flag kBoolFlags[] = { { "-jdk11-workaround", &TestConfig::jdk11_workaround }, { "-server-preference", &TestConfig::server_preference }, { "-export-traffic-secrets", &TestConfig::export_traffic_secrets }, + { "-key-update", &TestConfig::key_update }, }; const Flag kStringFlags[] = { diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index 0d0753e8..5d5eb5ad 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h @@ -172,6 +172,7 @@ struct TestConfig { bool jdk11_workaround = false; bool server_preference = false; bool export_traffic_secrets = false; + bool key_update = false; int argc; char **argv; diff --git a/ssl/tls13_both.cc b/ssl/tls13_both.cc index 6baeaf71..f6e359c3 100644 --- a/ssl/tls13_both.cc +++ b/ssl/tls13_both.cc @@ -607,6 +607,25 @@ bool tls13_add_finished(SSL_HANDSHAKE *hs) { return true; } +bool tls13_add_key_update(SSL *ssl, int update_requested) { + ScopedCBB cbb; + CBB body_cbb; + if (!ssl->method->init_message(ssl, cbb.get(), &body_cbb, + SSL3_MT_KEY_UPDATE) || + !CBB_add_u8(&body_cbb, update_requested) || + !ssl_add_message_cbb(ssl, cbb.get()) || + !tls13_rotate_traffic_key(ssl, evp_aead_seal)) { + return false; + } + + // Suppress KeyUpdate acknowledgments until this change is written to the + // wire. This prevents us from accumulating write obligations when read and + // write progress at different rates. See RFC 8446, section 4.6.3. + ssl->s3->key_update_pending = true; + + return true; +} + static bool tls13_receive_key_update(SSL *ssl, const SSLMessage &msg) { CBS body = msg.body; uint8_t key_update_request; @@ -625,21 +644,9 @@ static bool tls13_receive_key_update(SSL *ssl, const SSLMessage &msg) { // Acknowledge the KeyUpdate if (key_update_request == SSL_KEY_UPDATE_REQUESTED && - !ssl->s3->key_update_pending) { - ScopedCBB cbb; - CBB body_cbb; - if (!ssl->method->init_message(ssl, cbb.get(), &body_cbb, - SSL3_MT_KEY_UPDATE) || - !CBB_add_u8(&body_cbb, SSL_KEY_UPDATE_NOT_REQUESTED) || - !ssl_add_message_cbb(ssl, cbb.get()) || - !tls13_rotate_traffic_key(ssl, evp_aead_seal)) { - return false; - } - - // Suppress KeyUpdate acknowledgments until this change is written to the - // wire. This prevents us from accumulating write obligations when read and - // write progress at different rates. See RFC 8446, section 4.6.3. - ssl->s3->key_update_pending = true; + !ssl->s3->key_update_pending && + !tls13_add_key_update(ssl, SSL_KEY_UPDATE_NOT_REQUESTED)) { + return false; } return true;