Make bn_mul_part_recursive constant-time.

This follows similar lines as the previous cleanups and fixes the
documentation of the preconditions.

And with that, RSA private key operations, provided p and q have the
same bit length, should be constant time, as far as I know. (Though I'm
sure I've missed something.)

bn_cmp_part_words and bn_cmp_words are no longer used and deleted.

Bug: 234
Change-Id: Iceefa39f57e466c214794c69b335c4d2c81f5577
Reviewed-on: https://boringssl-review.googlesource.com/25404
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2018-01-26 11:18:37 -05:00 committed by CQ bot account: commit-bot@chromium.org
parent 6541308ff3
commit 2bf82975ad
3 changed files with 89 additions and 149 deletions

View File

@ -135,48 +135,6 @@ int BN_cmp(const BIGNUM *a, const BIGNUM *b) {
return 0; return 0;
} }
int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n) {
int i;
BN_ULONG aa, bb;
aa = a[n - 1];
bb = b[n - 1];
if (aa != bb) {
return (aa > bb) ? 1 : -1;
}
for (i = n - 2; i >= 0; i--) {
aa = a[i];
bb = b[i];
if (aa != bb) {
return (aa > bb) ? 1 : -1;
}
}
return 0;
}
int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl) {
int n, i;
n = cl - 1;
if (dl < 0) {
for (i = dl; i < 0; i++) {
if (b[n - i] != 0) {
return -1; // a < b
}
}
}
if (dl > 0) {
for (i = dl; i > 0; i--) {
if (a[n + i] != 0) {
return 1; // a > b
}
}
}
return bn_cmp_words(a, b, cl);
}
static int bn_less_than_words_impl(const BN_ULONG *a, size_t a_len, static int bn_less_than_words_impl(const BN_ULONG *a, size_t a_len,
const BN_ULONG *b, size_t b_len) { const BN_ULONG *b, size_t b_len) {
OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t), OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),

View File

@ -280,16 +280,6 @@ void bn_sqr_comba8(BN_ULONG r[16], const BN_ULONG a[4]);
// bn_sqr_comba4 sets |r| to |a|^2. // bn_sqr_comba4 sets |r| to |a|^2.
void bn_sqr_comba4(BN_ULONG r[8], const BN_ULONG a[4]); void bn_sqr_comba4(BN_ULONG r[8], const BN_ULONG a[4]);
// bn_cmp_words returns a value less than, equal to or greater than zero if
// the, length |n|, array |a| is less than, equal to or greater than |b|.
int bn_cmp_words(const BN_ULONG *a, const BN_ULONG *b, int n);
// bn_cmp_words returns a value less than, equal to or greater than zero if the
// array |a| is less than, equal to or greater than |b|. The arrays can be of
// different lengths: |cl| gives the minimum of the two lengths and |dl| gives
// the length of |a| minus the length of |b|.
int bn_cmp_part_words(const BN_ULONG *a, const BN_ULONG *b, int cl, int dl);
// bn_less_than_words returns one if |a| < |b| and zero otherwise, where |a| // bn_less_than_words returns one if |a| < |b| and zero otherwise, where |a|
// and |b| both are |len| words long. It runs in constant time. // and |b| both are |len| words long. It runs in constant time.
int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len); int bn_less_than_words(const BN_ULONG *a, const BN_ULONG *b, size_t len);

View File

@ -409,69 +409,63 @@ static void bn_mul_recursive(BN_ULONG *r, const BN_ULONG *a, const BN_ULONG *b,
assert(c == 0); assert(c == 0);
} }
// n+tn is the word length // bn_mul_part_recursive sets |r| to |a| * |b|, using |t| as scratch space. |r|
// t needs to be n*4 is size, as does r // has length 4*|n|, |a| has length |n| + |tna|, |b| has length |n| + |tnb|, and
// tnX may not be negative but less than n // |t| has length 8*|n|. |n| must be a power of two. Additionally, we must have
// 0 <= tna < n and 0 <= tnb < n, and |tna| and |tnb| must differ by at most
// one.
//
// TODO(davidben): Make this take |size_t| and perhaps the actual lengths of |a|
// and |b|.
static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a, static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
const BN_ULONG *b, int n, int tna, int tnb, const BN_ULONG *b, int n, int tna, int tnb,
BN_ULONG *t) { BN_ULONG *t) {
int i, j, n2 = n * 2; // |n| is a power of two.
int c1, c2, neg; assert(n != 0 && (n & (n - 1)) == 0);
BN_ULONG ln, lo, *p; // Check |tna| and |tnb| are in range.
assert(0 <= tna && tna < n);
assert(0 <= tnb && tnb < n);
assert(-1 <= tna - tnb && tna - tnb <= 1);
int n2 = n * 2;
if (n < 8) { if (n < 8) {
bn_mul_normal(r, a, n + tna, b, n + tnb); bn_mul_normal(r, a, n + tna, b, n + tnb);
OPENSSL_memset(r + n2 + tna + tnb, 0, n2 - tna - tnb);
return; return;
} }
// TODO(davidben): This function is not constant-time, but should be. See // Split |a| and |b| into a0,a1 and b0,b1, where a0 and b0 have size |n|. |a1|
// https://crbug.com/boringssl/234. // and |b1| have size |tna| and |tnb|, respectively.
// Split |t| into t0,t1,t2,t3, each of size |n|, with the remaining 4*|n| used
// for recursive calls.
// Split |r| into r0,r1,r2,r3. We must contribute a0*b0 to r0,r1, a0*a1+b0*b1
// to r1,r2, and a1*b1 to r2,r3. The middle term we will compute as:
//
// a0*a1 + b0*b1 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0
// r=(a[0]-a[1])*(b[1]-b[0]) // t0 = a0 - a1 and t1 = b1 - b0. The result will be multiplied, so we XOR
c1 = bn_cmp_part_words(a, &(a[n]), tna, n - tna); // their sign masks, giving the sign of (a0 - a1)*(b1 - b0). t0 and t1
c2 = bn_cmp_part_words(&(b[n]), b, tnb, tnb - n); // themselves store the absolute value.
neg = 0; BN_ULONG neg = bn_abs_sub_part_words(t, a, &a[n], tna, n - tna, &t[n2]);
switch (c1 * 3 + c2) { neg ^= bn_abs_sub_part_words(&t[n], &b[n], b, tnb, tnb - n, &t[n2]);
case -4:
bn_sub_part_words(t, &(a[n]), a, tna, tna - n); // -
bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb); // -
break;
case -3:
// break;
case -2:
bn_sub_part_words(t, &(a[n]), a, tna, tna - n); // -
bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n); // +
neg = 1;
break;
case -1:
case 0:
case 1:
// break;
case 2:
bn_sub_part_words(t, a, &(a[n]), tna, n - tna); // +
bn_sub_part_words(&(t[n]), b, &(b[n]), tnb, n - tnb); // -
neg = 1;
break;
case 3:
// break;
case 4:
bn_sub_part_words(t, a, &(a[n]), tna, n - tna);
bn_sub_part_words(&(t[n]), &(b[n]), b, tnb, tnb - n);
break;
}
// Compute:
// t2,t3 = t0 * t1 = |(a0 - a1)*(b1 - b0)|
// r0,r1 = a0 * b0
// r2,r3 = a1 * b1
if (n == 8) { if (n == 8) {
bn_mul_comba8(&(t[n2]), t, &(t[n])); bn_mul_comba8(&t[n2], t, &t[n]);
bn_mul_comba8(r, a, b); bn_mul_comba8(r, a, b);
bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb);
OPENSSL_memset(&(r[n2 + tna + tnb]), 0, sizeof(BN_ULONG) * (n2 - tna - tnb)); bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
// |bn_mul_normal| only writes |tna| + |tna| words. Zero the rest.
OPENSSL_memset(&r[n2 + tna + tnb], 0, sizeof(BN_ULONG) * (n2 - tna - tnb));
} else { } else {
p = &(t[n2 * 2]); BN_ULONG *p = &t[n2 * 2];
bn_mul_recursive(&(t[n2]), t, &(t[n]), n, 0, 0, p); bn_mul_recursive(&t[n2], t, &t[n], n, 0, 0, p);
bn_mul_recursive(r, a, b, n, 0, 0, p); bn_mul_recursive(r, a, b, n, 0, 0, p);
i = n / 2;
// If there is only a bottom half to the number, int i = n / 2, j;
// just do it
if (tna > tnb) { if (tna > tnb) {
j = tna - i; j = tna - i;
} else { } else {
@ -479,32 +473,37 @@ static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
} }
if (j == 0) { if (j == 0) {
bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p); // If there is only a bottom half to the number, just do it. We know the
OPENSSL_memset(&(r[n2 + i * 2]), 0, sizeof(BN_ULONG) * (n2 - i * 2)); // larger of |tna - i| and |tnb - i| is zero. The other is zero or -1
// because |tna| and |tnb| differ by at most one.
bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
// |bn_mul_recursive| only writes the bottom |i|*2 words.
OPENSSL_memset(&r[n2 + i * 2], 0, sizeof(BN_ULONG) * (n2 - i * 2));
} else if (j > 0) { } else if (j > 0) {
// eg, n == 16, i == 8 and tn == 11 // E.g,, n == 16, i == 8 and tna == 11.
bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, p); // |tna| and |tnb| are within one of each other, so if |tna| is larger and
OPENSSL_memset(&(r[n2 + tna + tnb]), 0, // tna > i, then we know tnb >= i, and this call is valid.
sizeof(BN_ULONG) * (n2 - tna - tnb)); bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
} else { } else {
// (j < 0) eg, n == 16, i == 8 and tn == 5 // (j < 0) E.g., n == 16, i == 8 and tn == 5
OPENSSL_memset(&(r[n2]), 0, sizeof(BN_ULONG) * n2); OPENSSL_memset(&r[n2], 0, sizeof(BN_ULONG) * n2);
if (tna < BN_MUL_RECURSIVE_SIZE_NORMAL && if (tna < BN_MUL_RECURSIVE_SIZE_NORMAL &&
tnb < BN_MUL_RECURSIVE_SIZE_NORMAL) { tnb < BN_MUL_RECURSIVE_SIZE_NORMAL) {
bn_mul_normal(&(r[n2]), &(a[n]), tna, &(b[n]), tnb); bn_mul_normal(&r[n2], &a[n], tna, &b[n], tnb);
} else { } else {
for (;;) { for (;;) {
i /= 2; i /= 2;
// these simplified conditions work // These simplified conditions work exclusively because difference
// exclusively because difference // between |tna| and |tnb| is 1 or 0.
// between tna and tnb is 1 or 0 //
// TODO(davidben): This loop condition is exactly the same as the
// |j > 0| one but more complicated. Merge them.
if (i < tna || i < tnb) { if (i < tna || i < tnb) {
bn_mul_part_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, bn_mul_part_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
tnb - i, p);
break; break;
} else if (i == tna || i == tnb) { }
bn_mul_recursive(&(r[n2]), &(a[n]), &(b[n]), i, tna - i, tnb - i, if (i == tna || i == tnb) {
p); bn_mul_recursive(&r[n2], &a[n], &b[n], i, tna - i, tnb - i, p);
break; break;
} }
} }
@ -512,42 +511,32 @@ static void bn_mul_part_recursive(BN_ULONG *r, const BN_ULONG *a,
} }
} }
// t[32] holds (a[0]-a[1])*(b[1]-b[0]), c1 is the sign // t0,t1,c = r0,r1 + r2,r3 = a0*b0 + a1*b1
// r[10] holds (a[0]*b[0]) BN_ULONG c = bn_add_words(t, r, &r[n2], n2);
// r[32] holds (b[1]*b[1])
c1 = (int)(bn_add_words(t, r, &(r[n2]), n2)); // t2,t3,c = t0,t1,c + neg*t2,t3 = (a0 - a1)*(b1 - b0) + a1*b1 + a0*b0.
// The second term is stored as the absolute value, so we do this with a
// constant-time select.
BN_ULONG c_neg = c - bn_sub_words(&t[n2 * 2], t, &t[n2], n2);
BN_ULONG c_pos = c + bn_add_words(&t[n2], t, &t[n2], n2);
bn_select_words(&t[n2], neg, &t[n2 * 2], &t[n2], n2);
OPENSSL_COMPILE_ASSERT(sizeof(BN_ULONG) <= sizeof(crypto_word_t),
crypto_word_t_too_small);
c = constant_time_select_w(neg, c_neg, c_pos);
if (neg) { // We now have our three components. Add them together.
// if t[32] is negative // r1,r2,c = r1,r2 + t2,t3,c
c1 -= (int)(bn_sub_words(&(t[n2]), t, &(t[n2]), n2)); c += bn_add_words(&r[n], &r[n], &t[n2], n2);
} else {
// Might have a carry // Propagate the carry bit to the end.
c1 += (int)(bn_add_words(&(t[n2]), &(t[n2]), t, n2)); for (int i = n + n2; i < n2 + n2; i++) {
BN_ULONG old = r[i];
r[i] = old + c;
c = r[i] < old;
} }
// t[32] holds (a[0]-a[1])*(b[1]-b[0])+(a[0]*b[0])+(a[1]*b[1]) // The product should fit without carries.
// r[10] holds (a[0]*b[0]) assert(c == 0);
// r[32] holds (b[1]*b[1])
// c1 holds the carry bits
c1 += (int)(bn_add_words(&(r[n]), &(r[n]), &(t[n2]), n2));
if (c1) {
p = &(r[n + n2]);
lo = *p;
ln = lo + c1;
*p = ln;
// The overflow will stop before we over write
// words we should not overwrite
if (ln < (BN_ULONG)c1) {
do {
p++;
lo = *p;
ln = lo + 1;
*p = ln;
} while (ln == 0);
}
}
} }
// bn_mul_impl implements |BN_mul| and |bn_mul_fixed|. Note this function breaks // bn_mul_impl implements |BN_mul| and |bn_mul_fixed|. Note this function breaks
@ -605,6 +594,9 @@ static int bn_mul_impl(BIGNUM *r, const BIGNUM *a, const BIGNUM *b,
goto err; goto err;
} }
if (al > j || bl > j) { if (al > j || bl > j) {
// We know |al| and |bl| are at most one from each other, so if al > j,
// bl >= j, and vice versa. Thus we can use |bn_mul_part_recursive|.
assert(al >= j && bl >= j);
// TODO(davidben): Check that these are correctly-sized, after rewriting // TODO(davidben): Check that these are correctly-sized, after rewriting
// |bn_mul_part_recursive|. // |bn_mul_part_recursive|.
if (!bn_wexpand(t, j * 8) || if (!bn_wexpand(t, j * 8) ||