You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

241 lines
9.6 KiB

  1. #include <stdint.h>
  2. #include <string.h>
  3. #include "address.h"
  4. #include "hash.h"
  5. #include "hash_state.h"
  6. #include "hashx4.h"
  7. #include "params.h"
  8. #include "thash.h"
  9. #include "thashx4.h"
  10. #include "utils.h"
  11. #include "wots.h"
  12. // TODO clarify address expectations, and make them more uniform.
  13. // TODO i.e. do we expect types to be set already?
  14. // TODO and do we expect modifications or copies?
  15. /**
  16. * Computes the starting value for a chain, i.e. the secret key.
  17. * Expects the address to be complete up to the chain address.
  18. */
  19. static void wots_gen_sk(unsigned char *sk, const unsigned char *sk_seed,
  20. uint32_t wots_addr[8], const hash_state *state_seeded) {
  21. /* Make sure that the hash address is actually zeroed. */
  22. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_hash_addr(wots_addr, 0);
  23. /* Generate sk element. */
  24. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_prf_addr(sk, sk_seed, wots_addr, state_seeded);
  25. }
  26. /**
  27. * 4-way parallel version of wots_gen_sk; expects 4x as much space in sk
  28. */
  29. static void wots_gen_skx4(unsigned char *skx4, const unsigned char *sk_seed,
  30. uint32_t wots_addrx4[4 * 8], const hash_state *state_seeded) {
  31. unsigned int j;
  32. /* Make sure that the hash address is actually zeroed. */
  33. for (j = 0; j < 4; j++) {
  34. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_hash_addr(wots_addrx4 + j * 8, 0);
  35. }
  36. /* Generate sk element. */
  37. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_prf_addrx4(skx4 + 0 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  38. skx4 + 1 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  39. skx4 + 2 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  40. skx4 + 3 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  41. sk_seed, wots_addrx4,
  42. state_seeded);
  43. }
  44. /**
  45. * Computes the chaining function.
  46. * out and in have to be n-byte arrays.
  47. *
  48. * Interprets in as start-th value of the chain.
  49. * addr has to contain the address of the chain.
  50. */
  51. static void gen_chain(unsigned char *out, const unsigned char *in,
  52. unsigned int start, unsigned int steps,
  53. const unsigned char *pub_seed, uint32_t addr[8],
  54. const hash_state *state_seeded) {
  55. uint32_t i;
  56. /* Initialize out with the value at position 'start'. */
  57. memcpy(out, in, PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N);
  58. /* Iterate 'steps' calls to the hash function. */
  59. for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W; i++) {
  60. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_hash_addr(addr, i);
  61. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_thash_1(out, out, pub_seed, addr, state_seeded);
  62. }
  63. }
  64. /**
  65. * 4-way parallel version of gen_chain; expects 4x as much space in out, and
  66. * 4x as much space in inx4. Assumes start and step identical across chains.
  67. */
  68. static void gen_chainx4(unsigned char *outx4, const unsigned char *inx4,
  69. unsigned int start, unsigned int steps,
  70. const unsigned char *pub_seed, uint32_t addrx4[4 * 8],
  71. const hash_state *state_seeded) {
  72. uint32_t i;
  73. unsigned int j;
  74. /* Initialize outx4 with the value at position 'start'. */
  75. memcpy(outx4, inx4, 4 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N);
  76. /* Iterate 'steps' calls to the hash function. */
  77. for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W; i++) {
  78. for (j = 0; j < 4; j++) {
  79. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_hash_addr(addrx4 + j * 8, i);
  80. }
  81. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_thashx4_1(outx4 + 0 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  82. outx4 + 1 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  83. outx4 + 2 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  84. outx4 + 3 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  85. outx4 + 0 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  86. outx4 + 1 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  87. outx4 + 2 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  88. outx4 + 3 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  89. pub_seed, addrx4,
  90. state_seeded);
  91. }
  92. }
  93. /**
  94. * base_w algorithm as described in draft.
  95. * Interprets an array of bytes as integers in base w.
  96. * This only works when log_w is a divisor of 8.
  97. */
  98. static void base_w(unsigned int *output, const int out_len, const unsigned char *input) {
  99. int in = 0;
  100. int out = 0;
  101. unsigned char total = 0;
  102. int bits = 0;
  103. int consumed;
  104. for (consumed = 0; consumed < out_len; consumed++) {
  105. if (bits == 0) {
  106. total = input[in];
  107. in++;
  108. bits += 8;
  109. }
  110. bits -= PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LOGW;
  111. output[out] = (unsigned int)(total >> bits) & (PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W - 1);
  112. out++;
  113. }
  114. }
  115. /* Computes the WOTS+ checksum over a message (in base_w). */
  116. static void wots_checksum(unsigned int *csum_base_w, const unsigned int *msg_base_w) {
  117. unsigned int csum = 0;
  118. unsigned char csum_bytes[(PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LOGW + 7) / 8];
  119. unsigned int i;
  120. /* Compute checksum. */
  121. for (i = 0; i < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN1; i++) {
  122. csum += PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W - 1 - msg_base_w[i];
  123. }
  124. /* Convert checksum to base_w. */
  125. /* Make sure expected empty zero bits are the least significant bits. */
  126. csum = csum << (8 - ((PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LOGW) % 8));
  127. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_ull_to_bytes(csum_bytes, sizeof(csum_bytes), csum);
  128. base_w(csum_base_w, PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN2, csum_bytes);
  129. }
  130. /* Takes a message and derives the matching chain lengths. */
  131. static void chain_lengths(unsigned int *lengths, const unsigned char *msg) {
  132. base_w(lengths, PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN1, msg);
  133. wots_checksum(lengths + PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN1, lengths);
  134. }
  135. /**
  136. * WOTS key generation. Takes a 32 byte sk_seed, expands it to WOTS private key
  137. * elements and computes the corresponding public key.
  138. * It requires the seed pub_seed (used to generate bitmasks and hash keys)
  139. * and the address of this WOTS key pair.
  140. *
  141. * Writes the computed public key to 'pk'.
  142. */
  143. void PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_wots_gen_pk(unsigned char *pk, const unsigned char *sk_seed,
  144. const unsigned char *pub_seed, uint32_t addr[8],
  145. const hash_state *state_seeded) {
  146. uint32_t i;
  147. unsigned int j;
  148. uint32_t addrx4[4 * 8];
  149. unsigned char pkbuf[4 * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N];
  150. for (j = 0; j < 4; j++) {
  151. memcpy(addrx4 + j * 8, addr, sizeof(uint32_t) * 8);
  152. }
  153. /* The last iteration typically does not have complete set of 4 chains,
  154. but because we use pkbuf, this is not an issue -- we still do as many
  155. in parallel as possible. */
  156. for (i = 0; i < ((PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN + 3) & ~0x3); i += 4) {
  157. for (j = 0; j < 4; j++) {
  158. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_chain_addr(addrx4 + j * 8, i + j);
  159. }
  160. wots_gen_skx4(pkbuf, sk_seed, addrx4, state_seeded);
  161. gen_chainx4(pkbuf, pkbuf, 0, PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W - 1, pub_seed, addrx4, state_seeded);
  162. for (j = 0; j < 4; j++) {
  163. if (i + j < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN) {
  164. memcpy(pk + (i + j)*PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, pkbuf + j * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N);
  165. }
  166. }
  167. }
  168. // Get rid of unused argument variable.
  169. (void)state_seeded;
  170. }
  171. /**
  172. * Takes a n-byte message and the 32-byte sk_see to compute a signature 'sig'.
  173. */
  174. void PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_wots_sign(unsigned char *sig, const unsigned char *msg,
  175. const unsigned char *sk_seed, const unsigned char *pub_seed,
  176. uint32_t addr[8], const hash_state *state_seeded) {
  177. unsigned int lengths[PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN];
  178. uint32_t i;
  179. chain_lengths(lengths, msg);
  180. for (i = 0; i < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN; i++) {
  181. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_chain_addr(addr, i);
  182. wots_gen_sk(sig + i * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, sk_seed, addr, state_seeded);
  183. gen_chain(sig + i * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, sig + i * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, 0, lengths[i], pub_seed, addr, state_seeded);
  184. }
  185. // avoid unused argument
  186. (void)state_seeded;
  187. }
  188. /**
  189. * Takes a WOTS signature and an n-byte message, computes a WOTS public key.
  190. *
  191. * Writes the computed public key to 'pk'.
  192. */
  193. void PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_wots_pk_from_sig(unsigned char *pk,
  194. const unsigned char *sig, const unsigned char *msg,
  195. const unsigned char *pub_seed, uint32_t addr[8],
  196. const hash_state *state_seeded) {
  197. unsigned int lengths[PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN];
  198. uint32_t i;
  199. chain_lengths(lengths, msg);
  200. for (i = 0; i < PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_LEN; i++) {
  201. PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_set_chain_addr(addr, i);
  202. gen_chain(pk + i * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N, sig + i * PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_N,
  203. lengths[i], PQCLEAN_SPHINCSSHAKE256256FROBUST_AVX2_WOTS_W - 1 - lengths[i], pub_seed, addr,
  204. state_seeded);
  205. }
  206. // avoid unused argument
  207. (void)state_seeded;
  208. }