diff --git a/crypto/rsa_extra/rsa_test.cc b/crypto/rsa_extra/rsa_test.cc index 211b6903..b0a0b7e4 100644 --- a/crypto/rsa_extra/rsa_test.cc +++ b/crypto/rsa_extra/rsa_test.cc @@ -72,6 +72,11 @@ #include "../internal.h" #include "../test/test_util.h" +#if !defined(OPENSSL_NO_THREADS) +#include +#include +#endif + // kPlaintext is a sample plaintext. static const uint8_t kPlaintext[] = "\x54\x85\x9b\x34\x2c\x49\xea\x2a"; @@ -1042,3 +1047,73 @@ TEST(RSATest, SqrtTwo) { EXPECT_EQ(3072u / 2u, bits); } #endif // !BORINGSSL_SHARED_LIBRARY + +#if !defined(OPENSSL_NO_THREADS) +TEST(RSATest, Threads) { + bssl::UniquePtr rsa_template( + RSA_private_key_from_bytes(kKey1, sizeof(kKey1) - 1)); + ASSERT_TRUE(rsa_template); + + const uint8_t kDummyHash[32] = {0}; + uint8_t sig[256]; + unsigned sig_len = sizeof(sig); + ASSERT_LE(RSA_size(rsa_template.get()), sizeof(sig)); + EXPECT_TRUE(RSA_sign(NID_sha256, kDummyHash, sizeof(kDummyHash), sig, + &sig_len, rsa_template.get())); + + // RSA keys may be assembled piece-meal and then used in parallel between + // threads, which requires internal locking to create some derived properties. + bssl::UniquePtr rsa(RSA_new()); + rsa->n = BN_dup(rsa_template->n); + ASSERT_TRUE(rsa->n); + rsa->e = BN_dup(rsa_template->e); + ASSERT_TRUE(rsa->e); + rsa->d = BN_dup(rsa_template->d); + ASSERT_TRUE(rsa->d); + rsa->p = BN_dup(rsa_template->p); + ASSERT_TRUE(rsa->p); + rsa->q = BN_dup(rsa_template->q); + ASSERT_TRUE(rsa->q); + rsa->dmp1 = BN_dup(rsa_template->dmp1); + ASSERT_TRUE(rsa->dmp1); + rsa->dmq1 = BN_dup(rsa_template->dmq1); + ASSERT_TRUE(rsa->dmq1); + rsa->iqmp = BN_dup(rsa_template->iqmp); + ASSERT_TRUE(rsa->iqmp); + + // Each of these operations must be safe to do concurrently on different + // threads. + auto raw_access = [&] { EXPECT_EQ(0, BN_cmp(rsa->d, rsa_template->d)); }; + auto getter = [&] { + const BIGNUM *d; + RSA_get0_key(rsa.get(), nullptr, nullptr, &d); + EXPECT_EQ(0, BN_cmp(d, rsa_template->d)); + }; + auto sign = [&] { + uint8_t sig2[256]; + unsigned sig2_len = sizeof(sig2); + ASSERT_LE(RSA_size(rsa.get()), sizeof(sig2)); + EXPECT_TRUE(RSA_sign(NID_sha256, kDummyHash, sizeof(kDummyHash), sig2, + &sig2_len, rsa.get())); + // RSASSA-PKCS1-v1_5 is deterministic. + EXPECT_EQ(Bytes(sig, sig_len), Bytes(sig2, sig2_len)); + }; + auto verify = [&] { + EXPECT_TRUE(RSA_verify(NID_sha256, kDummyHash, sizeof(kDummyHash), sig, + sig_len, rsa.get())); + }; + + std::vector threads; + threads.emplace_back(raw_access); + threads.emplace_back(raw_access); + threads.emplace_back(getter); + threads.emplace_back(getter); + threads.emplace_back(sign); + threads.emplace_back(sign); + threads.emplace_back(verify); + threads.emplace_back(verify); + for (auto &thread : threads) { + thread.join(); + } +} +#endif