diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index c976e7c2..8c1e75e8 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc @@ -48,6 +48,20 @@ static const TestConfig *GetConfigPtr(SSL *ssl) { return (const TestConfig *)SSL_get_ex_data(ssl, g_ex_data_index); } +static EVP_PKEY *LoadPrivateKey(const std::string &file) { + BIO *bio = BIO_new(BIO_s_file()); + if (bio == NULL) { + return NULL; + } + if (!BIO_read_filename(bio, file.c_str())) { + BIO_free(bio); + return NULL; + } + EVP_PKEY *pkey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL); + BIO_free(bio); + return pkey; +} + static int early_callback_called = 0; static int select_certificate_callback(const struct ssl_early_callback_ctx *ctx) { @@ -205,6 +219,8 @@ static SSL_CTX *setup_ctx(const TestConfig *config) { SSL_CTX_set_cookie_generate_cb(ssl_ctx, cookie_generate_callback); SSL_CTX_set_cookie_verify_cb(ssl_ctx, cookie_verify_callback); + ssl_ctx->tlsext_channel_id_enabled_new = 1; + DH_free(dh); return ssl_ctx; @@ -300,6 +316,23 @@ static int do_exchange(SSL_SESSION **out_session, if (config->cookie_exchange) { SSL_set_options(ssl, SSL_OP_COOKIE_EXCHANGE); } + if (!config->expected_channel_id.empty()) { + SSL_enable_tls_channel_id(ssl); + } + if (!config->send_channel_id.empty()) { + EVP_PKEY *pkey = LoadPrivateKey(config->send_channel_id); + if (pkey == NULL) { + BIO_print_errors_fp(stdout); + return 1; + } + SSL_enable_tls_channel_id(ssl); + if (!SSL_set1_tls_channel_id(ssl, pkey)) { + EVP_PKEY_free(pkey); + BIO_print_errors_fp(stdout); + return 1; + } + EVP_PKEY_free(pkey); + } BIO *bio = BIO_new_fd(fd, 1 /* take ownership */); if (bio == NULL) { @@ -363,7 +396,7 @@ static int do_exchange(SSL_SESSION **out_session, if (!config->expected_certificate_types.empty()) { uint8_t *certificate_types; int num_certificate_types = - SSL_get0_certificate_types(ssl, &certificate_types); + SSL_get0_certificate_types(ssl, &certificate_types); if (num_certificate_types != (int)config->expected_certificate_types.size() || memcmp(certificate_types, @@ -386,6 +419,20 @@ static int do_exchange(SSL_SESSION **out_session, } } + if (!config->expected_channel_id.empty()) { + uint8_t channel_id[64]; + if (!SSL_get_tls_channel_id(ssl, channel_id, sizeof(channel_id))) { + fprintf(stderr, "no channel id negotiated\n"); + return 2; + } + if (config->expected_channel_id.size() != 64 || + memcmp(config->expected_channel_id.data(), + channel_id, 64) != 0) { + fprintf(stderr, "channel id mismatch\n"); + return 2; + } + } + if (config->write_different_record_sizes) { if (config->is_dtls) { fprintf(stderr, "write_different_record_sizes not supported for DTLS\n"); diff --git a/ssl/test/runner/channel_id_key.pem b/ssl/test/runner/channel_id_key.pem new file mode 100644 index 00000000..604752bc --- /dev/null +++ b/ssl/test/runner/channel_id_key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIPwxu50c7LEhVNRYJFRWBUnoaz7JSos96T5hBp4rjyptoAoGCCqGSM49 +AwEHoUQDQgAEzFSVTE5guxJRQ0VbZ8dicPs5e/DT7xpW7Yc9hq0VOchv7cbXuI/T +CwadDjGWX/oaz0ftFqrVmfkwZu+C58ioWg== +-----END EC PRIVATE KEY----- diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index 360de125..72e8cce2 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -2,8 +2,11 @@ package main import ( "bytes" + "crypto/ecdsa" + "crypto/elliptic" "crypto/x509" "encoding/base64" + "encoding/pem" "flag" "fmt" "io" @@ -26,11 +29,14 @@ const ( ) const ( - rsaKeyFile = "key.pem" - ecdsaKeyFile = "ecdsa_key.pem" + rsaKeyFile = "key.pem" + ecdsaKeyFile = "ecdsa_key.pem" + channelIDKeyFile = "channel_id_key.pem" ) var rsaCertificate, ecdsaCertificate Certificate +var channelIDKey *ecdsa.PrivateKey +var channelIDBytes []byte func initCertificates() { var err error @@ -43,6 +49,26 @@ func initCertificates() { if err != nil { panic(err) } + + channelIDPEMBlock, err := ioutil.ReadFile(channelIDKeyFile) + if err != nil { + panic(err) + } + channelIDDERBlock, _ := pem.Decode(channelIDPEMBlock) + if channelIDDERBlock.Type != "EC PRIVATE KEY" { + panic("bad key type") + } + channelIDKey, err = x509.ParseECPrivateKey(channelIDDERBlock.Bytes) + if err != nil { + panic(err) + } + if channelIDKey.Curve != elliptic.P256() { + panic("bad curve") + } + + channelIDBytes = make([]byte, 64) + writeIntPadded(channelIDBytes[:32], channelIDKey.X) + writeIntPadded(channelIDBytes[32:], channelIDKey.Y) } var certificateOnce sync.Once @@ -84,6 +110,9 @@ type testCase struct { // expectedVersion, if non-zero, specifies the TLS version that must be // negotiated. expectedVersion uint16 + // expectChannelID controls whether the connection should have + // negotiated a Channel ID with channelIDKey. + expectChannelID bool // messageLen is the length, in bytes, of the test message that will be // sent. messageLen int @@ -488,6 +517,18 @@ func doExchange(test *testCase, config *Config, conn net.Conn, messageLen int) e return fmt.Errorf("got version %x, expected %x", vers, test.expectedVersion) } + if test.expectChannelID { + channelID := tlsConn.ConnectionState().ChannelID + if channelID == nil { + return fmt.Errorf("no channel ID negotiated") + } + if channelID.Curve != channelIDKey.Curve || + channelIDKey.X.Cmp(channelIDKey.X) != 0 || + channelIDKey.Y.Cmp(channelIDKey.Y) != 0 { + return fmt.Errorf("incorrect channel ID") + } + } + if messageLen < 0 { if test.protocol == dtls { return fmt.Errorf("messageLen < 0 not supported for DTLS tests") @@ -1141,7 +1182,7 @@ func addStateMachineCoverageTests(async, splitHandshake bool, protocol protocol) ), }) - // Client sends a V2ClientHello. + // Server parses a V2ClientHello. testCases = append(testCases, testCase{ protocol: protocol, testType: serverTest, @@ -1158,6 +1199,42 @@ func addStateMachineCoverageTests(async, splitHandshake bool, protocol protocol) }, flags: flags, }) + + // Client sends a Channel ID. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "ChannelID-Client" + suffix, + config: Config{ + RequestChannelID: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-send-channel-id", channelIDKeyFile, + ), + resumeSession: true, + expectChannelID: true, + }) + + // Server accepts a Channel ID. + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "ChannelID-Server" + suffix, + config: Config{ + ChannelID: channelIDKey, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-expect-channel-id", + base64.StdEncoding.EncodeToString(channelIDBytes), + ), + resumeSession: true, + expectChannelID: true, + }) } else { testCases = append(testCases, testCase{ protocol: protocol, diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index e69d570a..cf0d0af5 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc @@ -64,12 +64,14 @@ const StringFlag kStringFlags[] = { { "-advertise-npn", &TestConfig::advertise_npn }, { "-expect-next-proto", &TestConfig::expected_next_proto }, { "-select-next-proto", &TestConfig::select_next_proto }, + { "-send-channel-id", &TestConfig::send_channel_id }, }; const size_t kNumStringFlags = sizeof(kStringFlags) / sizeof(kStringFlags[0]); const StringFlag kBase64Flags[] = { { "-expect-certificate-types", &TestConfig::expected_certificate_types }, + { "-expect-channel-id", &TestConfig::expected_channel_id }, }; const size_t kNumBase64Flags = sizeof(kBase64Flags) / sizeof(kBase64Flags[0]); diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index 34d720ef..dc5e8d3f 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h @@ -43,6 +43,8 @@ struct TestConfig { bool no_tls1; bool no_ssl3; bool cookie_exchange; + std::string expected_channel_id; + std::string send_channel_id; }; bool ParseConfig(int argc, char **argv, TestConfig *out_config);