Convert rsa/padding.c to constant-time helpers.

Remove the custom copy of those helpers.

Change-Id: I810c3ae8dbf7bc0654d3e9fb9900c425d36f64aa
Reviewed-on: https://boringssl-review.googlesource.com/6611
Reviewed-by: Adam Langley <agl@google.com>
This commit is contained in:
David Benjamin 2015-12-01 20:32:22 -05:00 committed by Adam Langley
parent b36a395a9a
commit 1634a33495

View File

@ -56,6 +56,7 @@
#include <openssl/rsa.h>
#include <assert.h>
#include <limits.h>
#include <string.h>
#include <openssl/digest.h>
@ -65,6 +66,7 @@
#include <openssl/sha.h>
#include "internal.h"
#include "../internal.h"
/* TODO(fork): don't the check functions have to be constant time? */
@ -191,51 +193,30 @@ int RSA_padding_add_PKCS1_type_2(uint8_t *to, unsigned tlen,
return 1;
}
/* constant_time_byte_eq returns 1 if |x| == |y| and 0 otherwise. */
static int constant_time_byte_eq(unsigned char a, unsigned char b) {
unsigned char z = ~(a ^ b);
z &= z >> 4;
z &= z >> 2;
z &= z >> 1;
return z;
}
/* constant_time_select returns |x| if |v| is 1 and |y| if |v| is 0.
* Its behavior is undefined if |v| takes any other value. */
static int constant_time_select(int v, int x, int y) {
return ((~(v - 1)) & x) | ((v - 1) & y);
}
/* constant_time_le returns 1 if |x| <= |y| and 0 otherwise.
* |x| and |y| must be positive. */
static int constant_time_le(int x, int y) {
return ((x - y - 1) >> (sizeof(int) * 8 - 1)) & 1;
}
int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len,
size_t *out_index) {
size_t i;
int first_byte_is_zero, second_byte_is_two, looking_for_index;
int valid_index, zero_index = 0;
unsigned first_byte_is_zero, second_byte_is_two, looking_for_index;
unsigned valid_index, zero_index = 0;
/* PKCS#1 v1.5 decryption. See "PKCS #1 v2.2: RSA Cryptography
* Standard", section 7.2.2. */
if (from_len < RSA_PKCS1_PADDING_SIZE) {
if (from_len < RSA_PKCS1_PADDING_SIZE || from_len > UINT_MAX) {
/* |from| is zero-padded to the size of the RSA modulus, a public value, so
* this can be rejected in non-constant time. */
* this can be rejected in non-constant time. This logic also requires
* |from_len| fit in an |unsigned|. */
*out_index = 0;
return 0;
}
first_byte_is_zero = constant_time_byte_eq(from[0], 0);
second_byte_is_two = constant_time_byte_eq(from[1], 2);
first_byte_is_zero = constant_time_eq(from[0], 0);
second_byte_is_two = constant_time_eq(from[1], 2);
looking_for_index = 1;
looking_for_index = ~0u;
for (i = 2; i < from_len; i++) {
int equals0 = constant_time_byte_eq(from[i], 0);
zero_index =
constant_time_select(looking_for_index & equals0, i, zero_index);
unsigned equals0 = constant_time_is_zero(from[i]);
zero_index = constant_time_select(looking_for_index & equals0, (unsigned)i,
zero_index);
looking_for_index = constant_time_select(equals0, 0, looking_for_index);
}
@ -247,13 +228,13 @@ int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len,
valid_index &= ~looking_for_index;
/* PS must be at least 8 bytes long, and it starts two bytes into |from|. */
valid_index &= constant_time_le(2 + 8, zero_index);
valid_index &= constant_time_ge(zero_index, 2 + 8);
/* Skip the zero byte. */
zero_index++;
*out_index = constant_time_select(valid_index, zero_index, 0);
return valid_index;
return constant_time_select(valid_index, 1, 0);
}
int RSA_padding_check_PKCS1_type_2(uint8_t *to, unsigned tlen,
@ -421,10 +402,9 @@ int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned tlen,
const uint8_t *from, unsigned flen,
const uint8_t *param, unsigned plen,
const EVP_MD *md, const EVP_MD *mgf1md) {
unsigned i, dblen, mlen = -1, mdlen;
unsigned i, dblen, mlen = -1, mdlen, bad, looking_for_one_byte, one_index;
const uint8_t *maskeddb, *maskedseed;
uint8_t *db = NULL, seed[EVP_MAX_MD_SIZE], phash[EVP_MAX_MD_SIZE];
int bad, looking_for_one_byte, one_index = 0;
if (md == NULL) {
md = EVP_sha1();
@ -472,15 +452,14 @@ int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned tlen,
goto err;
}
bad = CRYPTO_memcmp(db, phash, mdlen);
bad |= from[0];
bad = ~constant_time_is_zero(CRYPTO_memcmp(db, phash, mdlen));
bad |= ~constant_time_is_zero(from[0]);
looking_for_one_byte = 1;
looking_for_one_byte = ~0u;
for (i = mdlen; i < dblen; i++) {
int equals1 = constant_time_byte_eq(db[i], 1);
int equals0 = constant_time_byte_eq(db[i], 0);
one_index =
constant_time_select(looking_for_one_byte & equals1, i, one_index);
unsigned equals1 = constant_time_eq(db[i], 1);
unsigned equals0 = constant_time_eq(db[i], 0);
one_index = constant_time_select(looking_for_one_byte & equals1, i, one_index);
looking_for_one_byte =
constant_time_select(equals1, 0, looking_for_one_byte);
bad |= looking_for_one_byte & ~equals0;