diff --git a/crypto/fipsmodule/bn/mul.c b/crypto/fipsmodule/bn/mul.c index d97e5848..38c70cad 100644 --- a/crypto/fipsmodule/bn/mul.c +++ b/crypto/fipsmodule/bn/mul.c @@ -555,25 +555,19 @@ static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a, // callers. static int bn_mul_impl(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, BN_CTX *ctx) { - int ret = 0; - int top, al, bl; - BIGNUM *rr; - int i; - BIGNUM *t = NULL; - int j = 0, k; - - al = a->width; - bl = b->width; - - if ((al == 0) || (bl == 0)) { + int al = a->width; + int bl = b->width; + if (al == 0 || bl == 0) { BN_zero(r); return 1; } - top = al + bl; + int ret = 0; + BIGNUM *rr; BN_CTX_start(ctx); - if ((r == a) || (r == b)) { - if ((rr = BN_CTX_get(ctx)) == NULL) { + if (r == a || r == b) { + rr = BN_CTX_get(ctx); + if (r == NULL) { goto err; } } else { @@ -581,7 +575,7 @@ static int bn_mul_impl(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, } rr->neg = a->neg ^ b->neg; - i = al - bl; + int i = al - bl; if (i == 0) { if (al == 8) { if (!bn_wexpand(rr, 16)) { @@ -593,38 +587,37 @@ static int bn_mul_impl(BIGNUM *r, const BIGNUM *a, const BIGNUM *b, } } + int top = al + bl; static const int kMulNormalSize = 16; if (al >= kMulNormalSize && bl >= kMulNormalSize) { - if (i >= -1 && i <= 1) { - /* Find out the power of two lower or equal - to the longest of the two numbers */ + if (-1 <= i && i <= 1) { + // Find the larger power of two less than or equal to the larger length. + int j; if (i >= 0) { j = BN_num_bits_word((BN_ULONG)al); - } - if (i == -1) { + } else { j = BN_num_bits_word((BN_ULONG)bl); } j = 1 << (j - 1); assert(j <= al || j <= bl); - k = j + j; - t = BN_CTX_get(ctx); + BIGNUM *t = BN_CTX_get(ctx); if (t == NULL) { goto err; } if (al > j || bl > j) { - if (!bn_wexpand(t, k * 4)) { - goto err; - } - if (!bn_wexpand(rr, k * 4)) { + // TODO(davidben): Check that these are correctly-sized, after rewriting + // |bn_mul_part_recursive|. + if (!bn_wexpand(t, j * 8) || + !bn_wexpand(rr, j * 8)) { goto err; } bn_mul_part_recursive(rr->d, a->d, b->d, j, al - j, bl - j, t->d); } else { - // al <= j || bl <= j - if (!bn_wexpand(t, k * 2)) { - goto err; - } - if (!bn_wexpand(rr, k * 2)) { + // al <= j && bl <= j. Additionally, we know j <= al or j <= bl, so one + // of al - j or bl - j is zero. The other, by the bound on |i| above, is + // zero or -1. Thus, we can use |bn_mul_recursive|. + if (!bn_wexpand(t, j * 4) || + !bn_wexpand(rr, j * 2)) { goto err; } bn_mul_recursive(rr->d, a->d, b->d, j, al - j, bl - j, t->d);