diff --git a/crypto/rsa/padding.c b/crypto/rsa/padding.c index 5a42e248..40542923 100644 --- a/crypto/rsa/padding.c +++ b/crypto/rsa/padding.c @@ -56,6 +56,7 @@ #include #include +#include #include #include @@ -65,6 +66,7 @@ #include #include "internal.h" +#include "../internal.h" /* TODO(fork): don't the check functions have to be constant time? */ @@ -191,51 +193,30 @@ int RSA_padding_add_PKCS1_type_2(uint8_t *to, unsigned tlen, return 1; } -/* constant_time_byte_eq returns 1 if |x| == |y| and 0 otherwise. */ -static int constant_time_byte_eq(unsigned char a, unsigned char b) { - unsigned char z = ~(a ^ b); - z &= z >> 4; - z &= z >> 2; - z &= z >> 1; - - return z; -} - -/* constant_time_select returns |x| if |v| is 1 and |y| if |v| is 0. - * Its behavior is undefined if |v| takes any other value. */ -static int constant_time_select(int v, int x, int y) { - return ((~(v - 1)) & x) | ((v - 1) & y); -} - -/* constant_time_le returns 1 if |x| <= |y| and 0 otherwise. - * |x| and |y| must be positive. */ -static int constant_time_le(int x, int y) { - return ((x - y - 1) >> (sizeof(int) * 8 - 1)) & 1; -} - int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len, size_t *out_index) { size_t i; - int first_byte_is_zero, second_byte_is_two, looking_for_index; - int valid_index, zero_index = 0; + unsigned first_byte_is_zero, second_byte_is_two, looking_for_index; + unsigned valid_index, zero_index = 0; /* PKCS#1 v1.5 decryption. See "PKCS #1 v2.2: RSA Cryptography * Standard", section 7.2.2. */ - if (from_len < RSA_PKCS1_PADDING_SIZE) { + if (from_len < RSA_PKCS1_PADDING_SIZE || from_len > UINT_MAX) { /* |from| is zero-padded to the size of the RSA modulus, a public value, so - * this can be rejected in non-constant time. */ + * this can be rejected in non-constant time. This logic also requires + * |from_len| fit in an |unsigned|. */ *out_index = 0; return 0; } - first_byte_is_zero = constant_time_byte_eq(from[0], 0); - second_byte_is_two = constant_time_byte_eq(from[1], 2); + first_byte_is_zero = constant_time_eq(from[0], 0); + second_byte_is_two = constant_time_eq(from[1], 2); - looking_for_index = 1; + looking_for_index = ~0u; for (i = 2; i < from_len; i++) { - int equals0 = constant_time_byte_eq(from[i], 0); - zero_index = - constant_time_select(looking_for_index & equals0, i, zero_index); + unsigned equals0 = constant_time_is_zero(from[i]); + zero_index = constant_time_select(looking_for_index & equals0, (unsigned)i, + zero_index); looking_for_index = constant_time_select(equals0, 0, looking_for_index); } @@ -247,13 +228,13 @@ int RSA_message_index_PKCS1_type_2(const uint8_t *from, size_t from_len, valid_index &= ~looking_for_index; /* PS must be at least 8 bytes long, and it starts two bytes into |from|. */ - valid_index &= constant_time_le(2 + 8, zero_index); + valid_index &= constant_time_ge(zero_index, 2 + 8); /* Skip the zero byte. */ zero_index++; *out_index = constant_time_select(valid_index, zero_index, 0); - return valid_index; + return constant_time_select(valid_index, 1, 0); } int RSA_padding_check_PKCS1_type_2(uint8_t *to, unsigned tlen, @@ -421,10 +402,9 @@ int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned tlen, const uint8_t *from, unsigned flen, const uint8_t *param, unsigned plen, const EVP_MD *md, const EVP_MD *mgf1md) { - unsigned i, dblen, mlen = -1, mdlen; + unsigned i, dblen, mlen = -1, mdlen, bad, looking_for_one_byte, one_index; const uint8_t *maskeddb, *maskedseed; uint8_t *db = NULL, seed[EVP_MAX_MD_SIZE], phash[EVP_MAX_MD_SIZE]; - int bad, looking_for_one_byte, one_index = 0; if (md == NULL) { md = EVP_sha1(); @@ -472,15 +452,14 @@ int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned tlen, goto err; } - bad = CRYPTO_memcmp(db, phash, mdlen); - bad |= from[0]; + bad = ~constant_time_is_zero(CRYPTO_memcmp(db, phash, mdlen)); + bad |= ~constant_time_is_zero(from[0]); - looking_for_one_byte = 1; + looking_for_one_byte = ~0u; for (i = mdlen; i < dblen; i++) { - int equals1 = constant_time_byte_eq(db[i], 1); - int equals0 = constant_time_byte_eq(db[i], 0); - one_index = - constant_time_select(looking_for_one_byte & equals1, i, one_index); + unsigned equals1 = constant_time_eq(db[i], 1); + unsigned equals0 = constant_time_eq(db[i], 0); + one_index = constant_time_select(looking_for_one_byte & equals1, i, one_index); looking_for_one_byte = constant_time_select(equals1, 0, looking_for_one_byte); bad |= looking_for_one_byte & ~equals0;