From 3c0f6668ef6c3891533f681af61a065399883aa1 Mon Sep 17 00:00:00 2001 From: Joost Rijneveld Date: Sat, 22 Jul 2017 16:27:02 +0200 Subject: [PATCH] Add parameter for hash alg family, support SHAKE --- Makefile | 20 +- fips202.c | 418 ++++++++++++++++++++++++++++++++++++++++ fips202.h | 12 ++ hash.c | 48 +++-- hash.h | 11 +- test/test_wots.c | 3 +- test/test_xmss.c | 4 +- test/test_xmss_fast.c | 6 +- test/test_xmssmt.c | 4 +- test/test_xmssmt_fast.c | 4 +- wots.c | 7 +- wots.h | 3 +- xmss.c | 45 ++--- xmss.h | 5 +- xmss_fast.c | 51 ++--- xmss_fast.h | 5 +- 16 files changed, 549 insertions(+), 97 deletions(-) create mode 100644 fips202.c create mode 100644 fips202.h diff --git a/Makefile b/Makefile index eae3e15..7189371 100644 --- a/Makefile +++ b/Makefile @@ -7,20 +7,20 @@ test/test_xmss_fast \ test/test_xmssmt_fast \ test/test_xmssmt -test/test_wots: hash.c hash_address.c randombytes.c wots.c xmss_commons.c test/test_wots.c hash.h hash_address.h randombytes.h wots.h xmss_commons.h - $(CC) $(CFLAGS) hash.c hash_address.c randombytes.c wots.c xmss_commons.c test/test_wots.c -o $@ -lcrypto -lm +test/test_wots: hash.c fips202.c hash_address.c randombytes.c wots.c xmss_commons.c test/test_wots.c hash.h fips202.h hash_address.h randombytes.h wots.h xmss_commons.h + $(CC) $(CFLAGS) hash.c fips202.c hash_address.c randombytes.c wots.c xmss_commons.c test/test_wots.c -o $@ -lcrypto -lm -test/test_xmss: hash.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmss.c hash.h hash_address.h randombytes.h wots.h xmss.h xmss_commons.h - $(CC) $(CFLAGS) hash.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmss.c -o $@ -lcrypto -lm +test/test_xmss: hash.c fips202.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmss.c hash.h fips202.h hash_address.h randombytes.h wots.h xmss.h xmss_commons.h + $(CC) $(CFLAGS) hash.c fips202.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmss.c -o $@ -lcrypto -lm -test/test_xmss_fast: hash.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmss_fast.c hash.h hash_address.h randombytes.h wots.h xmss_fast.h xmss_commons.h - $(CC) $(CFLAGS) hash.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmss_fast.c -o $@ -lcrypto -lm +test/test_xmss_fast: hash.c fips202.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmss_fast.c hash.h fips202.h hash_address.h randombytes.h wots.h xmss_fast.h xmss_commons.h + $(CC) $(CFLAGS) hash.c fips202.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmss_fast.c -o $@ -lcrypto -lm -test/test_xmssmt: hash.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmssmt.c hash.h hash_address.h randombytes.h wots.h xmss.h xmss_commons.h - $(CC) $(CFLAGS) hash.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmssmt.c -o $@ -lcrypto -lm +test/test_xmssmt: hash.c fips202.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmssmt.c hash.h fips202.h hash_address.h randombytes.h wots.h xmss.h xmss_commons.h + $(CC) $(CFLAGS) hash.c fips202.c hash_address.c randombytes.c wots.c xmss.c xmss_commons.c test/test_xmssmt.c -o $@ -lcrypto -lm -test/test_xmssmt_fast: hash.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmssmt_fast.c hash.h hash_address.h randombytes.h wots.h xmss_fast.h xmss_commons.h - $(CC) $(CFLAGS) hash.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmssmt_fast.c -o $@ -lcrypto -lm +test/test_xmssmt_fast: hash.c fips202.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmssmt_fast.c hash.h fips202.h hash_address.h randombytes.h wots.h xmss_fast.h xmss_commons.h + $(CC) $(CFLAGS) hash.c fips202.c hash_address.c randombytes.c wots.c xmss_fast.c xmss_commons.c test/test_xmssmt_fast.c -o $@ -lcrypto -lm .PHONY: clean diff --git a/fips202.c b/fips202.c new file mode 100644 index 0000000..4568896 --- /dev/null +++ b/fips202.c @@ -0,0 +1,418 @@ +/* Based on the public domain implementation in + * crypto_hash/keccakc512/simple/ from http://bench.cr.yp.to/supercop.html + * by Ronny Van Keer + * and the public domain "TweetFips202" implementation + * from https://twitter.com/tweetfips202 + * by Gilles Van Assche, Daniel J. Bernstein, and Peter Schwabe */ + +#include +#include +#include "fips202.h" + +#define NROUNDS 24 +#define ROL(a, offset) ((a << offset) ^ (a >> (64-offset))) + +static uint64_t load64(const unsigned char *x) +{ + unsigned long long r = 0, i; + + for (i = 0; i < 8; ++i) { + r |= (unsigned long long)x[i] << 8 * i; + } + return r; +} + +static void store64(uint8_t *x, uint64_t u) +{ + unsigned int i; + + for(i=0; i<8; ++i) { + x[i] = u; + u >>= 8; + } +} + +static const uint64_t KeccakF_RoundConstants[NROUNDS] = +{ + (uint64_t)0x0000000000000001ULL, + (uint64_t)0x0000000000008082ULL, + (uint64_t)0x800000000000808aULL, + (uint64_t)0x8000000080008000ULL, + (uint64_t)0x000000000000808bULL, + (uint64_t)0x0000000080000001ULL, + (uint64_t)0x8000000080008081ULL, + (uint64_t)0x8000000000008009ULL, + (uint64_t)0x000000000000008aULL, + (uint64_t)0x0000000000000088ULL, + (uint64_t)0x0000000080008009ULL, + (uint64_t)0x000000008000000aULL, + (uint64_t)0x000000008000808bULL, + (uint64_t)0x800000000000008bULL, + (uint64_t)0x8000000000008089ULL, + (uint64_t)0x8000000000008003ULL, + (uint64_t)0x8000000000008002ULL, + (uint64_t)0x8000000000000080ULL, + (uint64_t)0x000000000000800aULL, + (uint64_t)0x800000008000000aULL, + (uint64_t)0x8000000080008081ULL, + (uint64_t)0x8000000000008080ULL, + (uint64_t)0x0000000080000001ULL, + (uint64_t)0x8000000080008008ULL +}; + +void KeccakF1600_StatePermute(uint64_t * state) +{ + int round; + + uint64_t Aba, Abe, Abi, Abo, Abu; + uint64_t Aga, Age, Agi, Ago, Agu; + uint64_t Aka, Ake, Aki, Ako, Aku; + uint64_t Ama, Ame, Ami, Amo, Amu; + uint64_t Asa, Ase, Asi, Aso, Asu; + uint64_t BCa, BCe, BCi, BCo, BCu; + uint64_t Da, De, Di, Do, Du; + uint64_t Eba, Ebe, Ebi, Ebo, Ebu; + uint64_t Ega, Ege, Egi, Ego, Egu; + uint64_t Eka, Eke, Eki, Eko, Eku; + uint64_t Ema, Eme, Emi, Emo, Emu; + uint64_t Esa, Ese, Esi, Eso, Esu; + + //copyFromState(A, state) + Aba = state[ 0]; + Abe = state[ 1]; + Abi = state[ 2]; + Abo = state[ 3]; + Abu = state[ 4]; + Aga = state[ 5]; + Age = state[ 6]; + Agi = state[ 7]; + Ago = state[ 8]; + Agu = state[ 9]; + Aka = state[10]; + Ake = state[11]; + Aki = state[12]; + Ako = state[13]; + Aku = state[14]; + Ama = state[15]; + Ame = state[16]; + Ami = state[17]; + Amo = state[18]; + Amu = state[19]; + Asa = state[20]; + Ase = state[21]; + Asi = state[22]; + Aso = state[23]; + Asu = state[24]; + + for( round = 0; round < NROUNDS; round += 2 ) + { + // prepareTheta + BCa = Aba^Aga^Aka^Ama^Asa; + BCe = Abe^Age^Ake^Ame^Ase; + BCi = Abi^Agi^Aki^Ami^Asi; + BCo = Abo^Ago^Ako^Amo^Aso; + BCu = Abu^Agu^Aku^Amu^Asu; + + //thetaRhoPiChiIotaPrepareTheta(round , A, E) + Da = BCu^ROL(BCe, 1); + De = BCa^ROL(BCi, 1); + Di = BCe^ROL(BCo, 1); + Do = BCi^ROL(BCu, 1); + Du = BCo^ROL(BCa, 1); + + Aba ^= Da; + BCa = Aba; + Age ^= De; + BCe = ROL(Age, 44); + Aki ^= Di; + BCi = ROL(Aki, 43); + Amo ^= Do; + BCo = ROL(Amo, 21); + Asu ^= Du; + BCu = ROL(Asu, 14); + Eba = BCa ^((~BCe)& BCi ); + Eba ^= (uint64_t)KeccakF_RoundConstants[round]; + Ebe = BCe ^((~BCi)& BCo ); + Ebi = BCi ^((~BCo)& BCu ); + Ebo = BCo ^((~BCu)& BCa ); + Ebu = BCu ^((~BCa)& BCe ); + + Abo ^= Do; + BCa = ROL(Abo, 28); + Agu ^= Du; + BCe = ROL(Agu, 20); + Aka ^= Da; + BCi = ROL(Aka, 3); + Ame ^= De; + BCo = ROL(Ame, 45); + Asi ^= Di; + BCu = ROL(Asi, 61); + Ega = BCa ^((~BCe)& BCi ); + Ege = BCe ^((~BCi)& BCo ); + Egi = BCi ^((~BCo)& BCu ); + Ego = BCo ^((~BCu)& BCa ); + Egu = BCu ^((~BCa)& BCe ); + + Abe ^= De; + BCa = ROL(Abe, 1); + Agi ^= Di; + BCe = ROL(Agi, 6); + Ako ^= Do; + BCi = ROL(Ako, 25); + Amu ^= Du; + BCo = ROL(Amu, 8); + Asa ^= Da; + BCu = ROL(Asa, 18); + Eka = BCa ^((~BCe)& BCi ); + Eke = BCe ^((~BCi)& BCo ); + Eki = BCi ^((~BCo)& BCu ); + Eko = BCo ^((~BCu)& BCa ); + Eku = BCu ^((~BCa)& BCe ); + + Abu ^= Du; + BCa = ROL(Abu, 27); + Aga ^= Da; + BCe = ROL(Aga, 36); + Ake ^= De; + BCi = ROL(Ake, 10); + Ami ^= Di; + BCo = ROL(Ami, 15); + Aso ^= Do; + BCu = ROL(Aso, 56); + Ema = BCa ^((~BCe)& BCi ); + Eme = BCe ^((~BCi)& BCo ); + Emi = BCi ^((~BCo)& BCu ); + Emo = BCo ^((~BCu)& BCa ); + Emu = BCu ^((~BCa)& BCe ); + + Abi ^= Di; + BCa = ROL(Abi, 62); + Ago ^= Do; + BCe = ROL(Ago, 55); + Aku ^= Du; + BCi = ROL(Aku, 39); + Ama ^= Da; + BCo = ROL(Ama, 41); + Ase ^= De; + BCu = ROL(Ase, 2); + Esa = BCa ^((~BCe)& BCi ); + Ese = BCe ^((~BCi)& BCo ); + Esi = BCi ^((~BCo)& BCu ); + Eso = BCo ^((~BCu)& BCa ); + Esu = BCu ^((~BCa)& BCe ); + + // prepareTheta + BCa = Eba^Ega^Eka^Ema^Esa; + BCe = Ebe^Ege^Eke^Eme^Ese; + BCi = Ebi^Egi^Eki^Emi^Esi; + BCo = Ebo^Ego^Eko^Emo^Eso; + BCu = Ebu^Egu^Eku^Emu^Esu; + + //thetaRhoPiChiIotaPrepareTheta(round+1, E, A) + Da = BCu^ROL(BCe, 1); + De = BCa^ROL(BCi, 1); + Di = BCe^ROL(BCo, 1); + Do = BCi^ROL(BCu, 1); + Du = BCo^ROL(BCa, 1); + + Eba ^= Da; + BCa = Eba; + Ege ^= De; + BCe = ROL(Ege, 44); + Eki ^= Di; + BCi = ROL(Eki, 43); + Emo ^= Do; + BCo = ROL(Emo, 21); + Esu ^= Du; + BCu = ROL(Esu, 14); + Aba = BCa ^((~BCe)& BCi ); + Aba ^= (uint64_t)KeccakF_RoundConstants[round+1]; + Abe = BCe ^((~BCi)& BCo ); + Abi = BCi ^((~BCo)& BCu ); + Abo = BCo ^((~BCu)& BCa ); + Abu = BCu ^((~BCa)& BCe ); + + Ebo ^= Do; + BCa = ROL(Ebo, 28); + Egu ^= Du; + BCe = ROL(Egu, 20); + Eka ^= Da; + BCi = ROL(Eka, 3); + Eme ^= De; + BCo = ROL(Eme, 45); + Esi ^= Di; + BCu = ROL(Esi, 61); + Aga = BCa ^((~BCe)& BCi ); + Age = BCe ^((~BCi)& BCo ); + Agi = BCi ^((~BCo)& BCu ); + Ago = BCo ^((~BCu)& BCa ); + Agu = BCu ^((~BCa)& BCe ); + + Ebe ^= De; + BCa = ROL(Ebe, 1); + Egi ^= Di; + BCe = ROL(Egi, 6); + Eko ^= Do; + BCi = ROL(Eko, 25); + Emu ^= Du; + BCo = ROL(Emu, 8); + Esa ^= Da; + BCu = ROL(Esa, 18); + Aka = BCa ^((~BCe)& BCi ); + Ake = BCe ^((~BCi)& BCo ); + Aki = BCi ^((~BCo)& BCu ); + Ako = BCo ^((~BCu)& BCa ); + Aku = BCu ^((~BCa)& BCe ); + + Ebu ^= Du; + BCa = ROL(Ebu, 27); + Ega ^= Da; + BCe = ROL(Ega, 36); + Eke ^= De; + BCi = ROL(Eke, 10); + Emi ^= Di; + BCo = ROL(Emi, 15); + Eso ^= Do; + BCu = ROL(Eso, 56); + Ama = BCa ^((~BCe)& BCi ); + Ame = BCe ^((~BCi)& BCo ); + Ami = BCi ^((~BCo)& BCu ); + Amo = BCo ^((~BCu)& BCa ); + Amu = BCu ^((~BCa)& BCe ); + + Ebi ^= Di; + BCa = ROL(Ebi, 62); + Ego ^= Do; + BCe = ROL(Ego, 55); + Eku ^= Du; + BCi = ROL(Eku, 39); + Ema ^= Da; + BCo = ROL(Ema, 41); + Ese ^= De; + BCu = ROL(Ese, 2); + Asa = BCa ^((~BCe)& BCi ); + Ase = BCe ^((~BCi)& BCo ); + Asi = BCi ^((~BCo)& BCu ); + Aso = BCo ^((~BCu)& BCa ); + Asu = BCu ^((~BCa)& BCe ); + } + + //copyToState(state, A) + state[ 0] = Aba; + state[ 1] = Abe; + state[ 2] = Abi; + state[ 3] = Abo; + state[ 4] = Abu; + state[ 5] = Aga; + state[ 6] = Age; + state[ 7] = Agi; + state[ 8] = Ago; + state[ 9] = Agu; + state[10] = Aka; + state[11] = Ake; + state[12] = Aki; + state[13] = Ako; + state[14] = Aku; + state[15] = Ama; + state[16] = Ame; + state[17] = Ami; + state[18] = Amo; + state[19] = Amu; + state[20] = Asa; + state[21] = Ase; + state[22] = Asi; + state[23] = Aso; + state[24] = Asu; + + #undef round +} + +#include +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + + +static void keccak_absorb(uint64_t *s, + unsigned int r, + const unsigned char *m, unsigned long long int mlen, + unsigned char p) +{ + unsigned long long i; + unsigned char t[200]; + + while (mlen >= r) + { + for (i = 0; i < r / 8; ++i) + s[i] ^= load64(m + 8 * i); + KeccakF1600_StatePermute(s); + mlen -= r; + m += r; + } + + for (i = 0; i < r; ++i) + t[i] = 0; + for (i = 0; i < mlen; ++i) + t[i] = m[i]; + t[i] = p; + t[r - 1] |= 128; + for (i = 0; i < r / 8; ++i) + s[i] ^= load64(t + 8 * i); +} + +static void keccak_squeezeblocks(unsigned char *h, unsigned long long int nblocks, + uint64_t *s, + unsigned int r) +{ + unsigned int i; + while(nblocks > 0) + { + KeccakF1600_StatePermute(s); + for(i=0;i<(r>>3);i++) + { + store64(h+8*i, s[i]); + } + h += r; + nblocks--; + } +} + +void shake128(unsigned char *output, unsigned int outputByteLen, const unsigned char *input, unsigned int inputByteLen) +{ + unsigned int i; + uint64_t s[25]; + unsigned char d[SHAKE128_RATE]; + + for(i = 0; i < 25; i++) + s[i] = 0; + keccak_absorb(s, SHAKE128_RATE, input, inputByteLen, 0x1F); + + keccak_squeezeblocks(output, outputByteLen/SHAKE128_RATE, s, SHAKE128_RATE); + output += (outputByteLen/SHAKE128_RATE)*SHAKE128_RATE; + + if (outputByteLen % SHAKE128_RATE) { + keccak_squeezeblocks(d, 1, s, SHAKE128_RATE); + for(i = 0; i < outputByteLen % SHAKE128_RATE; i++) { + output[i] = d[i]; + } + } +} + +void shake256(unsigned char *output, unsigned int outputByteLen, const unsigned char *input, unsigned int inputByteLen) +{ + unsigned int i; + uint64_t s[25]; + unsigned char d[SHAKE256_RATE]; + + for(i = 0; i < 25; i++) + s[i] = 0; + keccak_absorb(s, SHAKE256_RATE, input, inputByteLen, 0x1F); + + keccak_squeezeblocks(output, outputByteLen/SHAKE256_RATE, s, SHAKE256_RATE); + output += (outputByteLen/SHAKE256_RATE)*SHAKE256_RATE; + + if (outputByteLen % SHAKE256_RATE) { + keccak_squeezeblocks(d, 1, s, SHAKE256_RATE); + for(i = 0; i < outputByteLen % SHAKE256_RATE; i++) { + output[i] = d[i]; + } + } +} diff --git a/fips202.h b/fips202.h new file mode 100644 index 0000000..5dd27a3 --- /dev/null +++ b/fips202.h @@ -0,0 +1,12 @@ +#ifndef FIPS202_H +#define FIPS202_H + +#include + +#define SHAKE128_RATE 168 +#define SHAKE256_RATE 136 + +void shake128(unsigned char *output, unsigned int outputByteLen, const unsigned char *input, unsigned int inputByteLen); +void shake256(unsigned char *output, unsigned int outputByteLen, const unsigned char *input, unsigned int inputByteLen); + +#endif diff --git a/hash.c b/hash.c index 4f713c6..83e1bea 100644 --- a/hash.c +++ b/hash.c @@ -8,6 +8,7 @@ Public domain. #include "hash_address.h" #include "xmss_commons.h" #include "hash.h" +#include "fips202.h" #include #include @@ -29,7 +30,7 @@ unsigned char* addr_to_byte(unsigned char *bytes, const uint32_t addr[8]){ #endif } -int core_hash_SHA2(unsigned char *out, const unsigned int type, const unsigned char *key, unsigned int keylen, const unsigned char *in, unsigned long long inlen, unsigned int n){ +int core_hash(unsigned char *out, const unsigned int type, const unsigned char *key, unsigned int keylen, const unsigned char *in, unsigned long long inlen, unsigned int n, const unsigned char hash_alg){ unsigned long long i = 0; unsigned char buf[inlen + n + keylen]; @@ -46,43 +47,48 @@ int core_hash_SHA2(unsigned char *out, const unsigned int type, const unsigned c buf[keylen + n + i] = in[i]; } - if (n == 32) { + if (n == 32 && hash_alg == XMSS_SHA2) { SHA256(buf, inlen + keylen + n, out); - return 0; + } + else if (n == 32 && hash_alg == XMSS_SHAKE) { + shake128(out, 32, buf, inlen + keylen + n); + } + else if (n == 64 && hash_alg == XMSS_SHA2) { + SHA512(buf, inlen + keylen + n, out); + } + else if (n == 64 && hash_alg == XMSS_SHAKE) { + shake256(out, 64, buf, inlen + keylen + n); } else { - if (n == 64) { - SHA512(buf, inlen + keylen + n, out); - return 0; - } + return 1; } - return 1; + return 0; } /** * Implements PRF */ -int prf(unsigned char *out, const unsigned char *in, const unsigned char *key, unsigned int keylen) +int prf(unsigned char *out, const unsigned char *in, const unsigned char *key, unsigned int keylen, const unsigned char hash_alg) { - return core_hash_SHA2(out, 3, key, keylen, in, 32, keylen); + return core_hash(out, 3, key, keylen, in, 32, keylen, hash_alg); } /* * Implemts H_msg */ -int h_msg(unsigned char *out, const unsigned char *in, unsigned long long inlen, const unsigned char *key, const unsigned int keylen, const unsigned int n) +int h_msg(unsigned char *out, const unsigned char *in, unsigned long long inlen, const unsigned char *key, const unsigned int keylen, const unsigned int n, const unsigned char hash_alg) { if (keylen != 3*n){ fprintf(stderr, "H_msg takes 3n-bit keys, we got n=%d but a keylength of %d.\n", n, keylen); return 1; } - return core_hash_SHA2(out, 2, key, keylen, in, inlen, n); + return core_hash(out, 2, key, keylen, in, inlen, n, hash_alg); } /** * We assume the left half is in in[0]...in[n-1] */ -int hash_h(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n) +int hash_h(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n, const unsigned char hash_alg) { unsigned char buf[2*n]; @@ -93,21 +99,21 @@ int hash_h(unsigned char *out, const unsigned char *in, const unsigned char *pub setKeyAndMask(addr, 0); addr_to_byte(byte_addr, addr); - prf(key, byte_addr, pub_seed, n); + prf(key, byte_addr, pub_seed, n, hash_alg); // Use MSB order setKeyAndMask(addr, 1); addr_to_byte(byte_addr, addr); - prf(bitmask, byte_addr, pub_seed, n); + prf(bitmask, byte_addr, pub_seed, n, hash_alg); setKeyAndMask(addr, 2); addr_to_byte(byte_addr, addr); - prf(bitmask+n, byte_addr, pub_seed, n); + prf(bitmask+n, byte_addr, pub_seed, n, hash_alg); for (i = 0; i < 2*n; i++) { buf[i] = in[i] ^ bitmask[i]; } - return core_hash_SHA2(out, 1, key, n, buf, 2*n, n); + return core_hash(out, 1, key, n, buf, 2*n, n, hash_alg); } -int hash_f(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n) +int hash_f(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n, const unsigned char hash_alg) { unsigned char buf[n]; unsigned char key[n]; @@ -117,14 +123,14 @@ int hash_f(unsigned char *out, const unsigned char *in, const unsigned char *pub setKeyAndMask(addr, 0); addr_to_byte(byte_addr, addr); - prf(key, byte_addr, pub_seed, n); + prf(key, byte_addr, pub_seed, n, hash_alg); setKeyAndMask(addr, 1); addr_to_byte(byte_addr, addr); - prf(bitmask, byte_addr, pub_seed, n); + prf(bitmask, byte_addr, pub_seed, n, hash_alg); for (i = 0; i < n; i++) { buf[i] = in[i] ^ bitmask[i]; } - return core_hash_SHA2(out, 0, key, n, buf, n, n); + return core_hash(out, 0, key, n, buf, n, n, hash_alg); } diff --git a/hash.h b/hash.h index 2fed730..12354f5 100644 --- a/hash.h +++ b/hash.h @@ -8,12 +8,15 @@ Public domain. #ifndef HASH_H #define HASH_H +#define XMSS_SHA2 0x01 +#define XMSS_SHAKE 0x02 + #define IS_LITTLE_ENDIAN 1 unsigned char* addr_to_byte(unsigned char *bytes, const uint32_t addr[8]); -int prf(unsigned char *out, const unsigned char *in, const unsigned char *key, unsigned int keylen); -int h_msg(unsigned char *out,const unsigned char *in,unsigned long long inlen, const unsigned char *key, const unsigned int keylen, const unsigned int n); -int hash_h(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n); -int hash_f(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n); +int prf(unsigned char *out, const unsigned char *in, const unsigned char *key, unsigned int keylen, const unsigned char hash_alg); +int h_msg(unsigned char *out,const unsigned char *in,unsigned long long inlen, const unsigned char *key, const unsigned int keylen, const unsigned int n, const unsigned char hash_alg); +int hash_h(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n, const unsigned char hash_alg); +int hash_f(unsigned char *out, const unsigned char *in, const unsigned char *pub_seed, uint32_t addr[8], const unsigned int n, const unsigned char hash_alg); #endif diff --git a/test/test_wots.c b/test/test_wots.c index 03dbf39..eba44ea 100644 --- a/test/test_wots.c +++ b/test/test_wots.c @@ -2,6 +2,7 @@ #include #include #include "../wots.h" +#include "../hash.h" #include "../randombytes.h" static void hexdump(unsigned char *a, size_t len) @@ -17,7 +18,7 @@ int main() unsigned char seed[n]; unsigned char pub_seed[n]; wots_params params; - wots_set_params(¶ms, n, 16); + wots_set_params(¶ms, n, 16, XMSS_SHA2); int sig_len = params.len*params.n; diff --git a/test/test_xmss.c b/test/test_xmss.c index f7bedd2..9831f51 100644 --- a/test/test_xmss.c +++ b/test/test_xmss.c @@ -2,6 +2,7 @@ #include #include "../xmss.h" +#include "../hash.h" #define MLEN 3491 #define SIGNATURES 50 @@ -17,6 +18,7 @@ int main() unsigned int n = 32; unsigned int h = 8; unsigned int w = 16; + unsigned char hash_alg = XMSS_SHA2; unsigned long errors = 0; @@ -25,7 +27,7 @@ int main() xmss_params p; xmss_params *params = &p; - xmss_set_params(params, n, h, w); + xmss_set_params(params, n, h, w, hash_alg); unsigned long long signature_length = 4+n+params->wots_par.keysize+h*n; unsigned char mo[MLEN+signature_length]; unsigned char sm[MLEN+signature_length]; diff --git a/test/test_xmss_fast.c b/test/test_xmss_fast.c index 4d8bfd8..ee38eb8 100644 --- a/test/test_xmss_fast.c +++ b/test/test_xmss_fast.c @@ -3,6 +3,7 @@ #include #include "../xmss_fast.h" +#include "../hash.h" #define MLEN 3491 #define SIGNATURES 256 @@ -28,6 +29,7 @@ int main() unsigned int h = 20; unsigned int w = 16; unsigned int k = 2; + unsigned char hash_alg = XMSS_SHA2; unsigned long errors = 0; @@ -36,7 +38,7 @@ int main() xmss_params p; xmss_params *params = &p; - xmss_set_params(params, n, h, w, k); + xmss_set_params(params, n, h, w, k, hash_alg); // TODO should we hide this into xmss_fast.c and just allocate a large enough chunk of memory here? unsigned char stack[(h+1)*n]; @@ -67,8 +69,6 @@ int main() printf("cycles = %llu\n", (t2-t1)); double sec = (t2-t1)/3500000; printf("ms = %f\n", sec); - int read; - read = fgetc(stdin); // check pub_seed in SK for (i = 0; i < n; i++) { if (pk[n+i] != sk[4+2*n+i]) printf("pk.pub_seed != sk.pub_seed %llu",i); diff --git a/test/test_xmssmt.c b/test/test_xmssmt.c index 29827bb..205a1a6 100644 --- a/test/test_xmssmt.c +++ b/test/test_xmssmt.c @@ -2,6 +2,7 @@ #include #include "../xmss.h" +#include "../hash.h" #define MLEN 3491 #define SIGNATURES 1024 @@ -18,10 +19,11 @@ int main() unsigned int h = 20; unsigned int d = 5; unsigned int w = 16; + unsigned char hash_alg = XMSS_SHA2; xmssmt_params p; xmssmt_params *params = &p; - xmssmt_set_params(params, n, h, d, w); + xmssmt_set_params(params, n, h, d, w, hash_alg); unsigned char sk[(params->index_len+4*n)]; unsigned char pk[2*n]; diff --git a/test/test_xmssmt_fast.c b/test/test_xmssmt_fast.c index fc0f057..168f5a7 100644 --- a/test/test_xmssmt_fast.c +++ b/test/test_xmssmt_fast.c @@ -2,6 +2,7 @@ #include #include "../xmss_fast.h" +#include "../hash.h" #define MLEN 3491 #define SIGNATURES 4096 @@ -19,10 +20,11 @@ int main() unsigned int d = 3; unsigned int w = 16; unsigned int k = 2; + unsigned char hash_alg = XMSS_SHA2; xmssmt_params p; xmssmt_params *params = &p; - if (xmssmt_set_params(params, n, h, d, w, k)) { + if (xmssmt_set_params(params, n, h, d, w, k, hash_alg)) { return 1; } diff --git a/wots.c b/wots.c index b3f15db..65881f0 100644 --- a/wots.c +++ b/wots.c @@ -16,7 +16,7 @@ Public domain. #include "hash_address.h" -void wots_set_params(wots_params *params, int n, int w) +void wots_set_params(wots_params *params, int n, int w, unsigned char hash_alg) { params->n = n; params->w = w; @@ -25,6 +25,7 @@ void wots_set_params(wots_params *params, int n, int w) params->len_2 = (int) floor(log2(params->len_1*(w-1)) / params->log_w) + 1; params->len = params->len_1 + params->len_2; params->keysize = params->len*params->n; + params->hash_alg = hash_alg; } /** @@ -38,7 +39,7 @@ static void expand_seed(unsigned char *outseeds, const unsigned char *inseed, co unsigned char ctr[32]; for(i = 0; i < params->len; i++){ to_byte(ctr, i, 32); - prf((outseeds + (i*params->n)), ctr, inseed, params->n); + prf((outseeds + (i*params->n)), ctr, inseed, params->n, params->hash_alg); } } @@ -57,7 +58,7 @@ static void gen_chain(unsigned char *out, const unsigned char *in, unsigned int for (i = start; i < (start+steps) && i < params->w; i++) { setHashADRS(addr, i); - hash_f(out, out, pub_seed, addr, params->n); + hash_f(out, out, pub_seed, addr, params->n, params->hash_alg); } } diff --git a/wots.h b/wots.h index 6e5c8f7..606fc54 100644 --- a/wots.h +++ b/wots.h @@ -23,6 +23,7 @@ typedef struct { uint32_t w; uint32_t log_w; uint32_t keysize; + unsigned char hash_alg; } wots_params; /** @@ -32,7 +33,7 @@ typedef struct { * * Assumes w is a power of 2 */ -void wots_set_params(wots_params *params, int n, int w); +void wots_set_params(wots_params *params, int n, int w, unsigned char hash_alg); /** * WOTS key generation. Takes a 32byte seed for the secret key, expands it to a full WOTS secret key and computes the corresponding public key. diff --git a/xmss.c b/xmss.c index 966f2c3..4cc638c 100644 --- a/xmss.c +++ b/xmss.c @@ -28,7 +28,7 @@ Public domain. * * takes n byte sk_seed and returns n byte seed using 32 byte address addr. */ -static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8]) +static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8], const unsigned char hash_alg) { unsigned char bytes[32]; // Make sure that chain addr, hash addr, and key bit are 0! @@ -37,20 +37,21 @@ static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, u setKeyAndMask(addr,0); // Generate pseudorandom value addr_to_byte(bytes, addr); - prf(seed, bytes, sk_seed, n); + prf(seed, bytes, sk_seed, n, hash_alg); } /** * Initialize xmss params struct * parameter names are the same as in the draft */ -void xmss_set_params(xmss_params *params, int n, int h, int w) +void xmss_set_params(xmss_params *params, int n, int h, int w, unsigned char hash_alg) { params->h = h; params->n = n; wots_params wots_par; - wots_set_params(&wots_par, n, w); + wots_set_params(&wots_par, n, w, hash_alg); params->wots_par = wots_par; + params->hash_alg = hash_alg; } /** @@ -59,7 +60,7 @@ void xmss_set_params(xmss_params *params, int n, int h, int w) * * Especially h is the total tree height, i.e. the XMSS trees have height h/d */ -void xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w) +void xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, unsigned char hash_alg) { if (h % d) { fprintf(stderr, "d must devide h without remainder!\n"); @@ -70,7 +71,7 @@ void xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w) params->n = n; params->index_len = (h + 7) / 8; xmss_params xmss_par; - xmss_set_params(&xmss_par, n, (h/d), w); + xmss_set_params(&xmss_par, n, (h/d), w, hash_alg); params->xmss_par = xmss_par; } @@ -94,7 +95,7 @@ static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_param //ADRS.setTreeIndex(i); setTreeIndex(addr, i); //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS); - hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n); + hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n, params->hash_alg); } //if ( l % 2 == 1 ) { if (l & 1) { @@ -123,7 +124,7 @@ static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, con unsigned char seed[params->n]; unsigned char pk[params->wots_par.keysize]; - get_seed(seed, sk_seed, params->n, ots_addr); + get_seed(seed, sk_seed, params->n, ots_addr, params->hash_alg); wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr); l_tree(leaf, pk, params, pub_seed, ltree_addr); @@ -169,7 +170,7 @@ static void treehash(unsigned char *node, uint16_t height, uint32_t index, const setTreeHeight(node_addr, stacklevels[stackoffset-1]); setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1))); hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed, - node_addr, n); + node_addr, n, params->hash_alg); stacklevels[stackoffset-2]++; stackoffset--; } @@ -209,12 +210,12 @@ static void validate_authpath(unsigned char *root, const unsigned char *leaf, un leafidx >>= 1; setTreeIndex(addr, leafidx); if (leafidx&1) { - hash_h(buffer+n, buffer, pub_seed, addr, n); + hash_h(buffer+n, buffer, pub_seed, addr, n, params->hash_alg); for (j = 0; j < n; j++) buffer[j] = authpath[j]; } else { - hash_h(buffer, buffer, pub_seed, addr, n); + hash_h(buffer, buffer, pub_seed, addr, n, params->hash_alg); for (j = 0; j < n; j++) buffer[j+n] = authpath[j]; } @@ -223,7 +224,7 @@ static void validate_authpath(unsigned char *root, const unsigned char *leaf, un setTreeHeight(addr, (params->h-1)); leafidx >>= 1; setTreeIndex(addr, leafidx); - hash_h(root, buffer, pub_seed, addr, n); + hash_h(root, buffer, pub_seed, addr, n, params->hash_alg); } /** @@ -266,7 +267,7 @@ static void compute_authpath_wots(unsigned char *root, unsigned char *authpath, // Inner loop: for each pair of sibling nodes for (j = 0; j < i; j+=2) { setTreeIndex(node_addr, j>>1); - hash_h(tree + (i>>1)*n + (j>>1) * n, tree + i*n + j*n, pub_seed, node_addr, n); + hash_h(tree + (i>>1)*n + (j>>1) * n, tree + i*n + j*n, pub_seed, node_addr, n, params->hash_alg); } level++; } @@ -355,13 +356,13 @@ int xmss_sign(unsigned char *sk, unsigned char *sig_msg, unsigned long long *sig // Message Hash: // First compute pseudorandom value - prf(R, idx_bytes_32, sk_prf, n); + prf(R, idx_bytes_32, sk_prf, n, params->hash_alg); // Generate hash key (R || root || idx) memcpy(hash_key, R, n); memcpy(hash_key+n, sk+4+3*n, n); to_byte(hash_key+2*n, idx, n); // Then use it for message digest - h_msg(msg_h, msg, msglen, hash_key, 3*n, n); + h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->hash_alg); // Start collecting signature *sig_msg_len = 0; @@ -391,7 +392,7 @@ int xmss_sign(unsigned char *sk, unsigned char *sig_msg, unsigned long long *sig setOTSADRS(ots_addr, idx); // Compute seed for OTS key pair - get_seed(ots_seed, sk_seed, n, ots_addr); + get_seed(ots_seed, sk_seed, n, ots_addr, params->hash_alg); // Compute WOTS signature wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr); @@ -455,7 +456,7 @@ int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigne // hash message unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n; m_len = sig_msg_len - tmp_sig_len; - h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n); + h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->hash_alg); //----------------------- // Verify signature @@ -579,14 +580,14 @@ int xmssmt_sign(unsigned char *sk, unsigned char *sig_msg, unsigned long long *s // Message Hash: // First compute pseudorandom value to_byte(idx_bytes_32, idx, 32); - prf(R, idx_bytes_32, sk_prf, n); + prf(R, idx_bytes_32, sk_prf, n, params->xmss_par.hash_alg); // Generate hash key (R || root || idx) memcpy(hash_key, R, n); memcpy(hash_key+n, sk+idx_len+3*n, n); to_byte(hash_key+2*n, idx, n); // Then use it for message digest - h_msg(msg_h, msg, msglen, hash_key, 3*n, n); + h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->xmss_par.hash_alg); // Start collecting signature *sig_msg_len = 0; @@ -621,7 +622,7 @@ int xmssmt_sign(unsigned char *sk, unsigned char *sig_msg, unsigned long long *s setOTSADRS(ots_addr, idx_leaf); // Compute seed for OTS key pair - get_seed(ots_seed, sk_seed, n, ots_addr); + get_seed(ots_seed, sk_seed, n, ots_addr, params->xmss_par.hash_alg); // Compute WOTS signature wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr); @@ -644,7 +645,7 @@ int xmssmt_sign(unsigned char *sk, unsigned char *sig_msg, unsigned long long *s setOTSADRS(ots_addr, idx_leaf); // Compute seed for OTS key pair - get_seed(ots_seed, sk_seed, n, ots_addr); + get_seed(ots_seed, sk_seed, n, ots_addr, params->xmss_par.hash_alg); // Compute WOTS signature wots_sign(sig_msg, root, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr); @@ -713,7 +714,7 @@ int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsig // hash message unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n); m_len = sig_msg_len - tmp_sig_len; - h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n); + h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->xmss_par.hash_alg); //----------------------- diff --git a/xmss.h b/xmss.h index a51c2b8..74b6614 100644 --- a/xmss.h +++ b/xmss.h @@ -19,6 +19,7 @@ typedef struct{ wots_params wots_par; unsigned int n; unsigned int h; + unsigned char hash_alg; } xmss_params; typedef struct{ @@ -32,14 +33,14 @@ typedef struct{ * Initializes parameter set. * Needed, for any of the other methods. */ -void xmss_set_params(xmss_params *params, int n, int h, int w); +void xmss_set_params(xmss_params *params, int n, int h, int w, unsigned char hash_alg); /** * Initialize xmssmt_params struct * parameter names are the same as in the draft * * Especially h is the total tree height, i.e. the XMSS trees have height h/d */ -void xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w); +void xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, unsigned char hash_alg); /** * Generates a XMSS key pair for a given parameter set. * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root] diff --git a/xmss_fast.c b/xmss_fast.c index 53e146d..dd02cf2 100644 --- a/xmss_fast.c +++ b/xmss_fast.c @@ -28,7 +28,7 @@ Public domain. * * takes n byte sk_seed and returns n byte seed using 32 byte address addr. */ -static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8]) +static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8], unsigned char hash_alg) { unsigned char bytes[32]; // Make sure that chain addr, hash addr, and key bit are 0! @@ -37,7 +37,7 @@ static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, u setKeyAndMask(addr,0); // Generate pseudorandom value addr_to_byte(bytes, addr); - prf(seed, bytes, sk_seed, n); + prf(seed, bytes, sk_seed, n, hash_alg); } /** @@ -45,7 +45,7 @@ static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, u * parameter names are the same as in the draft * parameter k is K as used in the BDS algorithm */ -int xmss_set_params(xmss_params *params, int n, int h, int w, int k) +int xmss_set_params(xmss_params *params, int n, int h, int w, int k, unsigned char hash_alg) { if (k >= h || k < 2 || (h - k) % 2) { fprintf(stderr, "For BDS traversal, H - K must be even, with H > K >= 2!\n"); @@ -55,8 +55,9 @@ int xmss_set_params(xmss_params *params, int n, int h, int w, int k) params->n = n; params->k = k; wots_params wots_par; - wots_set_params(&wots_par, n, w); + wots_set_params(&wots_par, n, w, hash_alg); params->wots_par = wots_par; + params->hash_alg = hash_alg; return 0; } @@ -82,7 +83,7 @@ void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, * * Especially h is the total tree height, i.e. the XMSS trees have height h/d */ -int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k) +int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k, unsigned char hash_alg) { if (h % d) { fprintf(stderr, "d must divide h without remainder!\n"); @@ -93,7 +94,7 @@ int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k) params->n = n; params->index_len = (h + 7) / 8; xmss_params xmss_par; - if (xmss_set_params(&xmss_par, n, (h/d), w, k)) { + if (xmss_set_params(&xmss_par, n, (h/d), w, k, hash_alg)) { return 1; } params->xmss_par = xmss_par; @@ -120,7 +121,7 @@ static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_param //ADRS.setTreeIndex(i); setTreeIndex(addr, i); //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS); - hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n); + hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n, params->hash_alg); } //if ( l % 2 == 1 ) { if (l & 1) { @@ -149,7 +150,7 @@ static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, con unsigned char seed[params->n]; unsigned char pk[params->wots_par.keysize]; - get_seed(seed, sk_seed, params->n, ots_addr); + get_seed(seed, sk_seed, params->n, ots_addr, params->hash_alg); wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr); l_tree(leaf, pk, params, pub_seed, ltree_addr); @@ -230,7 +231,7 @@ static void treehash_setup(unsigned char *node, int height, int index, bds_state setTreeHeight(node_addr, stacklevels[stackoffset-1]); setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1))); hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed, - node_addr, n); + node_addr, n, params->hash_alg); stacklevels[stackoffset-2]++; stackoffset--; } @@ -267,7 +268,7 @@ static void treehash_update(treehash_inst *treehash, bds_state *state, const uns memcpy(nodebuffer, state->stack + (state->stackoffset-1)*n, n); setTreeHeight(node_addr, nodeheight); setTreeIndex(node_addr, (treehash->next_idx >> (nodeheight+1))); - hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n); + hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n, params->hash_alg); nodeheight++; treehash->stackusage--; state->stackoffset--; @@ -316,12 +317,12 @@ static void validate_authpath(unsigned char *root, const unsigned char *leaf, un leafidx >>= 1; setTreeIndex(addr, leafidx); if (leafidx&1) { - hash_h(buffer+n, buffer, pub_seed, addr, n); + hash_h(buffer+n, buffer, pub_seed, addr, n, params->hash_alg); for (j = 0; j < n; j++) buffer[j] = authpath[j]; } else { - hash_h(buffer, buffer, pub_seed, addr, n); + hash_h(buffer, buffer, pub_seed, addr, n, params->hash_alg); for (j = 0; j < n; j++) buffer[j+n] = authpath[j]; } @@ -330,7 +331,7 @@ static void validate_authpath(unsigned char *root, const unsigned char *leaf, un setTreeHeight(addr, (params->h-1)); leafidx >>= 1; setTreeIndex(addr, leafidx); - hash_h(root, buffer, pub_seed, addr, n); + hash_h(root, buffer, pub_seed, addr, n, params->hash_alg); } /** @@ -424,7 +425,7 @@ static char bds_state_update(bds_state *state, const unsigned char *sk_seed, con } setTreeHeight(node_addr, state->stacklevels[state->stackoffset-1]); setTreeIndex(node_addr, (idx >> (state->stacklevels[state->stackoffset-1]+1))); - hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n); + hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n, params->hash_alg); state->stacklevels[state->stackoffset-2]++; state->stackoffset--; @@ -485,7 +486,7 @@ static void bds_round(bds_state *state, const unsigned long leaf_idx, const unsi else { setTreeHeight(node_addr, (tau-1)); setTreeIndex(node_addr, leaf_idx >> tau); - hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n); + hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n, params->hash_alg); for (i = 0; i < tau; i++) { if (i < h - k) { memcpy(state->auth + i * n, state->treehash[i].node, n); @@ -585,13 +586,13 @@ int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsig // Message Hash: // First compute pseudorandom value - prf(R, idx_bytes_32, sk_prf, n); + prf(R, idx_bytes_32, sk_prf, n, params->hash_alg); // Generate hash key (R || root || idx) memcpy(hash_key, R, n); memcpy(hash_key+n, sk+4+3*n, n); to_byte(hash_key+2*n, idx, n); // Then use it for message digest - h_msg(msg_h, msg, msglen, hash_key, 3*n, n); + h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->hash_alg); // Start collecting signature *sig_msg_len = 0; @@ -621,7 +622,7 @@ int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsig setOTSADRS(ots_addr, idx); // Compute seed for OTS key pair - get_seed(ots_seed, sk_seed, n, ots_addr); + get_seed(ots_seed, sk_seed, n, ots_addr, params->hash_alg); // Compute WOTS signature wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr); @@ -691,7 +692,7 @@ int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigne // hash message unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n; m_len = sig_msg_len - tmp_sig_len; - h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n); + h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->hash_alg); //----------------------- // Verify signature @@ -761,7 +762,7 @@ int xmssmt_keypair(unsigned char *pk, unsigned char *sk, bds_state *states, unsi // Compute seed for OTS key pair treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr); setLayerADRS(addr, (i+1)); - get_seed(ots_seed, sk+params->index_len, n, addr); + get_seed(ots_seed, sk+params->index_len, n, addr, params->xmss_par.hash_alg); wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, pk, ots_seed, &(params->xmss_par.wots_par), pk+n, addr); } // Address now points to the single tree on layer d-1 @@ -829,14 +830,14 @@ int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, // Message Hash: // First compute pseudorandom value to_byte(idx_bytes_32, idx, 32); - prf(R, idx_bytes_32, sk_prf, n); + prf(R, idx_bytes_32, sk_prf, n, params->xmss_par.hash_alg); // Generate hash key (R || root || idx) memcpy(hash_key, R, n); memcpy(hash_key+n, sk+idx_len+3*n, n); to_byte(hash_key+2*n, idx, n); // Then use it for message digest - h_msg(msg_h, msg, msglen, hash_key, 3*n, n); + h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->xmss_par.hash_alg); // Start collecting signature *sig_msg_len = 0; @@ -871,7 +872,7 @@ int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, setOTSADRS(ots_addr, idx_leaf); // Compute seed for OTS key pair - get_seed(ots_seed, sk_seed, n, ots_addr); + get_seed(ots_seed, sk_seed, n, ots_addr, params->xmss_par.hash_alg); // Compute WOTS signature wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr); @@ -934,7 +935,7 @@ int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, setTreeADRS(ots_addr, ((idx + 1) >> ((i+2) * tree_h))); setOTSADRS(ots_addr, (((idx >> ((i+1) * tree_h)) + 1) & ((1 << tree_h)-1))); - get_seed(ots_seed, sk+params->index_len, n, ots_addr); + get_seed(ots_seed, sk+params->index_len, n, ots_addr, params->xmss_par.hash_alg); wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, states[i].stack, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr); states[params->d + i].stackoffset = 0; @@ -1005,7 +1006,7 @@ int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsig // hash message (recall, R is now on pole position at sig_msg unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n); m_len = sig_msg_len - tmp_sig_len; - h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n); + h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->xmss_par.hash_alg); //----------------------- diff --git a/xmss_fast.h b/xmss_fast.h index 922f908..c24e736 100644 --- a/xmss_fast.h +++ b/xmss_fast.h @@ -20,6 +20,7 @@ typedef struct{ unsigned int n; unsigned int h; unsigned int k; + unsigned char hash_alg; } xmss_params; typedef struct{ @@ -58,14 +59,14 @@ void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, * Initializes parameter set. * Needed, for any of the other methods. */ -int xmss_set_params(xmss_params *params, int n, int h, int w, int k); +int xmss_set_params(xmss_params *params, int n, int h, int w, int k, unsigned char hash_alg); /** * Initialize xmssmt_params struct * parameter names are the same as in the draft * * Especially h is the total tree height, i.e. the XMSS trees have height h/d */ -int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k); +int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k, unsigned char hash_alg); /** * Generates a XMSS key pair for a given parameter set. * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]