diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc index 672d83f1..5ff91f21 100644 --- a/crypto/bn/bn_test.cc +++ b/crypto/bn/bn_test.cc @@ -568,21 +568,25 @@ static bool TestModSqrt(FileTest *t, BN_CTX *ctx) { bssl::UniquePtr a = GetBIGNUM(t, "A"); bssl::UniquePtr p = GetBIGNUM(t, "P"); bssl::UniquePtr mod_sqrt = GetBIGNUM(t, "ModSqrt"); - if (!a || !p || !mod_sqrt) { + bssl::UniquePtr mod_sqrt2(BN_new()); + if (!a || !p || !mod_sqrt || !mod_sqrt2 || + // There are two possible answers. + !BN_sub(mod_sqrt2.get(), p.get(), mod_sqrt.get())) { return false; } + // -0 is 0, not P. + if (BN_is_zero(mod_sqrt.get())) { + BN_zero(mod_sqrt2.get()); + } + bssl::UniquePtr ret(BN_new()); - bssl::UniquePtr ret2(BN_new()); if (!ret || - !ret2 || - !BN_mod_sqrt(ret.get(), a.get(), p.get(), ctx) || - // There are two possible answers. - !BN_sub(ret2.get(), p.get(), ret.get())) { + !BN_mod_sqrt(ret.get(), a.get(), p.get(), ctx)) { return false; } - if (BN_cmp(ret2.get(), mod_sqrt.get()) != 0 && + if (BN_cmp(ret.get(), mod_sqrt2.get()) != 0 && !ExpectBIGNUMsEqual(t, "sqrt(A) (mod P)", mod_sqrt.get(), ret.get())) { return false; }