Don't use a global for early_callback_called.

We have a stateful object hanging off the SSL* now. May as well use it and
avoid having to remember to reset that.

Change-Id: I5fc5269aa9b158517dd551036e658afaa2ef9acd
Reviewed-on: https://boringssl-review.googlesource.com/3349
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2015-02-09 13:03:50 -05:00 committed by Adam Langley
parent c273d2c537
commit 2d445c0921

View File

@ -40,23 +40,24 @@ static int Usage(const char *program) {
return 1; return 1;
} }
struct AsyncState { struct TestState {
AsyncState() : cert_ready(false) {} TestState() : cert_ready(false), early_callback_called(false) {}
ScopedEVP_PKEY channel_id; ScopedEVP_PKEY channel_id;
bool cert_ready; bool cert_ready;
ScopedSSL_SESSION session; ScopedSSL_SESSION session;
ScopedSSL_SESSION pending_session; ScopedSSL_SESSION pending_session;
bool early_callback_called;
}; };
static void AsyncExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad, int index, static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad,
long argl, void *argp) { int index, long argl, void *argp) {
delete ((AsyncState *)ptr); delete ((TestState *)ptr);
} }
static int g_config_index = 0; static int g_config_index = 0;
static int g_clock_index = 0; static int g_clock_index = 0;
static int g_async_index = 0; static int g_state_index = 0;
static bool SetConfigPtr(SSL *ssl, const TestConfig *config) { static bool SetConfigPtr(SSL *ssl, const TestConfig *config) {
return SSL_set_ex_data(ssl, g_config_index, (void *)config) == 1; return SSL_set_ex_data(ssl, g_config_index, (void *)config) == 1;
@ -74,16 +75,16 @@ static OPENSSL_timeval *GetClockPtr(SSL *ssl) {
return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_clock_index); return (OPENSSL_timeval *)SSL_get_ex_data(ssl, g_clock_index);
} }
static bool SetAsyncState(SSL *ssl, std::unique_ptr<AsyncState> async) { static bool SetTestState(SSL *ssl, std::unique_ptr<TestState> async) {
if (SSL_set_ex_data(ssl, g_async_index, (void *)async.get()) == 1) { if (SSL_set_ex_data(ssl, g_state_index, (void *)async.get()) == 1) {
async.release(); async.release();
return true; return true;
} }
return false; return false;
} }
static AsyncState *GetAsyncState(SSL *ssl) { static TestState *GetTestState(SSL *ssl) {
return (AsyncState *)SSL_get_ex_data(ssl, g_async_index); return (TestState *)SSL_get_ex_data(ssl, g_state_index);
} }
static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) { static ScopedEVP_PKEY LoadPrivateKey(const std::string &file) {
@ -110,12 +111,9 @@ static bool InstallCertificate(SSL *ssl) {
return true; return true;
} }
static int g_early_callback_called = 0;
static int SelectCertificateCallback(const struct ssl_early_callback_ctx *ctx) { static int SelectCertificateCallback(const struct ssl_early_callback_ctx *ctx) {
g_early_callback_called = 1;
const TestConfig *config = GetConfigPtr(ctx->ssl); const TestConfig *config = GetConfigPtr(ctx->ssl);
GetTestState(ctx->ssl)->early_callback_called = true;
if (config->expected_server_name.empty()) { if (config->expected_server_name.empty()) {
return 1; return 1;
@ -274,11 +272,11 @@ static void CurrentTimeCallback(SSL *ssl, OPENSSL_timeval *out_clock) {
} }
static void ChannelIdCallback(SSL *ssl, EVP_PKEY **out_pkey) { static void ChannelIdCallback(SSL *ssl, EVP_PKEY **out_pkey) {
*out_pkey = GetAsyncState(ssl)->channel_id.release(); *out_pkey = GetTestState(ssl)->channel_id.release();
} }
static int CertCallback(SSL *ssl, void *arg) { static int CertCallback(SSL *ssl, void *arg) {
if (!GetAsyncState(ssl)->cert_ready) { if (!GetTestState(ssl)->cert_ready) {
return -1; return -1;
} }
if (!InstallCertificate(ssl)) { if (!InstallCertificate(ssl)) {
@ -289,7 +287,7 @@ static int CertCallback(SSL *ssl, void *arg) {
static SSL_SESSION *GetSessionCallback(SSL *ssl, uint8_t *data, int len, static SSL_SESSION *GetSessionCallback(SSL *ssl, uint8_t *data, int len,
int *copy) { int *copy) {
AsyncState *async_state = GetAsyncState(ssl); TestState *async_state = GetTestState(ssl);
if (async_state->session) { if (async_state->session) {
*copy = 0; *copy = 0;
return async_state->session.release(); return async_state->session.release();
@ -395,15 +393,15 @@ static int RetryAsync(SSL *ssl, int ret, BIO *async,
AsyncBioAllowWrite(async, 1); AsyncBioAllowWrite(async, 1);
return 1; return 1;
case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP: case SSL_ERROR_WANT_CHANNEL_ID_LOOKUP:
GetAsyncState(ssl)->channel_id = GetTestState(ssl)->channel_id =
LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id); LoadPrivateKey(GetConfigPtr(ssl)->send_channel_id);
return 1; return 1;
case SSL_ERROR_WANT_X509_LOOKUP: case SSL_ERROR_WANT_X509_LOOKUP:
GetAsyncState(ssl)->cert_ready = true; GetTestState(ssl)->cert_ready = true;
return 1; return 1;
case SSL_ERROR_PENDING_SESSION: case SSL_ERROR_PENDING_SESSION:
GetAsyncState(ssl)->session = GetTestState(ssl)->session =
std::move(GetAsyncState(ssl)->pending_session); std::move(GetTestState(ssl)->pending_session);
return 1; return 1;
default: default:
return 0; return 0;
@ -413,8 +411,6 @@ static int RetryAsync(SSL *ssl, int ret, BIO *async,
static int DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx, static int DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
const TestConfig *config, bool is_resume, const TestConfig *config, bool is_resume,
int fd, SSL_SESSION *session) { int fd, SSL_SESSION *session) {
g_early_callback_called = 0;
OPENSSL_timeval clock = {0}, clock_delta = {0}; OPENSSL_timeval clock = {0}, clock_delta = {0};
ScopedSSL ssl(SSL_new(ssl_ctx)); ScopedSSL ssl(SSL_new(ssl_ctx));
if (!ssl) { if (!ssl) {
@ -424,7 +420,7 @@ static int DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
if (!SetConfigPtr(ssl.get(), config) || if (!SetConfigPtr(ssl.get(), config) ||
!SetClockPtr(ssl.get(), &clock) | !SetClockPtr(ssl.get(), &clock) |
!SetAsyncState(ssl.get(), std::unique_ptr<AsyncState>(new AsyncState))) { !SetTestState(ssl.get(), std::unique_ptr<TestState>(new TestState))) {
BIO_print_errors_fp(stdout); BIO_print_errors_fp(stdout);
return 1; return 1;
} }
@ -564,7 +560,7 @@ static int DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
} else if (config->async) { } else if (config->async) {
// The internal session cache is disabled, so install the session // The internal session cache is disabled, so install the session
// manually. // manually.
GetAsyncState(ssl.get())->pending_session.reset( GetTestState(ssl.get())->pending_session.reset(
SSL_SESSION_up_ref(session)); SSL_SESSION_up_ref(session));
} }
} }
@ -605,7 +601,7 @@ static int DoExchange(ScopedSSL_SESSION *out_session, SSL_CTX *ssl_ctx,
return 2; return 2;
} }
if (!g_early_callback_called) { if (!GetTestState(ssl.get())->early_callback_called) {
fprintf(stderr, "early callback not called\n"); fprintf(stderr, "early callback not called\n");
return 2; return 2;
} }
@ -823,8 +819,8 @@ int main(int argc, char **argv) {
} }
g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); g_config_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
g_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); g_clock_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
g_async_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, AsyncExFree); g_state_index = SSL_get_ex_new_index(0, NULL, NULL, NULL, TestStateExFree);
if (g_config_index < 0 || g_clock_index < 0 || g_async_index < 0) { if (g_config_index < 0 || g_clock_index < 0 || g_state_index < 0) {
return 1; return 1;
} }