You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1099 lines
32 KiB

  1. /*
  2. xmss_fast.c version 20160722
  3. Andreas Hülsing
  4. Joost Rijneveld
  5. Public domain.
  6. */
  7. #include "xmss_fast.h"
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #include <stdint.h>
  11. #include <math.h>
  12. #include "randombytes.h"
  13. #include "wots.h"
  14. #include "hash.h"
  15. #include "xmss_commons.h"
  16. #include "hash_address.h"
  17. // For testing
  18. #include "stdio.h"
  19. /**
  20. * Used for pseudorandom keygeneration,
  21. * generates the seed for the WOTS keypair at address addr
  22. *
  23. * takes n byte sk_seed and returns n byte seed using 32 byte address addr.
  24. */
  25. static void get_seed(unsigned char *seed, const unsigned char *sk_seed, int n, uint32_t addr[8], unsigned char hash_alg)
  26. {
  27. unsigned char bytes[32];
  28. // Make sure that chain addr, hash addr, and key bit are 0!
  29. setChainADRS(addr,0);
  30. setHashADRS(addr,0);
  31. setKeyAndMask(addr,0);
  32. // Generate pseudorandom value
  33. addr_to_byte(bytes, addr);
  34. prf(seed, bytes, sk_seed, n, hash_alg);
  35. }
  36. /**
  37. * Initialize xmss params struct
  38. * parameter names are the same as in the draft
  39. * parameter k is K as used in the BDS algorithm
  40. */
  41. int xmss_set_params(xmss_params *params, int n, int h, int w, int k, unsigned char hash_alg)
  42. {
  43. if (k >= h || k < 2 || (h - k) % 2) {
  44. fprintf(stderr, "For BDS traversal, H - K must be even, with H > K >= 2!\n");
  45. return 1;
  46. }
  47. params->h = h;
  48. params->n = n;
  49. params->k = k;
  50. wots_params wots_par;
  51. wots_set_params(&wots_par, n, w, hash_alg);
  52. params->wots_par = wots_par;
  53. params->hash_alg = hash_alg;
  54. return 0;
  55. }
  56. /**
  57. * Initialize BDS state struct
  58. * parameter names are the same as used in the description of the BDS traversal
  59. */
  60. void xmss_set_bds_state(bds_state *state, unsigned char *stack, int stackoffset, unsigned char *stacklevels, unsigned char *auth, unsigned char *keep, treehash_inst *treehash, unsigned char *retain, int next_leaf)
  61. {
  62. state->stack = stack;
  63. state->stackoffset = stackoffset;
  64. state->stacklevels = stacklevels;
  65. state->auth = auth;
  66. state->keep = keep;
  67. state->treehash = treehash;
  68. state->retain = retain;
  69. state->next_leaf = next_leaf;
  70. }
  71. /**
  72. * Initialize xmssmt_params struct
  73. * parameter names are the same as in the draft
  74. *
  75. * Especially h is the total tree height, i.e. the XMSS trees have height h/d
  76. */
  77. int xmssmt_set_params(xmssmt_params *params, int n, int h, int d, int w, int k, unsigned char hash_alg)
  78. {
  79. if (h % d) {
  80. fprintf(stderr, "d must divide h without remainder!\n");
  81. return 1;
  82. }
  83. params->h = h;
  84. params->d = d;
  85. params->n = n;
  86. params->index_len = (h + 7) / 8;
  87. xmss_params xmss_par;
  88. if (xmss_set_params(&xmss_par, n, (h/d), w, k, hash_alg)) {
  89. return 1;
  90. }
  91. params->xmss_par = xmss_par;
  92. return 0;
  93. }
  94. /**
  95. * Computes a leaf from a WOTS public key using an L-tree.
  96. */
  97. static void l_tree(unsigned char *leaf, unsigned char *wots_pk, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
  98. {
  99. unsigned int l = params->wots_par.len;
  100. unsigned int n = params->n;
  101. uint32_t i = 0;
  102. uint32_t height = 0;
  103. uint32_t bound;
  104. //ADRS.setTreeHeight(0);
  105. setTreeHeight(addr, height);
  106. while (l > 1) {
  107. bound = l >> 1; //floor(l / 2);
  108. for (i = 0; i < bound; i++) {
  109. //ADRS.setTreeIndex(i);
  110. setTreeIndex(addr, i);
  111. //wots_pk[i] = RAND_HASH(pk[2i], pk[2i + 1], SEED, ADRS);
  112. hash_h(wots_pk+i*n, wots_pk+i*2*n, pub_seed, addr, n, params->hash_alg);
  113. }
  114. //if ( l % 2 == 1 ) {
  115. if (l & 1) {
  116. //pk[floor(l / 2) + 1] = pk[l];
  117. memcpy(wots_pk+(l>>1)*n, wots_pk+(l-1)*n, n);
  118. //l = ceil(l / 2);
  119. l=(l>>1)+1;
  120. }
  121. else {
  122. //l = ceil(l / 2);
  123. l=(l>>1);
  124. }
  125. //ADRS.setTreeHeight(ADRS.getTreeHeight() + 1);
  126. height++;
  127. setTreeHeight(addr, height);
  128. }
  129. //return pk[0];
  130. memcpy(leaf, wots_pk, n);
  131. }
  132. /**
  133. * Computes the leaf at a given address. First generates the WOTS key pair, then computes leaf using l_tree. As this happens position independent, we only require that addr encodes the right ltree-address.
  134. */
  135. static void gen_leaf_wots(unsigned char *leaf, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, uint32_t ltree_addr[8], uint32_t ots_addr[8])
  136. {
  137. unsigned char seed[params->n];
  138. unsigned char pk[params->wots_par.keysize];
  139. get_seed(seed, sk_seed, params->n, ots_addr, params->hash_alg);
  140. wots_pkgen(pk, seed, &(params->wots_par), pub_seed, ots_addr);
  141. l_tree(leaf, pk, params, pub_seed, ltree_addr);
  142. }
  143. static int treehash_minheight_on_stack(bds_state* state, const xmss_params *params, const treehash_inst *treehash) {
  144. unsigned int r = params->h, i;
  145. for (i = 0; i < treehash->stackusage; i++) {
  146. if (state->stacklevels[state->stackoffset - i - 1] < r) {
  147. r = state->stacklevels[state->stackoffset - i - 1];
  148. }
  149. }
  150. return r;
  151. }
  152. /**
  153. * Merkle's TreeHash algorithm. The address only needs to initialize the first 78 bits of addr. Everything else will be set by treehash.
  154. * Currently only used for key generation.
  155. *
  156. */
  157. static void treehash_setup(unsigned char *node, int height, int index, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8])
  158. {
  159. unsigned int idx = index;
  160. unsigned int n = params->n;
  161. unsigned int h = params->h;
  162. unsigned int k = params->k;
  163. // use three different addresses because at this point we use all three formats in parallel
  164. uint32_t ots_addr[8];
  165. uint32_t ltree_addr[8];
  166. uint32_t node_addr[8];
  167. // only copy layer and tree address parts
  168. memcpy(ots_addr, addr, 12);
  169. // type = ots
  170. setType(ots_addr, 0);
  171. memcpy(ltree_addr, addr, 12);
  172. setType(ltree_addr, 1);
  173. memcpy(node_addr, addr, 12);
  174. setType(node_addr, 2);
  175. uint32_t lastnode, i;
  176. unsigned char stack[(height+1)*n];
  177. unsigned int stacklevels[height+1];
  178. unsigned int stackoffset=0;
  179. unsigned int nodeh;
  180. lastnode = idx+(1<<height);
  181. for (i = 0; i < h-k; i++) {
  182. state->treehash[i].h = i;
  183. state->treehash[i].completed = 1;
  184. state->treehash[i].stackusage = 0;
  185. }
  186. i = 0;
  187. for (; idx < lastnode; idx++) {
  188. setLtreeADRS(ltree_addr, idx);
  189. setOTSADRS(ots_addr, idx);
  190. gen_leaf_wots(stack+stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
  191. stacklevels[stackoffset] = 0;
  192. stackoffset++;
  193. if (h - k > 0 && i == 3) {
  194. memcpy(state->treehash[0].node, stack+stackoffset*n, n);
  195. }
  196. while (stackoffset>1 && stacklevels[stackoffset-1] == stacklevels[stackoffset-2])
  197. {
  198. nodeh = stacklevels[stackoffset-1];
  199. if (i >> nodeh == 1) {
  200. memcpy(state->auth + nodeh*n, stack+(stackoffset-1)*n, n);
  201. }
  202. else {
  203. if (nodeh < h - k && i >> nodeh == 3) {
  204. memcpy(state->treehash[nodeh].node, stack+(stackoffset-1)*n, n);
  205. }
  206. else if (nodeh >= h - k) {
  207. memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((i >> nodeh) - 3) >> 1)) * n, stack+(stackoffset-1)*n, n);
  208. }
  209. }
  210. setTreeHeight(node_addr, stacklevels[stackoffset-1]);
  211. setTreeIndex(node_addr, (idx >> (stacklevels[stackoffset-1]+1)));
  212. hash_h(stack+(stackoffset-2)*n, stack+(stackoffset-2)*n, pub_seed,
  213. node_addr, n, params->hash_alg);
  214. stacklevels[stackoffset-2]++;
  215. stackoffset--;
  216. }
  217. i++;
  218. }
  219. for (i = 0; i < n; i++)
  220. node[i] = stack[i];
  221. }
  222. static void treehash_update(treehash_inst *treehash, bds_state *state, const unsigned char *sk_seed, const xmss_params *params, const unsigned char *pub_seed, const uint32_t addr[8]) {
  223. int n = params->n;
  224. uint32_t ots_addr[8];
  225. uint32_t ltree_addr[8];
  226. uint32_t node_addr[8];
  227. // only copy layer and tree address parts
  228. memcpy(ots_addr, addr, 12);
  229. // type = ots
  230. setType(ots_addr, 0);
  231. memcpy(ltree_addr, addr, 12);
  232. setType(ltree_addr, 1);
  233. memcpy(node_addr, addr, 12);
  234. setType(node_addr, 2);
  235. setLtreeADRS(ltree_addr, treehash->next_idx);
  236. setOTSADRS(ots_addr, treehash->next_idx);
  237. unsigned char nodebuffer[2 * n];
  238. unsigned int nodeheight = 0;
  239. gen_leaf_wots(nodebuffer, sk_seed, params, pub_seed, ltree_addr, ots_addr);
  240. while (treehash->stackusage > 0 && state->stacklevels[state->stackoffset-1] == nodeheight) {
  241. memcpy(nodebuffer + n, nodebuffer, n);
  242. memcpy(nodebuffer, state->stack + (state->stackoffset-1)*n, n);
  243. setTreeHeight(node_addr, nodeheight);
  244. setTreeIndex(node_addr, (treehash->next_idx >> (nodeheight+1)));
  245. hash_h(nodebuffer, nodebuffer, pub_seed, node_addr, n, params->hash_alg);
  246. nodeheight++;
  247. treehash->stackusage--;
  248. state->stackoffset--;
  249. }
  250. if (nodeheight == treehash->h) { // this also implies stackusage == 0
  251. memcpy(treehash->node, nodebuffer, n);
  252. treehash->completed = 1;
  253. }
  254. else {
  255. memcpy(state->stack + state->stackoffset*n, nodebuffer, n);
  256. treehash->stackusage++;
  257. state->stacklevels[state->stackoffset] = nodeheight;
  258. state->stackoffset++;
  259. treehash->next_idx++;
  260. }
  261. }
  262. /**
  263. * Computes a root node given a leaf and an authapth
  264. */
  265. static void validate_authpath(unsigned char *root, const unsigned char *leaf, unsigned long leafidx, const unsigned char *authpath, const xmss_params *params, const unsigned char *pub_seed, uint32_t addr[8])
  266. {
  267. unsigned int n = params->n;
  268. uint32_t i, j;
  269. unsigned char buffer[2*n];
  270. // If leafidx is odd (last bit = 1), current path element is a right child and authpath has to go to the left.
  271. // Otherwise, it is the other way around
  272. if (leafidx & 1) {
  273. for (j = 0; j < n; j++)
  274. buffer[n+j] = leaf[j];
  275. for (j = 0; j < n; j++)
  276. buffer[j] = authpath[j];
  277. }
  278. else {
  279. for (j = 0; j < n; j++)
  280. buffer[j] = leaf[j];
  281. for (j = 0; j < n; j++)
  282. buffer[n+j] = authpath[j];
  283. }
  284. authpath += n;
  285. for (i=0; i < params->h-1; i++) {
  286. setTreeHeight(addr, i);
  287. leafidx >>= 1;
  288. setTreeIndex(addr, leafidx);
  289. if (leafidx&1) {
  290. hash_h(buffer+n, buffer, pub_seed, addr, n, params->hash_alg);
  291. for (j = 0; j < n; j++)
  292. buffer[j] = authpath[j];
  293. }
  294. else {
  295. hash_h(buffer, buffer, pub_seed, addr, n, params->hash_alg);
  296. for (j = 0; j < n; j++)
  297. buffer[j+n] = authpath[j];
  298. }
  299. authpath += n;
  300. }
  301. setTreeHeight(addr, (params->h-1));
  302. leafidx >>= 1;
  303. setTreeIndex(addr, leafidx);
  304. hash_h(root, buffer, pub_seed, addr, n, params->hash_alg);
  305. }
  306. /**
  307. * Performs one treehash update on the instance that needs it the most.
  308. * Returns 1 if such an instance was not found
  309. **/
  310. static char bds_treehash_update(bds_state *state, unsigned int updates, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
  311. uint32_t i, j;
  312. unsigned int level, l_min, low;
  313. unsigned int h = params->h;
  314. unsigned int k = params->k;
  315. unsigned int used = 0;
  316. for (j = 0; j < updates; j++) {
  317. l_min = h;
  318. level = h - k;
  319. for (i = 0; i < h - k; i++) {
  320. if (state->treehash[i].completed) {
  321. low = h;
  322. }
  323. else if (state->treehash[i].stackusage == 0) {
  324. low = i;
  325. }
  326. else {
  327. low = treehash_minheight_on_stack(state, params, &(state->treehash[i]));
  328. }
  329. if (low < l_min) {
  330. level = i;
  331. l_min = low;
  332. }
  333. }
  334. if (level == h - k) {
  335. break;
  336. }
  337. treehash_update(&(state->treehash[level]), state, sk_seed, params, pub_seed, addr);
  338. used++;
  339. }
  340. return updates - used;
  341. }
  342. /**
  343. * Updates the state (typically NEXT_i) by adding a leaf and updating the stack
  344. * Returns 1 if all leaf nodes have already been processed
  345. **/
  346. static char bds_state_update(bds_state *state, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, const uint32_t addr[8]) {
  347. uint32_t ltree_addr[8];
  348. uint32_t node_addr[8];
  349. uint32_t ots_addr[8];
  350. int n = params->n;
  351. int h = params->h;
  352. int k = params->k;
  353. int nodeh;
  354. int idx = state->next_leaf;
  355. if (idx == 1 << h) {
  356. return 1;
  357. }
  358. // only copy layer and tree address parts
  359. memcpy(ots_addr, addr, 12);
  360. // type = ots
  361. setType(ots_addr, 0);
  362. memcpy(ltree_addr, addr, 12);
  363. setType(ltree_addr, 1);
  364. memcpy(node_addr, addr, 12);
  365. setType(node_addr, 2);
  366. setOTSADRS(ots_addr, idx);
  367. setLtreeADRS(ltree_addr, idx);
  368. gen_leaf_wots(state->stack+state->stackoffset*n, sk_seed, params, pub_seed, ltree_addr, ots_addr);
  369. state->stacklevels[state->stackoffset] = 0;
  370. state->stackoffset++;
  371. if (h - k > 0 && idx == 3) {
  372. memcpy(state->treehash[0].node, state->stack+state->stackoffset*n, n);
  373. }
  374. while (state->stackoffset>1 && state->stacklevels[state->stackoffset-1] == state->stacklevels[state->stackoffset-2]) {
  375. nodeh = state->stacklevels[state->stackoffset-1];
  376. if (idx >> nodeh == 1) {
  377. memcpy(state->auth + nodeh*n, state->stack+(state->stackoffset-1)*n, n);
  378. }
  379. else {
  380. if (nodeh < h - k && idx >> nodeh == 3) {
  381. memcpy(state->treehash[nodeh].node, state->stack+(state->stackoffset-1)*n, n);
  382. }
  383. else if (nodeh >= h - k) {
  384. memcpy(state->retain + ((1 << (h - 1 - nodeh)) + nodeh - h + (((idx >> nodeh) - 3) >> 1)) * n, state->stack+(state->stackoffset-1)*n, n);
  385. }
  386. }
  387. setTreeHeight(node_addr, state->stacklevels[state->stackoffset-1]);
  388. setTreeIndex(node_addr, (idx >> (state->stacklevels[state->stackoffset-1]+1)));
  389. hash_h(state->stack+(state->stackoffset-2)*n, state->stack+(state->stackoffset-2)*n, pub_seed, node_addr, n, params->hash_alg);
  390. state->stacklevels[state->stackoffset-2]++;
  391. state->stackoffset--;
  392. }
  393. state->next_leaf++;
  394. return 0;
  395. }
  396. /**
  397. * Returns the auth path for node leaf_idx and computes the auth path for the
  398. * next leaf node, using the algorithm described by Buchmann, Dahmen and Szydlo
  399. * in "Post Quantum Cryptography", Springer 2009.
  400. */
  401. static void bds_round(bds_state *state, const unsigned long leaf_idx, const unsigned char *sk_seed, const xmss_params *params, unsigned char *pub_seed, uint32_t addr[8])
  402. {
  403. unsigned int i;
  404. unsigned int n = params->n;
  405. unsigned int h = params->h;
  406. unsigned int k = params->k;
  407. unsigned int tau = h;
  408. unsigned int startidx;
  409. unsigned int offset, rowidx;
  410. unsigned char buf[2 * n];
  411. uint32_t ots_addr[8];
  412. uint32_t ltree_addr[8];
  413. uint32_t node_addr[8];
  414. // only copy layer and tree address parts
  415. memcpy(ots_addr, addr, 12);
  416. // type = ots
  417. setType(ots_addr, 0);
  418. memcpy(ltree_addr, addr, 12);
  419. setType(ltree_addr, 1);
  420. memcpy(node_addr, addr, 12);
  421. setType(node_addr, 2);
  422. for (i = 0; i < h; i++) {
  423. if (! ((leaf_idx >> i) & 1)) {
  424. tau = i;
  425. break;
  426. }
  427. }
  428. if (tau > 0) {
  429. memcpy(buf, state->auth + (tau-1) * n, n);
  430. // we need to do this before refreshing state->keep to prevent overwriting
  431. memcpy(buf + n, state->keep + ((tau-1) >> 1) * n, n);
  432. }
  433. if (!((leaf_idx >> (tau + 1)) & 1) && (tau < h - 1)) {
  434. memcpy(state->keep + (tau >> 1)*n, state->auth + tau*n, n);
  435. }
  436. if (tau == 0) {
  437. setLtreeADRS(ltree_addr, leaf_idx);
  438. setOTSADRS(ots_addr, leaf_idx);
  439. gen_leaf_wots(state->auth, sk_seed, params, pub_seed, ltree_addr, ots_addr);
  440. }
  441. else {
  442. setTreeHeight(node_addr, (tau-1));
  443. setTreeIndex(node_addr, leaf_idx >> tau);
  444. hash_h(state->auth + tau * n, buf, pub_seed, node_addr, n, params->hash_alg);
  445. for (i = 0; i < tau; i++) {
  446. if (i < h - k) {
  447. memcpy(state->auth + i * n, state->treehash[i].node, n);
  448. }
  449. else {
  450. offset = (1 << (h - 1 - i)) + i - h;
  451. rowidx = ((leaf_idx >> i) - 1) >> 1;
  452. memcpy(state->auth + i * n, state->retain + (offset + rowidx) * n, n);
  453. }
  454. }
  455. for (i = 0; i < ((tau < h - k) ? tau : (h - k)); i++) {
  456. startidx = leaf_idx + 1 + 3 * (1 << i);
  457. if (startidx < 1U << h) {
  458. state->treehash[i].h = i;
  459. state->treehash[i].next_idx = startidx;
  460. state->treehash[i].completed = 0;
  461. state->treehash[i].stackusage = 0;
  462. }
  463. }
  464. }
  465. }
  466. /*
  467. * Generates a XMSS key pair for a given parameter set.
  468. * Format sk: [(32bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
  469. * Format pk: [root || PUB_SEED] omitting algo oid.
  470. */
  471. int xmss_keypair(unsigned char *pk, unsigned char *sk, bds_state *state, xmss_params *params)
  472. {
  473. unsigned int n = params->n;
  474. // Set idx = 0
  475. sk[0] = 0;
  476. sk[1] = 0;
  477. sk[2] = 0;
  478. sk[3] = 0;
  479. // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
  480. randombytes(sk+4, 3*n);
  481. // Copy PUB_SEED to public key
  482. memcpy(pk+n, sk+4+2*n, n);
  483. uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  484. // Compute root
  485. treehash_setup(pk, params->h, 0, state, sk+4, params, sk+4+2*n, addr);
  486. // copy root to sk
  487. memcpy(sk+4+3*n, pk, n);
  488. return 0;
  489. }
  490. /**
  491. * Signs a message.
  492. * Returns
  493. * 1. an array containing the signature followed by the message AND
  494. * 2. an updated secret key!
  495. *
  496. */
  497. int xmss_sign(unsigned char *sk, bds_state *state, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmss_params *params)
  498. {
  499. unsigned int h = params->h;
  500. unsigned int n = params->n;
  501. unsigned int k = params->k;
  502. uint16_t i = 0;
  503. // Extract SK
  504. unsigned long idx = ((unsigned long)sk[0] << 24) | ((unsigned long)sk[1] << 16) | ((unsigned long)sk[2] << 8) | sk[3];
  505. unsigned char sk_seed[n];
  506. memcpy(sk_seed, sk+4, n);
  507. unsigned char sk_prf[n];
  508. memcpy(sk_prf, sk+4+n, n);
  509. unsigned char pub_seed[n];
  510. memcpy(pub_seed, sk+4+2*n, n);
  511. // index as 32 bytes string
  512. unsigned char idx_bytes_32[32];
  513. to_byte(idx_bytes_32, idx, 32);
  514. unsigned char hash_key[3*n];
  515. // Update SK
  516. sk[0] = ((idx + 1) >> 24) & 255;
  517. sk[1] = ((idx + 1) >> 16) & 255;
  518. sk[2] = ((idx + 1) >> 8) & 255;
  519. sk[3] = (idx + 1) & 255;
  520. // -- Secret key for this non-forward-secure version is now updated.
  521. // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
  522. // Init working params
  523. unsigned char R[n];
  524. unsigned char msg_h[n];
  525. unsigned char ots_seed[n];
  526. uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  527. // ---------------------------------
  528. // Message Hashing
  529. // ---------------------------------
  530. // Message Hash:
  531. // First compute pseudorandom value
  532. prf(R, idx_bytes_32, sk_prf, n, params->hash_alg);
  533. // Generate hash key (R || root || idx)
  534. memcpy(hash_key, R, n);
  535. memcpy(hash_key+n, sk+4+3*n, n);
  536. to_byte(hash_key+2*n, idx, n);
  537. // Then use it for message digest
  538. h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->hash_alg);
  539. // Start collecting signature
  540. *sig_msg_len = 0;
  541. // Copy index to signature
  542. sig_msg[0] = (idx >> 24) & 255;
  543. sig_msg[1] = (idx >> 16) & 255;
  544. sig_msg[2] = (idx >> 8) & 255;
  545. sig_msg[3] = idx & 255;
  546. sig_msg += 4;
  547. *sig_msg_len += 4;
  548. // Copy R to signature
  549. for (i = 0; i < n; i++)
  550. sig_msg[i] = R[i];
  551. sig_msg += n;
  552. *sig_msg_len += n;
  553. // ----------------------------------
  554. // Now we start to "really sign"
  555. // ----------------------------------
  556. // Prepare Address
  557. setType(ots_addr, 0);
  558. setOTSADRS(ots_addr, idx);
  559. // Compute seed for OTS key pair
  560. get_seed(ots_seed, sk_seed, n, ots_addr, params->hash_alg);
  561. // Compute WOTS signature
  562. wots_sign(sig_msg, msg_h, ots_seed, &(params->wots_par), pub_seed, ots_addr);
  563. sig_msg += params->wots_par.keysize;
  564. *sig_msg_len += params->wots_par.keysize;
  565. // the auth path was already computed during the previous round
  566. memcpy(sig_msg, state->auth, h*n);
  567. if (idx < (1U << h) - 1) {
  568. bds_round(state, idx, sk_seed, params, pub_seed, ots_addr);
  569. bds_treehash_update(state, (h - k) >> 1, sk_seed, params, pub_seed, ots_addr);
  570. }
  571. sig_msg += params->h*n;
  572. *sig_msg_len += params->h*n;
  573. //Whipe secret elements?
  574. //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
  575. memcpy(sig_msg, msg, msglen);
  576. *sig_msg_len += msglen;
  577. return 0;
  578. }
  579. /**
  580. * Verifies a given message signature pair under a given public key.
  581. */
  582. int xmss_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmss_params *params)
  583. {
  584. unsigned int n = params->n;
  585. unsigned long long i, m_len;
  586. unsigned long idx=0;
  587. unsigned char wots_pk[params->wots_par.keysize];
  588. unsigned char pkhash[n];
  589. unsigned char root[n];
  590. unsigned char msg_h[n];
  591. unsigned char hash_key[3*n];
  592. unsigned char pub_seed[n];
  593. memcpy(pub_seed, pk+n, n);
  594. // Init addresses
  595. uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  596. uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  597. uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  598. setType(ots_addr, 0);
  599. setType(ltree_addr, 1);
  600. setType(node_addr, 2);
  601. // Extract index
  602. idx = ((unsigned long)sig_msg[0] << 24) | ((unsigned long)sig_msg[1] << 16) | ((unsigned long)sig_msg[2] << 8) | sig_msg[3];
  603. printf("verify:: idx = %lu\n", idx);
  604. // Generate hash key (R || root || idx)
  605. memcpy(hash_key, sig_msg+4,n);
  606. memcpy(hash_key+n, pk, n);
  607. to_byte(hash_key+2*n, idx, n);
  608. sig_msg += (n+4);
  609. sig_msg_len -= (n+4);
  610. // hash message
  611. unsigned long long tmp_sig_len = params->wots_par.keysize+params->h*n;
  612. m_len = sig_msg_len - tmp_sig_len;
  613. h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->hash_alg);
  614. //-----------------------
  615. // Verify signature
  616. //-----------------------
  617. // Prepare Address
  618. setOTSADRS(ots_addr, idx);
  619. // Check WOTS signature
  620. wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->wots_par), pub_seed, ots_addr);
  621. sig_msg += params->wots_par.keysize;
  622. sig_msg_len -= params->wots_par.keysize;
  623. // Compute Ltree
  624. setLtreeADRS(ltree_addr, idx);
  625. l_tree(pkhash, wots_pk, params, pub_seed, ltree_addr);
  626. // Compute root
  627. validate_authpath(root, pkhash, idx, sig_msg, params, pub_seed, node_addr);
  628. sig_msg += params->h*n;
  629. sig_msg_len -= params->h*n;
  630. for (i = 0; i < n; i++)
  631. if (root[i] != pk[i])
  632. goto fail;
  633. *msglen = sig_msg_len;
  634. for (i = 0; i < *msglen; i++)
  635. msg[i] = sig_msg[i];
  636. return 0;
  637. fail:
  638. *msglen = sig_msg_len;
  639. for (i = 0; i < *msglen; i++)
  640. msg[i] = 0;
  641. *msglen = -1;
  642. return -1;
  643. }
  644. /*
  645. * Generates a XMSSMT key pair for a given parameter set.
  646. * Format sk: [(ceil(h/8) bit) idx || SK_SEED || SK_PRF || PUB_SEED || root]
  647. * Format pk: [root || PUB_SEED] omitting algo oid.
  648. */
  649. int xmssmt_keypair(unsigned char *pk, unsigned char *sk, bds_state *states, unsigned char *wots_sigs, xmssmt_params *params)
  650. {
  651. unsigned int n = params->n;
  652. unsigned int i;
  653. unsigned char ots_seed[params->n];
  654. // Set idx = 0
  655. for (i = 0; i < params->index_len; i++) {
  656. sk[i] = 0;
  657. }
  658. // Init SK_SEED (n byte), SK_PRF (n byte), and PUB_SEED (n byte)
  659. randombytes(sk+params->index_len, 3*n);
  660. // Copy PUB_SEED to public key
  661. memcpy(pk+n, sk+params->index_len+2*n, n);
  662. uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  663. // Start with the bottom-most layer
  664. setLayerADRS(addr, 0);
  665. // Set up state and compute wots signatures for all but topmost tree root
  666. for (i = 0; i < params->d - 1; i++) {
  667. // Compute seed for OTS key pair
  668. treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
  669. setLayerADRS(addr, (i+1));
  670. get_seed(ots_seed, sk+params->index_len, n, addr, params->xmss_par.hash_alg);
  671. wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, pk, ots_seed, &(params->xmss_par.wots_par), pk+n, addr);
  672. }
  673. // Address now points to the single tree on layer d-1
  674. treehash_setup(pk, params->xmss_par.h, 0, states + i, sk+params->index_len, &(params->xmss_par), pk+n, addr);
  675. memcpy(sk+params->index_len+3*n, pk, n);
  676. return 0;
  677. }
  678. /**
  679. * Signs a message.
  680. * Returns
  681. * 1. an array containing the signature followed by the message AND
  682. * 2. an updated secret key!
  683. *
  684. */
  685. int xmssmt_sign(unsigned char *sk, bds_state *states, unsigned char *wots_sigs, unsigned char *sig_msg, unsigned long long *sig_msg_len, const unsigned char *msg, unsigned long long msglen, const xmssmt_params *params)
  686. {
  687. unsigned int n = params->n;
  688. unsigned int tree_h = params->xmss_par.h;
  689. unsigned int h = params->h;
  690. unsigned int k = params->xmss_par.k;
  691. unsigned int idx_len = params->index_len;
  692. uint64_t idx_tree;
  693. uint32_t idx_leaf;
  694. uint64_t i, j;
  695. int needswap_upto = -1;
  696. unsigned int updates;
  697. unsigned char sk_seed[n];
  698. unsigned char sk_prf[n];
  699. unsigned char pub_seed[n];
  700. // Init working params
  701. unsigned char R[n];
  702. unsigned char msg_h[n];
  703. unsigned char hash_key[3*n];
  704. unsigned char ots_seed[n];
  705. uint32_t addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  706. uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  707. unsigned char idx_bytes_32[32];
  708. bds_state tmp;
  709. // Extract SK
  710. unsigned long long idx = 0;
  711. for (i = 0; i < idx_len; i++) {
  712. idx |= ((unsigned long long)sk[i]) << 8*(idx_len - 1 - i);
  713. }
  714. memcpy(sk_seed, sk+idx_len, n);
  715. memcpy(sk_prf, sk+idx_len+n, n);
  716. memcpy(pub_seed, sk+idx_len+2*n, n);
  717. // Update SK
  718. for (i = 0; i < idx_len; i++) {
  719. sk[i] = ((idx + 1) >> 8*(idx_len - 1 - i)) & 255;
  720. }
  721. // -- Secret key for this non-forward-secure version is now updated.
  722. // -- A productive implementation should use a file handle instead and write the updated secret key at this point!
  723. // ---------------------------------
  724. // Message Hashing
  725. // ---------------------------------
  726. // Message Hash:
  727. // First compute pseudorandom value
  728. to_byte(idx_bytes_32, idx, 32);
  729. prf(R, idx_bytes_32, sk_prf, n, params->xmss_par.hash_alg);
  730. // Generate hash key (R || root || idx)
  731. memcpy(hash_key, R, n);
  732. memcpy(hash_key+n, sk+idx_len+3*n, n);
  733. to_byte(hash_key+2*n, idx, n);
  734. // Then use it for message digest
  735. h_msg(msg_h, msg, msglen, hash_key, 3*n, n, params->xmss_par.hash_alg);
  736. // Start collecting signature
  737. *sig_msg_len = 0;
  738. // Copy index to signature
  739. for (i = 0; i < idx_len; i++) {
  740. sig_msg[i] = (idx >> 8*(idx_len - 1 - i)) & 255;
  741. }
  742. sig_msg += idx_len;
  743. *sig_msg_len += idx_len;
  744. // Copy R to signature
  745. for (i = 0; i < n; i++)
  746. sig_msg[i] = R[i];
  747. sig_msg += n;
  748. *sig_msg_len += n;
  749. // ----------------------------------
  750. // Now we start to "really sign"
  751. // ----------------------------------
  752. // Handle lowest layer separately as it is slightly different...
  753. // Prepare Address
  754. setType(ots_addr, 0);
  755. idx_tree = idx >> tree_h;
  756. idx_leaf = (idx & ((1 << tree_h)-1));
  757. setLayerADRS(ots_addr, 0);
  758. setTreeADRS(ots_addr, idx_tree);
  759. setOTSADRS(ots_addr, idx_leaf);
  760. // Compute seed for OTS key pair
  761. get_seed(ots_seed, sk_seed, n, ots_addr, params->xmss_par.hash_alg);
  762. // Compute WOTS signature
  763. wots_sign(sig_msg, msg_h, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
  764. sig_msg += params->xmss_par.wots_par.keysize;
  765. *sig_msg_len += params->xmss_par.wots_par.keysize;
  766. memcpy(sig_msg, states[0].auth, tree_h*n);
  767. sig_msg += tree_h*n;
  768. *sig_msg_len += tree_h*n;
  769. // prepare signature of remaining layers
  770. for (i = 1; i < params->d; i++) {
  771. // put WOTS signature in place
  772. memcpy(sig_msg, wots_sigs + (i-1)*params->xmss_par.wots_par.keysize, params->xmss_par.wots_par.keysize);
  773. sig_msg += params->xmss_par.wots_par.keysize;
  774. *sig_msg_len += params->xmss_par.wots_par.keysize;
  775. // put AUTH nodes in place
  776. memcpy(sig_msg, states[i].auth, tree_h*n);
  777. sig_msg += tree_h*n;
  778. *sig_msg_len += tree_h*n;
  779. }
  780. updates = (tree_h - k) >> 1;
  781. setTreeADRS(addr, (idx_tree + 1));
  782. // mandatory update for NEXT_0 (does not count towards h-k/2) if NEXT_0 exists
  783. if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << h)) {
  784. bds_state_update(&states[params->d], sk_seed, &(params->xmss_par), pub_seed, addr);
  785. }
  786. for (i = 0; i < params->d; i++) {
  787. // check if we're not at the end of a tree
  788. if (! (((idx + 1) & ((1ULL << ((i+1)*tree_h)) - 1)) == 0)) {
  789. idx_leaf = (idx >> (tree_h * i)) & ((1 << tree_h)-1);
  790. idx_tree = (idx >> (tree_h * (i+1)));
  791. setLayerADRS(addr, i);
  792. setTreeADRS(addr, idx_tree);
  793. if (i == (unsigned int) (needswap_upto + 1)) {
  794. bds_round(&states[i], idx_leaf, sk_seed, &(params->xmss_par), pub_seed, addr);
  795. }
  796. updates = bds_treehash_update(&states[i], updates, sk_seed, &(params->xmss_par), pub_seed, addr);
  797. setTreeADRS(addr, (idx_tree + 1));
  798. // if a NEXT-tree exists for this level;
  799. if ((1 + idx_tree) * (1 << tree_h) + idx_leaf < (1ULL << (h - tree_h * i))) {
  800. if (i > 0 && updates > 0 && states[params->d + i].next_leaf < (1ULL << h)) {
  801. bds_state_update(&states[params->d + i], sk_seed, &(params->xmss_par), pub_seed, addr);
  802. updates--;
  803. }
  804. }
  805. }
  806. else if (idx < (1ULL << h) - 1) {
  807. memcpy(&tmp, states+params->d + i, sizeof(bds_state));
  808. memcpy(states+params->d + i, states + i, sizeof(bds_state));
  809. memcpy(states + i, &tmp, sizeof(bds_state));
  810. setLayerADRS(ots_addr, (i+1));
  811. setTreeADRS(ots_addr, ((idx + 1) >> ((i+2) * tree_h)));
  812. setOTSADRS(ots_addr, (((idx >> ((i+1) * tree_h)) + 1) & ((1 << tree_h)-1)));
  813. get_seed(ots_seed, sk+params->index_len, n, ots_addr, params->xmss_par.hash_alg);
  814. wots_sign(wots_sigs + i*params->xmss_par.wots_par.keysize, states[i].stack, ots_seed, &(params->xmss_par.wots_par), pub_seed, ots_addr);
  815. states[params->d + i].stackoffset = 0;
  816. states[params->d + i].next_leaf = 0;
  817. updates--; // WOTS-signing counts as one update
  818. needswap_upto = i;
  819. for (j = 0; j < tree_h-k; j++) {
  820. states[i].treehash[j].completed = 1;
  821. }
  822. }
  823. }
  824. //Whipe secret elements?
  825. //zerobytes(tsk, CRYPTO_SECRETKEYBYTES);
  826. memcpy(sig_msg, msg, msglen);
  827. *sig_msg_len += msglen;
  828. return 0;
  829. }
  830. /**
  831. * Verifies a given message signature pair under a given public key.
  832. */
  833. int xmssmt_sign_open(unsigned char *msg, unsigned long long *msglen, const unsigned char *sig_msg, unsigned long long sig_msg_len, const unsigned char *pk, const xmssmt_params *params)
  834. {
  835. unsigned int n = params->n;
  836. unsigned int tree_h = params->xmss_par.h;
  837. unsigned int idx_len = params->index_len;
  838. uint64_t idx_tree;
  839. uint32_t idx_leaf;
  840. unsigned long long i, m_len;
  841. unsigned long long idx=0;
  842. unsigned char wots_pk[params->xmss_par.wots_par.keysize];
  843. unsigned char pkhash[n];
  844. unsigned char root[n];
  845. unsigned char msg_h[n];
  846. unsigned char hash_key[3*n];
  847. unsigned char pub_seed[n];
  848. memcpy(pub_seed, pk+n, n);
  849. // Init addresses
  850. uint32_t ots_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  851. uint32_t ltree_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  852. uint32_t node_addr[8] = {0, 0, 0, 0, 0, 0, 0, 0};
  853. // Extract index
  854. for (i = 0; i < idx_len; i++) {
  855. idx |= ((unsigned long long)sig_msg[i]) << (8*(idx_len - 1 - i));
  856. }
  857. printf("verify:: idx = %llu\n", idx);
  858. sig_msg += idx_len;
  859. sig_msg_len -= idx_len;
  860. // Generate hash key (R || root || idx)
  861. memcpy(hash_key, sig_msg,n);
  862. memcpy(hash_key+n, pk, n);
  863. to_byte(hash_key+2*n, idx, n);
  864. sig_msg += n;
  865. sig_msg_len -= n;
  866. // hash message (recall, R is now on pole position at sig_msg
  867. unsigned long long tmp_sig_len = (params->d * params->xmss_par.wots_par.keysize) + (params->h * n);
  868. m_len = sig_msg_len - tmp_sig_len;
  869. h_msg(msg_h, sig_msg + tmp_sig_len, m_len, hash_key, 3*n, n, params->xmss_par.hash_alg);
  870. //-----------------------
  871. // Verify signature
  872. //-----------------------
  873. // Prepare Address
  874. idx_tree = idx >> tree_h;
  875. idx_leaf = (idx & ((1 << tree_h)-1));
  876. setLayerADRS(ots_addr, 0);
  877. setTreeADRS(ots_addr, idx_tree);
  878. setType(ots_addr, 0);
  879. memcpy(ltree_addr, ots_addr, 12);
  880. setType(ltree_addr, 1);
  881. memcpy(node_addr, ltree_addr, 12);
  882. setType(node_addr, 2);
  883. setOTSADRS(ots_addr, idx_leaf);
  884. // Check WOTS signature
  885. wots_pkFromSig(wots_pk, sig_msg, msg_h, &(params->xmss_par.wots_par), pub_seed, ots_addr);
  886. sig_msg += params->xmss_par.wots_par.keysize;
  887. sig_msg_len -= params->xmss_par.wots_par.keysize;
  888. // Compute Ltree
  889. setLtreeADRS(ltree_addr, idx_leaf);
  890. l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
  891. // Compute root
  892. validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
  893. sig_msg += tree_h*n;
  894. sig_msg_len -= tree_h*n;
  895. for (i = 1; i < params->d; i++) {
  896. // Prepare Address
  897. idx_leaf = (idx_tree & ((1 << tree_h)-1));
  898. idx_tree = idx_tree >> tree_h;
  899. setLayerADRS(ots_addr, i);
  900. setTreeADRS(ots_addr, idx_tree);
  901. setType(ots_addr, 0);
  902. memcpy(ltree_addr, ots_addr, 12);
  903. setType(ltree_addr, 1);
  904. memcpy(node_addr, ltree_addr, 12);
  905. setType(node_addr, 2);
  906. setOTSADRS(ots_addr, idx_leaf);
  907. // Check WOTS signature
  908. wots_pkFromSig(wots_pk, sig_msg, root, &(params->xmss_par.wots_par), pub_seed, ots_addr);
  909. sig_msg += params->xmss_par.wots_par.keysize;
  910. sig_msg_len -= params->xmss_par.wots_par.keysize;
  911. // Compute Ltree
  912. setLtreeADRS(ltree_addr, idx_leaf);
  913. l_tree(pkhash, wots_pk, &(params->xmss_par), pub_seed, ltree_addr);
  914. // Compute root
  915. validate_authpath(root, pkhash, idx_leaf, sig_msg, &(params->xmss_par), pub_seed, node_addr);
  916. sig_msg += tree_h*n;
  917. sig_msg_len -= tree_h*n;
  918. }
  919. for (i = 0; i < n; i++)
  920. if (root[i] != pk[i])
  921. goto fail;
  922. *msglen = sig_msg_len;
  923. for (i = 0; i < *msglen; i++)
  924. msg[i] = sig_msg[i];
  925. return 0;
  926. fail:
  927. *msglen = sig_msg_len;
  928. for (i = 0; i < *msglen; i++)
  929. msg[i] = 0;
  930. *msglen = -1;
  931. return -1;
  932. }