diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc index 7636f302..e7e04f18 100644 --- a/crypto/bn/bn_test.cc +++ b/crypto/bn/bn_test.cc @@ -1316,23 +1316,23 @@ static bool test_exp(FILE *fp, BN_CTX *ctx) { // test_exp_mod_zero tests that 1**0 mod 1 == 0. static bool test_exp_mod_zero(void) { - ScopedBIGNUM zero(BN_new()); - if (!zero) { + ScopedBIGNUM zero(BN_new()), a(BN_new()), r(BN_new()); + if (!zero || !a || !r || !BN_rand(a.get(), 1024, 0, 0)) { return false; } BN_zero(zero.get()); - ScopedBN_CTX ctx(BN_CTX_new()); - ScopedBIGNUM r(BN_new()); - if (!ctx || !r || - !BN_mod_exp(r.get(), BN_value_one(), zero.get(), BN_value_one(), ctx.get())) { - return false; - } - - if (!BN_is_zero(r.get())) { - fprintf(stderr, "1**0 mod 1 = "); - BN_print_fp(stderr, r.get()); - fprintf(stderr, ", should be 0\n"); + if (!BN_mod_exp(r.get(), a.get(), zero.get(), BN_value_one(), nullptr) || + !BN_is_zero(r.get()) || + !BN_mod_exp_mont(r.get(), a.get(), zero.get(), BN_value_one(), nullptr, + nullptr) || + !BN_is_zero(r.get()) || + !BN_mod_exp_mont_consttime(r.get(), a.get(), zero.get(), BN_value_one(), + nullptr, nullptr) || + !BN_is_zero(r.get()) || + !BN_mod_exp_mont_word(r.get(), 42, zero.get(), BN_value_one(), nullptr, + nullptr) || + !BN_is_zero(r.get())) { return false; } diff --git a/crypto/bn/exponentiation.c b/crypto/bn/exponentiation.c index c580248c..72a8db4b 100644 --- a/crypto/bn/exponentiation.c +++ b/crypto/bn/exponentiation.c @@ -445,8 +445,12 @@ static int mod_exp_recp(BIGNUM *r, const BIGNUM *a, const BIGNUM *p, bits = BN_num_bits(p); if (bits == 0) { - ret = BN_one(r); - return ret; + /* x**0 mod 1 is still zero. */ + if (BN_is_one(m)) { + BN_zero(r); + return 1; + } + return BN_one(r); } BN_CTX_start(ctx); @@ -632,8 +636,12 @@ int BN_mod_exp_mont(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p, } bits = BN_num_bits(p); if (bits == 0) { - ret = BN_one(rr); - return ret; + /* x**0 mod 1 is still zero. */ + if (BN_is_one(m)) { + BN_zero(rr); + return 1; + } + return BN_one(rr); } BN_CTX_start(ctx); @@ -875,8 +883,12 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p, bits = BN_num_bits(p); if (bits == 0) { - ret = BN_one(rr); - return ret; + /* x**0 mod 1 is still zero. */ + if (BN_is_one(m)) { + BN_zero(rr); + return 1; + } + return BN_one(rr); } BN_CTX_start(ctx); @@ -1230,17 +1242,14 @@ int BN_mod_exp_mont_word(BIGNUM *rr, BN_ULONG a, const BIGNUM *p, if (bits == 0) { /* x**0 mod 1 is still zero. */ if (BN_is_one(m)) { - ret = 1; BN_zero(rr); - } else { - ret = BN_one(rr); + return 1; } - return ret; + return BN_one(rr); } if (a == 0) { BN_zero(rr); - ret = 1; - return ret; + return 1; } BN_CTX_start(ctx);