diff --git a/wots.c b/wots.c index 54ca088..a3f9354 100644 --- a/wots.c +++ b/wots.c @@ -1,4 +1,6 @@ #include +#include + #include "xmss_commons.h" #include "hash.h" #include "wots.h" @@ -35,10 +37,10 @@ static void gen_chain(const xmss_params *params, { uint32_t i; - for (i = 0; i < params->n; i++) { - out[i] = in[i]; - } + /* Initialize out with the value at position 'start'. */ + memcpy(out, in, params->n); + /* Iterate 'steps' calls to the hash function. */ for (i = start; i < (start+steps) && i < params->wots_w; i++) { set_hash_addr(addr, i); hash_f(params, out, out, pub_seed, addr); @@ -48,17 +50,18 @@ static void gen_chain(const xmss_params *params, /** * base_w algorithm as described in draft. * Interprets an array of bytes as integers in base w. + * This only works when log_w is a divisor of 8. */ static void base_w(const xmss_params *params, int *output, const int out_len, const unsigned char *input) { int in = 0; int out = 0; - uint8_t total = 0; + unsigned char total; int bits = 0; - int i; + int consumed; - for (i = 0; i < out_len; i++) { + for (consumed = 0; consumed < out_len; consumed++) { if (bits == 0) { total = input[in]; in++; @@ -70,83 +73,99 @@ static void base_w(const xmss_params *params, } } +/* Computes the WOTS+ checksum over a message (in base_w). */ +static void wots_checksum(const xmss_params *params, + int *csum_base_w, const int *msg_base_w) +{ + int csum = 0; + unsigned char csum_bytes[(params->wots_len2 * params->wots_log_w + 7) / 8]; + unsigned int i; + + /* Compute checksum. */ + for (i = 0; i < params->wots_len1; i++) { + csum += params->wots_w - 1 - msg_base_w[i]; + } + + /* Convert checksum to base_w. */ + /* Make sure expected empty zero bits are the least significant bits. */ + csum = csum << (8 - ((params->wots_len2 * params->wots_log_w) % 8)); + ull_to_bytes(csum_bytes, csum, sizeof(csum_bytes)); + base_w(params, csum_base_w, params->wots_len2, csum_bytes); +} + +/* Takes a message and derives the matching chain lengths. */ +static void chain_lengths(const xmss_params *params, + int *lengths, const unsigned char *msg) +{ + base_w(params, lengths, params->wots_len1, msg); + wots_checksum(params, lengths + params->wots_len1, lengths); +} + +/** + * WOTS key generation. Takes a 32 byte seed for the private key, expands it to + * a full WOTS private key and computes the corresponding public key. + * It requires the seed pub_seed (used to generate bitmasks and hash keys) + * and the address of this WOTS key pair. + * + * Writes the computed public key to 'pk'. + */ void wots_pkgen(const xmss_params *params, - unsigned char *pk, const unsigned char *sk, + unsigned char *pk, const unsigned char *seed, const unsigned char *pub_seed, uint32_t addr[8]) { uint32_t i; - expand_seed(params, pk, sk); + /* The WOTS+ private key is derived from the seed. */ + expand_seed(params, pk, seed); + for (i = 0; i < params->wots_len; i++) { set_chain_addr(addr, i); gen_chain(params, pk + i*params->n, pk + i*params->n, - 0, params->wots_w-1, pub_seed, addr); + 0, params->wots_w - 1, pub_seed, addr); } } - +/** + * Takes a m-byte message and the 32-byte seed for the private key to compute a + * signature that is placed at 'sig'. + */ void wots_sign(const xmss_params *params, unsigned char *sig, const unsigned char *msg, - const unsigned char *sk, const unsigned char *pub_seed, + const unsigned char *seed, const unsigned char *pub_seed, uint32_t addr[8]) { - int basew[params->wots_len]; - int csum = 0; - unsigned char csum_bytes[((params->wots_len2 * params->wots_log_w) + 7) / 8]; - int csum_basew[params->wots_len2]; + int lengths[params->wots_len]; uint32_t i; - base_w(params, basew, params->wots_len1, msg); - - for (i = 0; i < params->wots_len1; i++) { - csum += params->wots_w - 1 - basew[i]; - } - - csum = csum << (8 - ((params->wots_len2 * params->wots_log_w) % 8)); + chain_lengths(params, lengths, msg); - ull_to_bytes(csum_bytes, csum, ((params->wots_len2 * params->wots_log_w) + 7) / 8); - base_w(params, csum_basew, params->wots_len2, csum_bytes); - - for (i = 0; i < params->wots_len2; i++) { - basew[params->wots_len1 + i] = csum_basew[i]; - } - - expand_seed(params, sig, sk); + /* The WOTS+ private key is derived from the seed. */ + expand_seed(params, sig, seed); for (i = 0; i < params->wots_len; i++) { set_chain_addr(addr, i); gen_chain(params, sig + i*params->n, sig + i*params->n, - 0, basew[i], pub_seed, addr); + 0, lengths[i], pub_seed, addr); } } +/** + * Takes a WOTS signature and an m-byte message, computes a WOTS public key. + * + * Writes the computed public key to 'pk'. + */ void wots_pk_from_sig(const xmss_params *params, unsigned char *pk, const unsigned char *sig, const unsigned char *msg, const unsigned char *pub_seed, uint32_t addr[8]) { - int basew[params->wots_len]; - int csum = 0; - unsigned char csum_bytes[((params->wots_len2 * params->wots_log_w) + 7) / 8]; - int csum_basew[params->wots_len2]; - uint32_t i = 0; - - base_w(params, basew, params->wots_len1, msg); - - for (i=0; i < params->wots_len1; i++) { - csum += params->wots_w - 1 - basew[i]; - } - - csum = csum << (8 - ((params->wots_len2 * params->wots_log_w) % 8)); + int lengths[params->wots_len]; + uint32_t i; - ull_to_bytes(csum_bytes, csum, ((params->wots_len2 * params->wots_log_w) + 7) / 8); - base_w(params, csum_basew, params->wots_len2, csum_bytes); + chain_lengths(params, lengths, msg); - for (i = 0; i < params->wots_len2; i++) { - basew[params->wots_len1 + i] = csum_basew[i]; - } - for (i=0; i < params->wots_len; i++) { + for (i = 0; i < params->wots_len; i++) { set_chain_addr(addr, i); gen_chain(params, pk + i*params->n, sig + i*params->n, - basew[i], params->wots_w-1-basew[i], pub_seed, addr); + lengths[i], params->wots_w - 1 - lengths[i], pub_seed, addr); } } diff --git a/wots.h b/wots.h index 61f0b8b..f0866ac 100644 --- a/wots.h +++ b/wots.h @@ -5,24 +5,24 @@ #include "params.h" /** - * WOTS key generation. Takes a 32 byte seed for the secret key, expands it to - * a full WOTS secret key and computes the corresponding public key. + * WOTS key generation. Takes a 32 byte seed for the private key, expands it to + * a full WOTS private key and computes the corresponding public key. * It requires the seed pub_seed (used to generate bitmasks and hash keys) * and the address of this WOTS key pair. * * Writes the computed public key to 'pk'. */ void wots_pkgen(const xmss_params *params, - unsigned char *pk, const unsigned char *sk, + unsigned char *pk, const unsigned char *seed, const unsigned char *pub_seed, uint32_t addr[8]); /** - * Takes a m-byte message and the 32-byte seed for the secret key to compute a + * Takes a m-byte message and the 32-byte seed for the private key to compute a * signature that is placed at 'sig'. */ void wots_sign(const xmss_params *params, unsigned char *sig, const unsigned char *msg, - const unsigned char *sk, const unsigned char *pub_seed, + const unsigned char *seed, const unsigned char *pub_seed, uint32_t addr[8]); /**