|
|
@@ -924,7 +924,10 @@ static size_t GetClientHelloLen(size_t ticket_len) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get())); |
|
|
|
if (!ssl || !SSL_set_session(ssl.get(), session.get())) { |
|
|
|
// Test at TLS 1.2. TLS 1.3 adds enough extensions that the ClientHello is |
|
|
|
// longer than our test vectors. |
|
|
|
if (!ssl || !SSL_set_session(ssl.get(), session.get()) || |
|
|
|
!SSL_set_max_proto_version(ssl.get(), TLS1_2_VERSION)) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
std::vector<uint8_t> client_hello; |
|
|
@@ -1270,70 +1273,100 @@ static bool ConnectClientAndServer(bssl::UniquePtr<SSL> *out_client, bssl::Uniqu |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static bool TestSequenceNumber(bool dtls) { |
|
|
|
bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); |
|
|
|
bssl::UniquePtr<SSL_CTX> server_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method())); |
|
|
|
if (!client_ctx || !server_ctx) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
static uint16_t kTLSVersions[] = { |
|
|
|
SSL3_VERSION, TLS1_VERSION, TLS1_1_VERSION, TLS1_2_VERSION, TLS1_3_VERSION, |
|
|
|
}; |
|
|
|
|
|
|
|
bssl::UniquePtr<X509> cert = GetTestCertificate(); |
|
|
|
bssl::UniquePtr<EVP_PKEY> key = GetTestKey(); |
|
|
|
if (!cert || !key || |
|
|
|
!SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || |
|
|
|
!SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
static uint16_t kDTLSVersions[] = { |
|
|
|
DTLS1_VERSION, DTLS1_2_VERSION, |
|
|
|
}; |
|
|
|
|
|
|
|
bssl::UniquePtr<SSL> client, server; |
|
|
|
if (!ConnectClientAndServer(&client, &server, client_ctx.get(), |
|
|
|
server_ctx.get(), nullptr /* no session */)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
static bool TestSequenceNumber() { |
|
|
|
for (bool is_dtls : std::vector<bool>{false, true}) { |
|
|
|
const SSL_METHOD *method = is_dtls ? DTLS_method() : TLS_method(); |
|
|
|
const uint16_t *versions = is_dtls ? kDTLSVersions : kTLSVersions; |
|
|
|
size_t num_versions = is_dtls ? OPENSSL_ARRAY_SIZE(kDTLSVersions) |
|
|
|
: OPENSSL_ARRAY_SIZE(kTLSVersions); |
|
|
|
for (size_t i = 0; i < num_versions; i++) { |
|
|
|
uint16_t version = versions[i]; |
|
|
|
bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(method)); |
|
|
|
bssl::UniquePtr<SSL_CTX> server_ctx(SSL_CTX_new(method)); |
|
|
|
if (!server_ctx || !client_ctx || |
|
|
|
!SSL_CTX_set_min_proto_version(client_ctx.get(), version) || |
|
|
|
!SSL_CTX_set_max_proto_version(client_ctx.get(), version) || |
|
|
|
!SSL_CTX_set_min_proto_version(server_ctx.get(), version) || |
|
|
|
!SSL_CTX_set_max_proto_version(server_ctx.get(), version)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
uint64_t client_read_seq = SSL_get_read_sequence(client.get()); |
|
|
|
uint64_t client_write_seq = SSL_get_write_sequence(client.get()); |
|
|
|
uint64_t server_read_seq = SSL_get_read_sequence(server.get()); |
|
|
|
uint64_t server_write_seq = SSL_get_write_sequence(server.get()); |
|
|
|
bssl::UniquePtr<X509> cert = GetTestCertificate(); |
|
|
|
bssl::UniquePtr<EVP_PKEY> key = GetTestKey(); |
|
|
|
if (!cert || !key || |
|
|
|
!SSL_CTX_use_certificate(server_ctx.get(), cert.get()) || |
|
|
|
!SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (dtls) { |
|
|
|
// Both client and server must be at epoch 1. |
|
|
|
if (EpochFromSequence(client_read_seq) != 1 || |
|
|
|
EpochFromSequence(client_write_seq) != 1 || |
|
|
|
EpochFromSequence(server_read_seq) != 1 || |
|
|
|
EpochFromSequence(server_write_seq) != 1) { |
|
|
|
fprintf(stderr, "Bad epochs.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
bssl::UniquePtr<SSL> client, server; |
|
|
|
if (!ConnectClientAndServer(&client, &server, client_ctx.get(), |
|
|
|
server_ctx.get(), nullptr /* no session */)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// The next record to be written should exceed the largest received. |
|
|
|
if (client_write_seq <= server_read_seq || |
|
|
|
server_write_seq <= client_read_seq) { |
|
|
|
fprintf(stderr, "Inconsistent sequence numbers.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// The next record to be written should equal the next to be received. |
|
|
|
if (client_write_seq != server_read_seq || |
|
|
|
server_write_seq != client_write_seq) { |
|
|
|
fprintf(stderr, "Inconsistent sequence numbers.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
// Drain any post-handshake messages to ensure there are no unread records |
|
|
|
// on either end. |
|
|
|
uint8_t byte = 0; |
|
|
|
if (SSL_read(client.get(), &byte, 1) > 0 || |
|
|
|
SSL_read(server.get(), &byte, 1) > 0) { |
|
|
|
fprintf(stderr, "Received unexpected data.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// Send a record from client to server. |
|
|
|
uint8_t byte = 0; |
|
|
|
if (SSL_write(client.get(), &byte, 1) != 1 || |
|
|
|
SSL_read(server.get(), &byte, 1) != 1) { |
|
|
|
fprintf(stderr, "Could not send byte.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
uint64_t client_read_seq = SSL_get_read_sequence(client.get()); |
|
|
|
uint64_t client_write_seq = SSL_get_write_sequence(client.get()); |
|
|
|
uint64_t server_read_seq = SSL_get_read_sequence(server.get()); |
|
|
|
uint64_t server_write_seq = SSL_get_write_sequence(server.get()); |
|
|
|
|
|
|
|
if (is_dtls) { |
|
|
|
// Both client and server must be at epoch 1. |
|
|
|
if (EpochFromSequence(client_read_seq) != 1 || |
|
|
|
EpochFromSequence(client_write_seq) != 1 || |
|
|
|
EpochFromSequence(server_read_seq) != 1 || |
|
|
|
EpochFromSequence(server_write_seq) != 1) { |
|
|
|
fprintf(stderr, "Bad epochs.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// The next record to be written should exceed the largest received. |
|
|
|
if (client_write_seq <= server_read_seq || |
|
|
|
server_write_seq <= client_read_seq) { |
|
|
|
fprintf(stderr, "Inconsistent sequence numbers.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// The next record to be written should equal the next to be received. |
|
|
|
if (client_write_seq != server_read_seq || |
|
|
|
server_write_seq != client_read_seq) { |
|
|
|
fprintf(stderr, "Inconsistent sequence numbers.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// The client write and server read sequence numbers should have incremented. |
|
|
|
if (client_write_seq + 1 != SSL_get_write_sequence(client.get()) || |
|
|
|
server_read_seq + 1 != SSL_get_read_sequence(server.get())) { |
|
|
|
fprintf(stderr, "Sequence numbers did not increment.\n");\ |
|
|
|
return false; |
|
|
|
// Send a record from client to server. |
|
|
|
if (SSL_write(client.get(), &byte, 1) != 1 || |
|
|
|
SSL_read(server.get(), &byte, 1) != 1) { |
|
|
|
fprintf(stderr, "Could not send byte.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// The client write and server read sequence numbers should have |
|
|
|
// incremented. |
|
|
|
if (client_write_seq + 1 != SSL_get_write_sequence(client.get()) || |
|
|
|
server_read_seq + 1 != SSL_get_read_sequence(server.get())) { |
|
|
|
fprintf(stderr, "Sequence numbers did not increment.\n"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
@@ -1596,14 +1629,6 @@ static bool TestSetBIO() { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static uint16_t kTLSVersions[] = { |
|
|
|
SSL3_VERSION, TLS1_VERSION, TLS1_1_VERSION, TLS1_2_VERSION, TLS1_3_VERSION, |
|
|
|
}; |
|
|
|
|
|
|
|
static uint16_t kDTLSVersions[] = { |
|
|
|
DTLS1_VERSION, DTLS1_2_VERSION, |
|
|
|
}; |
|
|
|
|
|
|
|
static int VerifySucceed(X509_STORE_CTX *store_ctx, void *arg) { return 1; } |
|
|
|
|
|
|
|
static bool TestGetPeerCertificate() { |
|
|
@@ -2324,8 +2349,7 @@ int main() { |
|
|
|
!TestPaddingExtension() || |
|
|
|
!TestClientCAList() || |
|
|
|
!TestInternalSessionCache() || |
|
|
|
!TestSequenceNumber(false /* TLS */) || |
|
|
|
!TestSequenceNumber(true /* DTLS */) || |
|
|
|
!TestSequenceNumber() || |
|
|
|
!TestOneSidedShutdown() || |
|
|
|
!TestSessionDuplication() || |
|
|
|
!TestSetFD() || |
|
|
|