diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc index 45461cd6..952ac114 100644 --- a/ssl/ssl_test.cc +++ b/ssl/ssl_test.cc @@ -2970,6 +2970,166 @@ static bool TestAutoChain(bool is_dtls, const SSL_METHOD *method, return true; } +static bool ExpectBadWriteRetry() { + int err = ERR_get_error(); + if (ERR_GET_LIB(err) != ERR_LIB_SSL || + ERR_GET_REASON(err) != SSL_R_BAD_WRITE_RETRY) { + char buf[ERR_ERROR_STRING_BUF_LEN]; + ERR_error_string_n(err, buf, sizeof(buf)); + fprintf(stderr, "Wanted SSL_R_BAD_WRITE_RETRY, got: %s.\n", buf); + return false; + } + + if (ERR_peek_error() != 0) { + fprintf(stderr, "Unexpected error following SSL_R_BAD_WRITE_RETRY.\n"); + return false; + } + + return true; +} + +static bool TestSSLWriteRetry(bool is_dtls, const SSL_METHOD *method, + uint16_t version) { + if (is_dtls) { + return true; + } + + for (bool enable_partial_write : std::vector{false, true}) { + // Connect a client and server. + bssl::UniquePtr cert = GetTestCertificate(); + bssl::UniquePtr key = GetTestKey(); + bssl::UniquePtr ctx(SSL_CTX_new(method)); + bssl::UniquePtr client, server; + if (!cert || !key || !ctx || + !SSL_CTX_use_certificate(ctx.get(), cert.get()) || + !SSL_CTX_use_PrivateKey(ctx.get(), key.get()) || + !SSL_CTX_set_min_proto_version(ctx.get(), version) || + !SSL_CTX_set_max_proto_version(ctx.get(), version) || + !ConnectClientAndServer(&client, &server, ctx.get(), ctx.get(), + nullptr /* no session */)) { + return false; + } + + if (enable_partial_write) { + SSL_set_mode(client.get(), SSL_MODE_ENABLE_PARTIAL_WRITE); + } + + // Write without reading until the buffer is full and we have an unfinished + // write. Keep a count so we may reread it again later. "hello!" will be + // written in two chunks, "hello" and "!". + char data[] = "hello!"; + static const int kChunkLen = 5; // The length of "hello". + unsigned count = 0; + for (;;) { + int ret = SSL_write(client.get(), data, kChunkLen); + if (ret <= 0) { + int err = SSL_get_error(client.get(), ret); + if (SSL_get_error(client.get(), ret) == SSL_ERROR_WANT_WRITE) { + break; + } + fprintf(stderr, "SSL_write failed in unexpected way: %d\n", err); + return false; + } + + if (ret != 5) { + fprintf(stderr, "SSL_write wrote %d bytes, expected 5.\n", ret); + return false; + } + + count++; + } + + // Retrying with the same parameters is legal. + if (SSL_get_error(client.get(), SSL_write(client.get(), data, kChunkLen)) != + SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "SSL_write retry unexpectedly failed.\n"); + return false; + } + + // Retrying with the same buffer but shorter length is not legal. + if (SSL_get_error(client.get(), + SSL_write(client.get(), data, kChunkLen - 1)) != + SSL_ERROR_SSL || + !ExpectBadWriteRetry()) { + fprintf(stderr, "SSL_write retry did not fail as expected.\n"); + return false; + } + + // Retrying with a different buffer pointer is not legal. + char data2[] = "hello"; + if (SSL_get_error(client.get(), SSL_write(client.get(), data2, + kChunkLen)) != SSL_ERROR_SSL || + !ExpectBadWriteRetry()) { + fprintf(stderr, "SSL_write retry did not fail as expected.\n"); + return false; + } + + // With |SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER|, the buffer may move. + SSL_set_mode(client.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + if (SSL_get_error(client.get(), + SSL_write(client.get(), data2, kChunkLen)) != + SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "SSL_write retry unexpectedly failed.\n"); + return false; + } + + // |SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER| does not disable length checks. + if (SSL_get_error(client.get(), + SSL_write(client.get(), data2, kChunkLen - 1)) != + SSL_ERROR_SSL || + !ExpectBadWriteRetry()) { + fprintf(stderr, "SSL_write retry did not fail as expected.\n"); + return false; + } + + // Retrying with a larger buffer is legal. + if (SSL_get_error(client.get(), + SSL_write(client.get(), data, kChunkLen + 1)) != + SSL_ERROR_WANT_WRITE) { + fprintf(stderr, "SSL_write retry unexpectedly failed.\n"); + return false; + } + + // Drain the buffer. + char buf[20]; + for (unsigned i = 0; i < count; i++) { + if (SSL_read(server.get(), buf, sizeof(buf)) != kChunkLen || + OPENSSL_memcmp(buf, "hello", kChunkLen) != 0) { + fprintf(stderr, "Failed to read initial records.\n"); + return false; + } + } + + // Now that there is space, a retry with a larger buffer should flush the + // pending record, skip over that many bytes of input (on assumption they + // are the same), and write the remainder. If SSL_MODE_ENABLE_PARTIAL_WRITE + // is set, this will complete in two steps. + char data3[] = "_____!"; + if (enable_partial_write) { + if (SSL_write(client.get(), data3, kChunkLen + 1) != kChunkLen || + SSL_write(client.get(), data3 + kChunkLen, 1) != 1) { + fprintf(stderr, "SSL_write retry failed.\n"); + return false; + } + } else if (SSL_write(client.get(), data3, kChunkLen + 1) != kChunkLen + 1) { + fprintf(stderr, "SSL_write retry failed.\n"); + return false; + } + + // Check the last write was correct. The data will be spread over two + // records, so SSL_read returns twice. + if (SSL_read(server.get(), buf, sizeof(buf)) != kChunkLen || + OPENSSL_memcmp(buf, "hello", kChunkLen) != 0 || + SSL_read(server.get(), buf, sizeof(buf)) != 1 || + buf[0] != '!') { + fprintf(stderr, "Failed to read write retry.\n"); + return false; + } + } + + return true; +} + static bool ForEachVersion(bool (*test_func)(bool is_dtls, const SSL_METHOD *method, uint16_t version)) { @@ -3047,7 +3207,8 @@ int main() { !ForEachVersion(TestVersion) || !ForEachVersion(TestALPNCipherAvailable) || !ForEachVersion(TestSSLClearSessionResumption) || - !ForEachVersion(TestAutoChain)) { + !ForEachVersion(TestAutoChain) || + !ForEachVersion(TestSSLWriteRetry)) { ERR_print_errors_fp(stderr); return 1; }