diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c index 86914dca..2fa4101d 100644 --- a/ssl/t1_lib.c +++ b/ssl/t1_lib.c @@ -296,102 +296,64 @@ done: char ssl_early_callback_init(struct ssl_early_callback_ctx *ctx) { - size_t len = ctx->client_hello_len; - const unsigned char *p = ctx->client_hello; - CBS extensions; + CBS client_hello, session_id, cipher_suites, compression_methods, extensions; + + CBS_init(&client_hello, ctx->client_hello, ctx->client_hello_len); /* Skip client version. */ - if (len < 2) + if (!CBS_skip(&client_hello, 2)) return 0; - len -= 2; p += 2; /* Skip client nonce. */ - if (len < 32) + if (!CBS_skip(&client_hello, 32)) return 0; - len -= 32; p += 32; - /* Get length of session id. */ - if (len < 1) + /* Extract session_id. */ + if (!CBS_get_u8_length_prefixed(&client_hello, &session_id)) return 0; - ctx->session_id_len = *p; - p++; len--; - - ctx->session_id = p; - if (len < ctx->session_id_len) - return 0; - p += ctx->session_id_len; len -= ctx->session_id_len; + ctx->session_id = CBS_data(&session_id); + ctx->session_id_len = CBS_len(&session_id); /* Skip past DTLS cookie */ if (ctx->ssl->version == DTLS1_VERSION || ctx->ssl->version == DTLS1_BAD_VER) { - unsigned cookie_len; + CBS cookie; - if (len < 1) + if (!CBS_get_u8_length_prefixed(&client_hello, &cookie)) return 0; - cookie_len = *p; - p++; len--; - if (len < cookie_len) - return 0; - p += cookie_len; len -= cookie_len; } - /* Skip cipher suites. */ - if (len < 2) + /* Extract cipher_suites. */ + if (!CBS_get_u16_length_prefixed(&client_hello, &cipher_suites) || + CBS_len(&cipher_suites) < 2 || + (CBS_len(&cipher_suites) & 1) != 0) return 0; - n2s(p, ctx->cipher_suites_len); - len -= 2; + ctx->cipher_suites = CBS_data(&cipher_suites); + ctx->cipher_suites_len = CBS_len(&cipher_suites); - if ((ctx->cipher_suites_len & 1) != 0) + /* Extract compression_methods. */ + if (!CBS_get_u8_length_prefixed(&client_hello, &compression_methods) || + CBS_len(&compression_methods) < 1) return 0; - - ctx->cipher_suites = p; - if (len < ctx->cipher_suites_len) - return 0; - p += ctx->cipher_suites_len; len -= ctx->cipher_suites_len; - - /* Skip compression methods. */ - if (len < 1) - return 0; - ctx->compression_methods_len = *p; - p++; len--; - - ctx->compression_methods = p; - if (len < ctx->compression_methods_len) - return 0; - p += ctx->compression_methods_len; len -= ctx->compression_methods_len; + ctx->compression_methods = CBS_data(&compression_methods); + ctx->compression_methods_len = CBS_len(&compression_methods); /* If the ClientHello ends here then it's valid, but doesn't have any * extensions. (E.g. SSLv3.) */ - if (len == 0) + if (CBS_len(&client_hello) == 0) { ctx->extensions = NULL; ctx->extensions_len = 0; return 1; } - if (len < 2) - return 0; - n2s(p, ctx->extensions_len); - len -= 2; - - if (ctx->extensions_len == 0 && len == 0) - { - ctx->extensions = NULL; - return 1; - } - - ctx->extensions = p; - if (len != ctx->extensions_len) - return 0; - - /* Verify that the extensions have valid lengths and that there are - * no duplicates. - * - * TODO(fork): Port the rest of this processing to CBS. - */ - CBS_init(&extensions, ctx->extensions, ctx->extensions_len); - if (!tls1_check_duplicate_extensions(&extensions)) + /* Extract extensions and check it is valid. */ + if (!CBS_get_u16_length_prefixed(&client_hello, &extensions) || + !tls1_check_duplicate_extensions(&extensions) || + CBS_len(&client_hello) != 0) return 0; + ctx->extensions = CBS_data(&extensions); + ctx->extensions_len = CBS_len(&extensions); return 1; } @@ -402,29 +364,26 @@ SSL_early_callback_ctx_extension_get(const struct ssl_early_callback_ctx *ctx, const unsigned char **out_data, size_t *out_len) { - size_t len = ctx->extensions_len; - const unsigned char *p = ctx->extensions; + CBS extensions; - while (len != 0) + CBS_init(&extensions, ctx->extensions, ctx->extensions_len); + + while (CBS_len(&extensions) != 0) { - uint16_t ext_type, ext_len; + uint16_t type; + CBS extension; - if (len < 4) + /* Decode the next extension. */ + if (!CBS_get_u16(&extensions, &type) || + !CBS_get_u16_length_prefixed(&extensions, &extension)) return 0; - n2s(p, ext_type); - n2s(p, ext_len); - len -= 4; - if (len < ext_len) - return 0; - if (ext_type == extension_type) + if (type == extension_type) { - *out_data = p; - *out_len = ext_len; + *out_data = CBS_data(&extension); + *out_len = CBS_len(&extension); return 1; } - - p += ext_len; len -= ext_len; } return 0; diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index 1dde8192..d0117da2 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc @@ -16,10 +16,51 @@ #include #include #include +#include #include #include +#include +const char *expected_server_name = NULL; +int early_callback_called = 0; + +int select_certificate_callback(const struct ssl_early_callback_ctx *ctx) { + early_callback_called = 1; + + if (expected_server_name) { + const unsigned char *extension_data; + size_t extension_len; + CBS extension, server_name_list, host_name; + uint8_t name_type; + + if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name, + &extension_data, + &extension_len)) { + fprintf(stderr, "Could not find server_name extension."); + return -1; + } + + CBS_init(&extension, extension_data, extension_len); + if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) || + CBS_len(&extension) != 0 || + !CBS_get_u8(&server_name_list, &name_type) || + name_type != TLSEXT_NAMETYPE_host_name || + !CBS_get_u16_length_prefixed(&server_name_list, &host_name) || + CBS_len(&server_name_list) != 0) { + fprintf(stderr, "Could not decode server_name extension."); + return -1; + } + + if (CBS_len(&host_name) != strlen(expected_server_name) || + memcmp(expected_server_name, + CBS_data(&host_name), CBS_len(&host_name)) != 0) { + fprintf(stderr, "Server name mismatch."); + } + } + + return 1; +} SSL *setup_test(int is_server) { if (!SSL_library_init()) { @@ -44,6 +85,8 @@ SSL *setup_test(int is_server) { goto err; } + ssl_ctx->select_certificate_cb = select_certificate_callback; + ssl = SSL_new(ssl_ctx); if (ssl == NULL) { goto err; @@ -74,7 +117,6 @@ err: int main(int argc, char **argv) { int i, is_server, ret; - const char *expected_server_name = NULL; if (argc < 2) { fprintf(stderr, "Usage: %s (client|server) [flags...]\n", argv[0]); @@ -153,6 +195,11 @@ int main(int argc, char **argv) { server_name, expected_server_name); return 2; } + + if (!early_callback_called) { + fprintf(stderr, "early callback not called\n"); + return 2; + } } for (;;) {