@@ -15,6 +15,7 @@
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <string>
#include <vector>
@@ -819,6 +820,157 @@ static bool TestClientCAList() {
return SSL_get_client_CA_list(ssl.get()) == stack;
}
static void AppendSession(SSL_SESSION *session, void *arg) {
std::vector<SSL_SESSION*> *out =
reinterpret_cast<std::vector<SSL_SESSION*>*>(arg);
out->push_back(session);
}
// ExpectCache returns true if |ctx|'s session cache consists of |expected|, in
// order.
static bool ExpectCache(SSL_CTX *ctx,
const std::vector<SSL_SESSION*> &expected) {
// Check the linked list.
SSL_SESSION *ptr = ctx->session_cache_head;
for (SSL_SESSION *session : expected) {
if (ptr != session) {
return false;
}
// TODO(davidben): This is an absurd way to denote the end of the list.
if (ptr->next ==
reinterpret_cast<SSL_SESSION *>(&ctx->session_cache_tail)) {
ptr = nullptr;
} else {
ptr = ptr->next;
}
}
if (ptr != nullptr) {
return false;
}
// Check the hash table.
std::vector<SSL_SESSION*> actual, expected_copy;
lh_SSL_SESSION_doall_arg(SSL_CTX_sessions(ctx), AppendSession, &actual);
expected_copy = expected;
std::sort(actual.begin(), actual.end());
std::sort(expected_copy.begin(), expected_copy.end());
return actual == expected_copy;
}
static ScopedSSL_SESSION CreateTestSession(uint32_t number) {
ScopedSSL_SESSION ret(SSL_SESSION_new());
if (!ret) {
return nullptr;
}
ret->session_id_length = SSL3_SSL_SESSION_ID_LENGTH;
memset(ret->session_id, 0, ret->session_id_length);
memcpy(ret->session_id, &number, sizeof(number));
return ret;
}
// TODO(davidben): Switch this to a |std::vector<ScopedSSL_SESSION>| once we can
// rely on a move-aware |std::vector|.
template <class T>
class ScopedVector {
public:
explicit ScopedVector(std::vector<T*> *v) : v_(v) {}
~ScopedVector() {
for (T *t : *v_) {
delete t;
}
}
private:
std::vector<T*> *const v_;
};
// Test that the internal session cache behaves as expected.
static bool TestInternalSessionCache() {
ScopedSSL_CTX ctx(SSL_CTX_new(TLS_method()));
if (!ctx) {
return false;
}
// Prepare 10 test sessions.
std::vector<SSL_SESSION*> sessions;
ScopedVector<SSL_SESSION> cleanup(&sessions);
for (int i = 0; i < 10; i++) {
ScopedSSL_SESSION session = CreateTestSession(i);
if (!session) {
return false;
}
sessions.push_back(session.release());
}
SSL_CTX_sess_set_cache_size(ctx.get(), 5);
// Insert all the test sessions.
for (SSL_SESSION *session : sessions) {
if (!SSL_CTX_add_session(ctx.get(), session)) {
return false;
}
}
// Only the last five should be in the list.
std::vector<SSL_SESSION*> expected;
expected.push_back(sessions[9]);
expected.push_back(sessions[8]);
expected.push_back(sessions[7]);
expected.push_back(sessions[6]);
expected.push_back(sessions[5]);
if (!ExpectCache(ctx.get(), expected)) {
return false;
}
// Inserting an element already in the cache should fail.
if (SSL_CTX_add_session(ctx.get(), sessions[7]) ||
!ExpectCache(ctx.get(), expected)) {
return false;
}
// Although collisions should be impossible (256-bit session IDs), the cache
// must handle them gracefully.
ScopedSSL_SESSION collision(CreateTestSession(7));
if (!collision || !SSL_CTX_add_session(ctx.get(), collision.get())) {
return false;
}
expected.clear();
expected.push_back(collision.get());
expected.push_back(sessions[9]);
expected.push_back(sessions[8]);
expected.push_back(sessions[6]);
expected.push_back(sessions[5]);
if (!ExpectCache(ctx.get(), expected)) {
return false;
}
// Removing sessions behaves correctly.
if (!SSL_CTX_remove_session(ctx.get(), sessions[6])) {
return false;
}
expected.clear();
expected.push_back(collision.get());
expected.push_back(sessions[9]);
expected.push_back(sessions[8]);
expected.push_back(sessions[5]);
if (!ExpectCache(ctx.get(), expected)) {
return false;
}
// Removing sessions requires an exact match.
if (SSL_CTX_remove_session(ctx.get(), sessions[0]) ||
SSL_CTX_remove_session(ctx.get(), sessions[7]) ||
!ExpectCache(ctx.get(), expected)) {
return false;
}
return true;
}
int main() {
CRYPTO_library_init();
@@ -839,7 +991,8 @@ int main() {
!TestDefaultVersion(DTLS1_2_VERSION, &DTLSv1_2_method) ||
!TestCipherGetRFCName() ||
!TestPaddingExtension() ||
!TestClientCAList()) {
!TestClientCAList() ||
!TestInternalSessionCache()) {
ERR_print_errors_fp(stderr);
return 1;
}