diff --git a/crypto/rsa/padding.c b/crypto/rsa/padding.c index 12147ea4..045faa21 100644 --- a/crypto/rsa/padding.c +++ b/crypto/rsa/padding.c @@ -274,47 +274,46 @@ int RSA_padding_add_none(uint8_t *to, unsigned to_len, const uint8_t *from, return 1; } -static int PKCS1_MGF1(uint8_t *mask, unsigned len, const uint8_t *seed, - unsigned seedlen, const EVP_MD *dgst) { - unsigned outlen = 0; - uint32_t i; - uint8_t cnt[4]; - EVP_MD_CTX c; - uint8_t md[EVP_MAX_MD_SIZE]; - unsigned mdlen; - int ret = -1; +static int PKCS1_MGF1(uint8_t *out, size_t len, const uint8_t *seed, + size_t seed_len, const EVP_MD *md) { + int ret = 0; + EVP_MD_CTX ctx; + EVP_MD_CTX_init(&ctx); - EVP_MD_CTX_init(&c); - mdlen = EVP_MD_size(dgst); + size_t md_len = EVP_MD_size(md); - for (i = 0; outlen < len; i++) { - cnt[0] = (uint8_t)((i >> 24) & 255); - cnt[1] = (uint8_t)((i >> 16) & 255); - cnt[2] = (uint8_t)((i >> 8)) & 255; - cnt[3] = (uint8_t)(i & 255); - if (!EVP_DigestInit_ex(&c, dgst, NULL) || - !EVP_DigestUpdate(&c, seed, seedlen) || - !EVP_DigestUpdate(&c, cnt, 4)) { + for (uint32_t i = 0; len > 0; i++) { + uint8_t counter[4]; + counter[0] = (uint8_t)(i >> 24); + counter[1] = (uint8_t)(i >> 16); + counter[2] = (uint8_t)(i >> 8); + counter[3] = (uint8_t)i; + if (!EVP_DigestInit_ex(&ctx, md, NULL) || + !EVP_DigestUpdate(&ctx, seed, seed_len) || + !EVP_DigestUpdate(&ctx, counter, sizeof(counter))) { goto err; } - if (outlen + mdlen <= len) { - if (!EVP_DigestFinal_ex(&c, mask + outlen, NULL)) { + if (md_len <= len) { + if (!EVP_DigestFinal_ex(&ctx, out, NULL)) { goto err; } - outlen += mdlen; + out += md_len; + len -= md_len; } else { - if (!EVP_DigestFinal_ex(&c, md, NULL)) { + uint8_t digest[EVP_MAX_MD_SIZE]; + if (!EVP_DigestFinal_ex(&ctx, digest, NULL)) { goto err; } - memcpy(mask + outlen, md, len - outlen); - outlen = len; + memcpy(out, digest, len); + len = 0; } } - ret = 0; + + ret = 1; err: - EVP_MD_CTX_cleanup(&c); + EVP_MD_CTX_cleanup(&ctx); return ret; } @@ -372,14 +371,14 @@ int RSA_padding_add_PKCS1_OAEP_mgf1(uint8_t *to, unsigned to_len, return 0; } - if (PKCS1_MGF1(dbmask, emlen - mdlen, seed, mdlen, mgf1md) < 0) { + if (!PKCS1_MGF1(dbmask, emlen - mdlen, seed, mdlen, mgf1md)) { goto out; } for (i = 0; i < emlen - mdlen; i++) { db[i] ^= dbmask[i]; } - if (PKCS1_MGF1(seedmask, mdlen, db, emlen - mdlen, mgf1md) < 0) { + if (!PKCS1_MGF1(seedmask, mdlen, db, emlen - mdlen, mgf1md)) { goto out; } for (i = 0; i < mdlen; i++) { @@ -428,14 +427,14 @@ int RSA_padding_check_PKCS1_OAEP_mgf1(uint8_t *to, unsigned to_len, maskedseed = from + 1; maskeddb = from + 1 + mdlen; - if (PKCS1_MGF1(seed, mdlen, maskeddb, dblen, mgf1md)) { + if (!PKCS1_MGF1(seed, mdlen, maskeddb, dblen, mgf1md)) { goto err; } for (i = 0; i < mdlen; i++) { seed[i] ^= maskedseed[i]; } - if (PKCS1_MGF1(db, dblen, seed, mdlen, mgf1md)) { + if (!PKCS1_MGF1(db, dblen, seed, mdlen, mgf1md)) { goto err; } for (i = 0; i < dblen; i++) { @@ -547,7 +546,7 @@ int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const uint8_t *mHash, OPENSSL_PUT_ERROR(RSA, ERR_R_MALLOC_FAILURE); goto err; } - if (PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash) < 0) { + if (!PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash)) { goto err; } for (i = 0; i < maskedDBLen; i++) { @@ -673,7 +672,7 @@ int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, unsigned char *EM, EVP_MD_CTX_cleanup(&ctx); /* Generate dbMask in place then perform XOR on it */ - if (PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash)) { + if (!PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash)) { goto err; }