Add lh_FOO_retrieve_key to avoid stack-allocating SSL_SESSION.

lh_FOO_retrieve is often called with a dummy instance of FOO that has
only a few fields filled in. This works fine for C, but a C++
SSL_SESSION with destructors is a bit more of a nuisance here.

Instead, teach LHASH to allow queries by some external key type. This
avoids stack-allocating SSL_SESSION. Along the way, fix the
make_macros.sh script.

Change-Id: Ie0b482d4ffe1027049d49db63274c7c17f9398fa
Reviewed-on: https://boringssl-review.googlesource.com/29586
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2018-06-29 23:58:43 -04:00 committed by CQ bot account: commit-bot@chromium.org
parent 63c79122e0
commit 58150ed59b
8 changed files with 118 additions and 41 deletions

View File

@ -141,14 +141,12 @@ size_t lh_num_items(const _LHASH *lh) { return lh->num_items; }
static LHASH_ITEM **get_next_ptr_and_hash(const _LHASH *lh, uint32_t *out_hash,
const void *data) {
const uint32_t hash = lh->hash(data);
LHASH_ITEM *cur, **ret;
if (out_hash != NULL) {
*out_hash = hash;
}
ret = &lh->buckets[hash % lh->num_buckets];
for (cur = *ret; cur != NULL; cur = *ret) {
LHASH_ITEM **ret = &lh->buckets[hash % lh->num_buckets];
for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
if (lh->comp(cur->data, data) == 0) {
break;
}
@ -158,16 +156,32 @@ static LHASH_ITEM **get_next_ptr_and_hash(const _LHASH *lh, uint32_t *out_hash,
return ret;
}
void *lh_retrieve(const _LHASH *lh, const void *data) {
LHASH_ITEM **next_ptr;
next_ptr = get_next_ptr_and_hash(lh, NULL, data);
if (*next_ptr == NULL) {
return NULL;
// get_next_ptr_by_key behaves like |get_next_ptr_and_hash| but takes a key
// which may be a different type from the values stored in |lh|.
static LHASH_ITEM **get_next_ptr_by_key(const _LHASH *lh, const void *key,
uint32_t key_hash,
int (*cmp_key)(const void *key,
const void *value)) {
LHASH_ITEM **ret = &lh->buckets[key_hash % lh->num_buckets];
for (LHASH_ITEM *cur = *ret; cur != NULL; cur = *ret) {
if (cmp_key(key, cur->data) == 0) {
break;
}
ret = &cur->next;
}
return (*next_ptr)->data;
return ret;
}
void *lh_retrieve(const _LHASH *lh, const void *data) {
LHASH_ITEM **next_ptr = get_next_ptr_and_hash(lh, NULL, data);
return *next_ptr == NULL ? NULL : (*next_ptr)->data;
}
void *lh_retrieve_key(const _LHASH *lh, const void *key, uint32_t key_hash,
int (*cmp_key)(const void *key, const void *value)) {
LHASH_ITEM **next_ptr = get_next_ptr_by_key(lh, key, key_hash, cmp_key);
return *next_ptr == NULL ? NULL : (*next_ptr)->data;
}
// lh_rebucket allocates a new array of |new_num_buckets| pointers and

View File

@ -103,6 +103,17 @@ TEST(LHashTest, Basic) {
std::unique_ptr<char[]> key = RandString();
void *value = lh_retrieve(lh.get(), key.get());
EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
// Do the same lookup with |lh_retrieve_key|.
value = lh_retrieve_key(
lh.get(), &key, lh_strhash(key.get()),
[](const void *key_ptr, const void *data) -> int {
const char *key_data =
reinterpret_cast<const std::unique_ptr<char[]> *>(key_ptr)
->get();
return strcmp(key_data, reinterpret_cast<const char *>(data));
});
EXPECT_EQ(Lookup(&dummy_lh, key.get()), value);
break;
}

View File

@ -28,7 +28,7 @@ output_lhash () {
type=$1
cat >> $out << EOF
/* ${type} */
// ${type}
#define lh_${type}_new(hash, comp)\\
((LHASH_OF(${type})*) lh_new(CHECKED_CAST(lhash_hash_func, uint32_t (*) (const ${type} *), hash), CHECKED_CAST(lhash_cmp_func, int (*) (const ${type} *a, const ${type} *b), comp)))
@ -41,6 +41,9 @@ output_lhash () {
#define lh_${type}_retrieve(lh, data)\\
((${type}*) lh_retrieve(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void*, ${type}*, data)))
#define lh_${type}_retrieve_key(lh, key, key_hash, cmp_key)\\
((${type}*) lh_retrieve_key(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), key, key_hash, CHECKED_CAST(int (*)(const void *, const void *), int (*)(const void *, const ${type} *), cmp_key)))
#define lh_${type}_insert(lh, old_data, data)\\
lh_insert(CHECKED_CAST(_LHASH*, LHASH_OF(${type})*, lh), CHECKED_CAST(void**, ${type}**, old_data), CHECKED_CAST(void*, ${type}*, data))
@ -57,7 +60,7 @@ output_lhash () {
EOF
}
lhash_types=$(cat ${include_dir}/lhash.h | grep '^ \* LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
lhash_types=$(cat ${include_dir}/lhash.h | grep '^// LHASH_OF:' | sed -e 's/.*LHASH_OF://' -e 's/ .*//')
for type in $lhash_types; do
echo Hash of ${type}

View File

@ -141,6 +141,15 @@ OPENSSL_EXPORT size_t lh_num_items(const _LHASH *lh);
// it. If no such element exists, it returns NULL.
OPENSSL_EXPORT void *lh_retrieve(const _LHASH *lh, const void *data);
// lh_retrieve_key finds an element matching |key|, given the specified hash and
// comparison function. This differs from |lh_retrieve| in that the key may be a
// different type than the values stored in |lh|. |key_hash| and |cmp_key| must
// be compatible with the functions passed into |lh_new|.
OPENSSL_EXPORT void *lh_retrieve_key(const _LHASH *lh, const void *key,
uint32_t key_hash,
int (*cmp_key)(const void *key,
const void *value));
// lh_insert inserts |data| into the hash table. If an existing element is
// equal to |data| (with respect to the comparison function) then |*old_data|
// will be set to that value and it will be replaced. Otherwise, or in the

View File

@ -35,6 +35,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
CHECKED_CAST(void *, ASN1_OBJECT *, data)))
#define lh_ASN1_OBJECT_retrieve_key(lh, key, key_hash, cmp_key) \
((ASN1_OBJECT *)lh_retrieve_key( \
CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), key, key_hash, \
CHECKED_CAST(int (*)(const void *, const void *), \
int (*)(const void *, const ASN1_OBJECT *), cmp_key)))
#define lh_ASN1_OBJECT_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(ASN1_OBJECT) *, lh), \
CHECKED_CAST(void **, ASN1_OBJECT **, old_data), \
@ -74,6 +80,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
CHECKED_CAST(void *, CONF_VALUE *, data)))
#define lh_CONF_VALUE_retrieve_key(lh, key, key_hash, cmp_key) \
((CONF_VALUE *)lh_retrieve_key( \
CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), key, key_hash, \
CHECKED_CAST(int (*)(const void *, const void *), \
int (*)(const void *, const CONF_VALUE *), cmp_key)))
#define lh_CONF_VALUE_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CONF_VALUE) *, lh), \
CHECKED_CAST(void **, CONF_VALUE **, old_data), \
@ -113,6 +125,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
CHECKED_CAST(void *, CRYPTO_BUFFER *, data)))
#define lh_CRYPTO_BUFFER_retrieve_key(lh, key, key_hash, cmp_key) \
((CRYPTO_BUFFER *)lh_retrieve_key( \
CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), key, key_hash, \
CHECKED_CAST(int (*)(const void *, const void *), \
int (*)(const void *, const CRYPTO_BUFFER *), cmp_key)))
#define lh_CRYPTO_BUFFER_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(CRYPTO_BUFFER) *, lh), \
CHECKED_CAST(void **, CRYPTO_BUFFER **, old_data), \
@ -153,6 +171,12 @@
CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
CHECKED_CAST(void *, SSL_SESSION *, data)))
#define lh_SSL_SESSION_retrieve_key(lh, key, key_hash, cmp_key) \
((SSL_SESSION *)lh_retrieve_key( \
CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), key, key_hash, \
CHECKED_CAST(int (*)(const void *, const void *), \
int (*)(const void *, const SSL_SESSION *), cmp_key)))
#define lh_SSL_SESSION_insert(lh, old_data, data) \
lh_insert(CHECKED_CAST(_LHASH *, LHASH_OF(SSL_SESSION) *, lh), \
CHECKED_CAST(void **, SSL_SESSION **, old_data), \

View File

@ -2909,6 +2909,10 @@ int ssl_ctx_rotate_ticket_encryption_key(SSL_CTX *ctx);
// error.
UniquePtr<SSL_SESSION> ssl_session_new(const SSL_X509_METHOD *x509_method);
// ssl_hash_session_id returns a hash of |session_id|, suitable for a hash table
// keyed on session IDs.
uint32_t ssl_hash_session_id(Span<const uint8_t> session_id);
// SSL_SESSION_parse parses an |SSL_SESSION| from |cbs| and advances |cbs| over
// the parsed data.
UniquePtr<SSL_SESSION> SSL_SESSION_parse(CBS *cbs,

View File

@ -567,22 +567,8 @@ int OPENSSL_init_ssl(uint64_t opts, const OPENSSL_INIT_SETTINGS *settings) {
}
static uint32_t ssl_session_hash(const SSL_SESSION *sess) {
const uint8_t *session_id = sess->session_id;
uint8_t tmp_storage[sizeof(uint32_t)];
if (sess->session_id_length < sizeof(tmp_storage)) {
OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
OPENSSL_memcpy(tmp_storage, sess->session_id, sess->session_id_length);
session_id = tmp_storage;
}
uint32_t hash =
((uint32_t)session_id[0]) |
((uint32_t)session_id[1] << 8) |
((uint32_t)session_id[2] << 16) |
((uint32_t)session_id[3] << 24);
return hash;
return ssl_hash_session_id(
MakeConstSpan(sess->session_id, sess->session_id_length));
}
static int ssl_session_cmp(const SSL_SESSION *a, const SSL_SESSION *b) {

View File

@ -184,6 +184,26 @@ UniquePtr<SSL_SESSION> ssl_session_new(const SSL_X509_METHOD *x509_method) {
return session;
}
uint32_t ssl_hash_session_id(Span<const uint8_t> session_id) {
// Take the first four bytes of |session_id|. Session IDs are generated by the
// server randomly, so we can assume even using the first four bytes results
// in a good distribution.
uint8_t tmp_storage[sizeof(uint32_t)];
if (session_id.size() < sizeof(tmp_storage)) {
OPENSSL_memset(tmp_storage, 0, sizeof(tmp_storage));
OPENSSL_memcpy(tmp_storage, session_id.data(), session_id.size());
session_id = tmp_storage;
}
uint32_t hash =
((uint32_t)session_id[0]) |
((uint32_t)session_id[1] << 8) |
((uint32_t)session_id[2] << 16) |
((uint32_t)session_id[3] << 24);
return hash;
}
UniquePtr<SSL_SESSION> SSL_SESSION_dup(SSL_SESSION *session, int dup_flags) {
UniquePtr<SSL_SESSION> new_session = ssl_session_new(session->x509_method);
if (!new_session) {
@ -657,11 +677,11 @@ int ssl_session_is_resumable(const SSL_HANDSHAKE *hs,
// |*out_session| to an |SSL_SESSION| object if found.
static enum ssl_hs_wait_t ssl_lookup_session(
SSL_HANDSHAKE *hs, UniquePtr<SSL_SESSION> *out_session,
const uint8_t *session_id, size_t session_id_len) {
Span<const uint8_t> session_id) {
SSL *const ssl = hs->ssl;
out_session->reset();
if (session_id_len == 0 || session_id_len > SSL_MAX_SSL_SESSION_ID_LENGTH) {
if (session_id.empty() || session_id.size() > SSL_MAX_SSL_SESSION_ID_LENGTH) {
return ssl_hs_ok;
}
@ -669,21 +689,26 @@ static enum ssl_hs_wait_t ssl_lookup_session(
// Try the internal cache, if it exists.
if (!(ssl->session_ctx->session_cache_mode &
SSL_SESS_CACHE_NO_INTERNAL_LOOKUP)) {
SSL_SESSION data;
data.session_id_length = session_id_len;
OPENSSL_memcpy(data.session_id, session_id, session_id_len);
uint32_t hash = ssl_hash_session_id(session_id);
auto cmp = [](const void *key, const SSL_SESSION *sess) -> int {
Span<const uint8_t> key_id =
*reinterpret_cast<const Span<const uint8_t> *>(key);
Span<const uint8_t> sess_id =
MakeConstSpan(sess->session_id, sess->session_id_length);
return key_id == sess_id ? 0 : 1;
};
MutexReadLock lock(&ssl->session_ctx->lock);
// |lh_SSL_SESSION_retrieve| returns a non-owning pointer.
session = UpRef(lh_SSL_SESSION_retrieve(ssl->session_ctx->sessions, &data));
// |lh_SSL_SESSION_retrieve_key| returns a non-owning pointer.
session = UpRef(lh_SSL_SESSION_retrieve_key(ssl->session_ctx->sessions,
&session_id, hash, cmp));
// TODO(davidben): This should probably move it to the front of the list.
}
// Fall back to the external cache, if it exists.
if (!session && ssl->session_ctx->get_session_cb != nullptr) {
int copy = 1;
session.reset(ssl->session_ctx->get_session_cb(ssl, session_id,
session_id_len, &copy));
session.reset(ssl->session_ctx->get_session_cb(ssl, session_id.data(),
session_id.size(), &copy));
if (!session) {
return ssl_hs_ok;
}
@ -752,7 +777,8 @@ enum ssl_hs_wait_t ssl_get_prev_session(SSL_HANDSHAKE *hs,
} else {
// The client didn't send a ticket, so the session ID is a real ID.
enum ssl_hs_wait_t lookup_ret = ssl_lookup_session(
hs, &session, client_hello->session_id, client_hello->session_id_len);
hs, &session,
MakeConstSpan(client_hello->session_id, client_hello->session_id_len));
if (lookup_ret != ssl_hs_ok) {
return lookup_ret;
}