diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index aadc4f01..ef389021 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc @@ -1079,23 +1079,9 @@ static ScopedEVP_PKEY GetTestKey() { PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); } -static bool TestSequenceNumber(bool dtls) { - ScopedSSL_CTX client_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); - ScopedSSL_CTX server_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); - if (!client_ctx || !server_ctx) { - return false; - } - - ScopedX509 cert = GetTestCertificate(); - ScopedEVP_PKEY key = GetTestKey(); - if (!cert || !key || - !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || - !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { - return false; - } - - // Create a client and server connected to each other. - ScopedSSL client(SSL_new(client_ctx.get())), server(SSL_new(server_ctx.get())); +static bool ConnectClientAndServer(ScopedSSL *out_client, ScopedSSL *out_server, + SSL_CTX *client_ctx, SSL_CTX *server_ctx) { + ScopedSSL client(SSL_new(client_ctx)), server(SSL_new(server_ctx)); if (!client || !server) { return false; } @@ -1135,6 +1121,32 @@ static bool TestSequenceNumber(bool dtls) { } } + *out_client = std::move(client); + *out_server = std::move(server); + return true; +} + +static bool TestSequenceNumber(bool dtls) { + ScopedSSL_CTX client_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); + ScopedSSL_CTX server_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); + if (!client_ctx || !server_ctx) { + return false; + } + + ScopedX509 cert = GetTestCertificate(); + ScopedEVP_PKEY key = GetTestKey(); + if (!cert || !key || + !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { + return false; + } + + ScopedSSL client, server; + if (!ConnectClientAndServer(&client, &server, client_ctx.get(), + server_ctx.get())) { + return false; + } + uint64_t client_read_seq = SSL_get_read_sequence(client.get()); uint64_t client_write_seq = SSL_get_write_sequence(client.get()); uint64_t server_read_seq = SSL_get_read_sequence(server.get()); @@ -1183,6 +1195,62 @@ static bool TestSequenceNumber(bool dtls) { return true; } +static bool TestOneSidedShutdown() { + ScopedSSL_CTX client_ctx(SSL_CTX_new(TLS_method())); + ScopedSSL_CTX server_ctx(SSL_CTX_new(TLS_method())); + if (!client_ctx || !server_ctx) { + return false; + } + + ScopedX509 cert = GetTestCertificate(); + ScopedEVP_PKEY key = GetTestKey(); + if (!cert || !key || + !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { + return false; + } + + ScopedSSL client, server; + if (!ConnectClientAndServer(&client, &server, client_ctx.get(), + server_ctx.get())) { + return false; + } + + // Shut down half the connection. SSL_shutdown will return 0 to signal only + // one side has shut down. + if (SSL_shutdown(client.get()) != 0) { + fprintf(stderr, "Could not shutdown.\n"); + return false; + } + + // Reading from the server should consume the EOF. + uint8_t byte; + if (SSL_read(server.get(), &byte, 1) != 0 || + SSL_get_error(server.get(), 0) != SSL_ERROR_ZERO_RETURN) { + fprintf(stderr, "Connection was not shut down cleanly.\n"); + return false; + } + + // However, the server may continue to write data and then shut down the + // connection. + byte = 42; + if (SSL_write(server.get(), &byte, 1) != 1 || + SSL_read(client.get(), &byte, 1) != 1 || + byte != 42) { + fprintf(stderr, "Could not send byte.\n"); + return false; + } + + // The server may then shutdown the connection. + if (SSL_shutdown(server.get()) != 1 || + SSL_shutdown(client.get()) != 1) { + fprintf(stderr, "Could not complete shutdown.\n"); + return false; + } + + return true; +} + int main() { CRYPTO_library_init(); @@ -1206,7 +1274,8 @@ int main() { !TestClientCAList() || !TestInternalSessionCache() || !TestSequenceNumber(false /* TLS */) || - !TestSequenceNumber(true /* DTLS */)) { + !TestSequenceNumber(true /* DTLS */) || + !TestOneSidedShutdown()) { ERR_print_errors_fp(stderr); return 1; }