You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

2202 line
74 KiB

  1. /* Copyright (c) 2018, Google Inc.
  2. *
  3. * Permission to use, copy, modify, and/or distribute this software for any
  4. * purpose with or without fee is hereby granted, provided that the above
  5. * copyright notice and this permission notice appear in all copies.
  6. *
  7. * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  8. * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  9. * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
  10. * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  11. * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
  12. * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
  13. * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
  14. #include <openssl/hrss.h>
  15. #include <assert.h>
  16. #include <stdio.h>
  17. #include <stdlib.h>
  18. #include <openssl/bn.h>
  19. #include <openssl/cpu.h>
  20. #include <openssl/hmac.h>
  21. #include <openssl/mem.h>
  22. #include <openssl/sha.h>
  23. #if defined(OPENSSL_X86) || defined(OPENSSL_X86_64)
  24. #include <emmintrin.h>
  25. #endif
  26. #if (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
  27. (defined(__ARM_NEON__) || defined(__ARM_NEON))
  28. #include <arm_neon.h>
  29. #endif
  30. #if defined(_MSC_VER)
  31. #define RESTRICT
  32. #else
  33. #define RESTRICT restrict
  34. #endif
  35. #include "../internal.h"
  36. #include "internal.h"
  37. // This is an implementation of [HRSS], but with a KEM transformation based on
  38. // [SXY]. The primary references are:
  39. // HRSS: https://eprint.iacr.org/2017/667.pdf
  40. // HRSSNIST:
  41. // https://csrc.nist.gov/CSRC/media/Projects/Post-Quantum-Cryptography/documents/round-1/submissions/NTRU_HRSS_KEM.zip
  42. // SXY: https://eprint.iacr.org/2017/1005.pdf
  43. // NTRUTN14:
  44. // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf
  45. // NTRUCOMP:
  46. // https://eprint.iacr.org/2018/1174
  47. // Vector operations.
  48. //
  49. // A couple of functions in this file can use vector operations to meaningful
  50. // effect. If we're building for a target that has a supported vector unit,
  51. // |HRSS_HAVE_VECTOR_UNIT| will be defined and |vec_t| will be typedefed to a
  52. // 128-bit vector. The following functions abstract over the differences between
  53. // NEON and SSE2 for implementing some vector operations.
  54. // TODO: MSVC can likely also be made to work with vector operations.
  55. #if ((defined(__SSE__) && defined(OPENSSL_X86)) || defined(OPENSSL_X86_64)) && \
  56. (defined(__clang__) || !defined(_MSC_VER))
  57. #define HRSS_HAVE_VECTOR_UNIT
  58. typedef __m128i vec_t;
  59. // vec_capable returns one iff the current platform supports SSE2.
  60. static int vec_capable(void) {
  61. #if defined(__SSE2__)
  62. return 1;
  63. #else
  64. int has_sse2 = (OPENSSL_ia32cap_P[0] & (1 << 26)) != 0;
  65. return has_sse2;
  66. #endif
  67. }
  68. // vec_add performs a pair-wise addition of four uint16s from |a| and |b|.
  69. static inline vec_t vec_add(vec_t a, vec_t b) { return _mm_add_epi16(a, b); }
  70. // vec_sub performs a pair-wise subtraction of four uint16s from |a| and |b|.
  71. static inline vec_t vec_sub(vec_t a, vec_t b) { return _mm_sub_epi16(a, b); }
  72. // vec_mul multiplies each uint16_t in |a| by |b| and returns the resulting
  73. // vector.
  74. static inline vec_t vec_mul(vec_t a, uint16_t b) {
  75. return _mm_mullo_epi16(a, _mm_set1_epi16(b));
  76. }
  77. // vec_fma multiplies each uint16_t in |b| by |c|, adds the result to |a|, and
  78. // returns the resulting vector.
  79. static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
  80. return _mm_add_epi16(a, _mm_mullo_epi16(b, _mm_set1_epi16(c)));
  81. }
  82. // vec3_rshift_word right-shifts the 24 uint16_t's in |v| by one uint16.
  83. static inline void vec3_rshift_word(vec_t v[3]) {
  84. // Intel's left and right shifting is backwards compared to the order in
  85. // memory because they're based on little-endian order of words (and not just
  86. // bytes). So the shifts in this function will be backwards from what one
  87. // might expect.
  88. const __m128i carry0 = _mm_srli_si128(v[0], 14);
  89. v[0] = _mm_slli_si128(v[0], 2);
  90. const __m128i carry1 = _mm_srli_si128(v[1], 14);
  91. v[1] = _mm_slli_si128(v[1], 2);
  92. v[1] |= carry0;
  93. v[2] = _mm_slli_si128(v[2], 2);
  94. v[2] |= carry1;
  95. }
  96. // vec4_rshift_word right-shifts the 32 uint16_t's in |v| by one uint16.
  97. static inline void vec4_rshift_word(vec_t v[4]) {
  98. // Intel's left and right shifting is backwards compared to the order in
  99. // memory because they're based on little-endian order of words (and not just
  100. // bytes). So the shifts in this function will be backwards from what one
  101. // might expect.
  102. const __m128i carry0 = _mm_srli_si128(v[0], 14);
  103. v[0] = _mm_slli_si128(v[0], 2);
  104. const __m128i carry1 = _mm_srli_si128(v[1], 14);
  105. v[1] = _mm_slli_si128(v[1], 2);
  106. v[1] |= carry0;
  107. const __m128i carry2 = _mm_srli_si128(v[2], 14);
  108. v[2] = _mm_slli_si128(v[2], 2);
  109. v[2] |= carry1;
  110. v[3] = _mm_slli_si128(v[3], 2);
  111. v[3] |= carry2;
  112. }
  113. // vec_merge_3_5 takes the final three uint16_t's from |left|, appends the first
  114. // five from |right|, and returns the resulting vector.
  115. static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
  116. return _mm_srli_si128(left, 10) | _mm_slli_si128(right, 6);
  117. }
  118. // poly3_vec_lshift1 left-shifts the 768 bits in |a_s|, and in |a_a|, by one
  119. // bit.
  120. static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
  121. vec_t carry_s = {0};
  122. vec_t carry_a = {0};
  123. for (int i = 0; i < 6; i++) {
  124. vec_t next_carry_s = _mm_srli_epi64(a_s[i], 63);
  125. a_s[i] = _mm_slli_epi64(a_s[i], 1);
  126. a_s[i] |= _mm_slli_si128(next_carry_s, 8);
  127. a_s[i] |= carry_s;
  128. carry_s = _mm_srli_si128(next_carry_s, 8);
  129. vec_t next_carry_a = _mm_srli_epi64(a_a[i], 63);
  130. a_a[i] = _mm_slli_epi64(a_a[i], 1);
  131. a_a[i] |= _mm_slli_si128(next_carry_a, 8);
  132. a_a[i] |= carry_a;
  133. carry_a = _mm_srli_si128(next_carry_a, 8);
  134. }
  135. }
  136. // poly3_vec_rshift1 right-shifts the 768 bits in |a_s|, and in |a_a|, by one
  137. // bit.
  138. static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
  139. vec_t carry_s = {0};
  140. vec_t carry_a = {0};
  141. for (int i = 5; i >= 0; i--) {
  142. const vec_t next_carry_s = _mm_slli_epi64(a_s[i], 63);
  143. a_s[i] = _mm_srli_epi64(a_s[i], 1);
  144. a_s[i] |= _mm_srli_si128(next_carry_s, 8);
  145. a_s[i] |= carry_s;
  146. carry_s = _mm_slli_si128(next_carry_s, 8);
  147. const vec_t next_carry_a = _mm_slli_epi64(a_a[i], 63);
  148. a_a[i] = _mm_srli_epi64(a_a[i], 1);
  149. a_a[i] |= _mm_srli_si128(next_carry_a, 8);
  150. a_a[i] |= carry_a;
  151. carry_a = _mm_slli_si128(next_carry_a, 8);
  152. }
  153. }
  154. // vec_broadcast_bit duplicates the least-significant bit in |a| to all bits in
  155. // a vector and returns the result.
  156. static inline vec_t vec_broadcast_bit(vec_t a) {
  157. return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63), 31),
  158. 0b01010101);
  159. }
  160. // vec_broadcast_bit15 duplicates the most-significant bit of the first word in
  161. // |a| to all bits in a vector and returns the result.
  162. static inline vec_t vec_broadcast_bit15(vec_t a) {
  163. return _mm_shuffle_epi32(_mm_srai_epi32(_mm_slli_epi64(a, 63 - 15), 31),
  164. 0b01010101);
  165. }
  166. // vec_get_word returns the |i|th uint16_t in |v|. (This is a macro because the
  167. // compiler requires that |i| be a compile-time constant.)
  168. #define vec_get_word(v, i) _mm_extract_epi16(v, i)
  169. #elif (defined(OPENSSL_ARM) || defined(OPENSSL_AARCH64)) && \
  170. (defined(__ARM_NEON__) || defined(__ARM_NEON))
  171. #define HRSS_HAVE_VECTOR_UNIT
  172. typedef uint16x8_t vec_t;
  173. // These functions perform the same actions as the SSE2 function of the same
  174. // name, above.
  175. static int vec_capable(void) { return CRYPTO_is_NEON_capable(); }
  176. static inline vec_t vec_add(vec_t a, vec_t b) { return a + b; }
  177. static inline vec_t vec_sub(vec_t a, vec_t b) { return a - b; }
  178. static inline vec_t vec_mul(vec_t a, uint16_t b) { return vmulq_n_u16(a, b); }
  179. static inline vec_t vec_fma(vec_t a, vec_t b, uint16_t c) {
  180. return vmlaq_n_u16(a, b, c);
  181. }
  182. static inline void vec3_rshift_word(vec_t v[3]) {
  183. const uint16x8_t kZero = {0};
  184. v[2] = vextq_u16(v[1], v[2], 7);
  185. v[1] = vextq_u16(v[0], v[1], 7);
  186. v[0] = vextq_u16(kZero, v[0], 7);
  187. }
  188. static inline void vec4_rshift_word(vec_t v[4]) {
  189. const uint16x8_t kZero = {0};
  190. v[3] = vextq_u16(v[2], v[3], 7);
  191. v[2] = vextq_u16(v[1], v[2], 7);
  192. v[1] = vextq_u16(v[0], v[1], 7);
  193. v[0] = vextq_u16(kZero, v[0], 7);
  194. }
  195. static inline vec_t vec_merge_3_5(vec_t left, vec_t right) {
  196. return vextq_u16(left, right, 5);
  197. }
  198. static inline uint16_t vec_get_word(vec_t v, unsigned i) {
  199. return v[i];
  200. }
  201. #if !defined(OPENSSL_AARCH64)
  202. static inline vec_t vec_broadcast_bit(vec_t a) {
  203. a = (vec_t)vshrq_n_s16(((int16x8_t)a) << 15, 15);
  204. return vdupq_lane_u16(vget_low_u16(a), 0);
  205. }
  206. static inline vec_t vec_broadcast_bit15(vec_t a) {
  207. a = (vec_t)vshrq_n_s16((int16x8_t)a, 15);
  208. return vdupq_lane_u16(vget_low_u16(a), 0);
  209. }
  210. static inline void poly3_vec_lshift1(vec_t a_s[6], vec_t a_a[6]) {
  211. vec_t carry_s = {0};
  212. vec_t carry_a = {0};
  213. const vec_t kZero = {0};
  214. for (int i = 0; i < 6; i++) {
  215. vec_t next_carry_s = a_s[i] >> 15;
  216. a_s[i] <<= 1;
  217. a_s[i] |= vextq_u16(kZero, next_carry_s, 7);
  218. a_s[i] |= carry_s;
  219. carry_s = vextq_u16(next_carry_s, kZero, 7);
  220. vec_t next_carry_a = a_a[i] >> 15;
  221. a_a[i] <<= 1;
  222. a_a[i] |= vextq_u16(kZero, next_carry_a, 7);
  223. a_a[i] |= carry_a;
  224. carry_a = vextq_u16(next_carry_a, kZero, 7);
  225. }
  226. }
  227. static inline void poly3_vec_rshift1(vec_t a_s[6], vec_t a_a[6]) {
  228. vec_t carry_s = {0};
  229. vec_t carry_a = {0};
  230. const vec_t kZero = {0};
  231. for (int i = 5; i >= 0; i--) {
  232. vec_t next_carry_s = a_s[i] << 15;
  233. a_s[i] >>= 1;
  234. a_s[i] |= vextq_u16(next_carry_s, kZero, 1);
  235. a_s[i] |= carry_s;
  236. carry_s = vextq_u16(kZero, next_carry_s, 1);
  237. vec_t next_carry_a = a_a[i] << 15;
  238. a_a[i] >>= 1;
  239. a_a[i] |= vextq_u16(next_carry_a, kZero, 1);
  240. a_a[i] |= carry_a;
  241. carry_a = vextq_u16(kZero, next_carry_a, 1);
  242. }
  243. }
  244. #endif // !OPENSSL_AARCH64
  245. #endif // (ARM || AARCH64) && NEON
  246. // Polynomials in this scheme have N terms.
  247. // #define N 701
  248. // Underlying data types and arithmetic operations.
  249. // ------------------------------------------------
  250. // Binary polynomials.
  251. // poly2 represents a degree-N polynomial over GF(2). The words are in little-
  252. // endian order, i.e. the coefficient of x^0 is the LSB of the first word. The
  253. // final word is only partially used since N is not a multiple of the word size.
  254. // Defined in internal.h:
  255. // struct poly2 {
  256. // crypto_word_t v[WORDS_PER_POLY];
  257. // };
  258. OPENSSL_UNUSED static void hexdump(const void *void_in, size_t len) {
  259. const uint8_t *in = (const uint8_t *)void_in;
  260. for (size_t i = 0; i < len; i++) {
  261. printf("%02x", in[i]);
  262. }
  263. printf("\n");
  264. }
  265. static void poly2_zero(struct poly2 *p) {
  266. OPENSSL_memset(&p->v[0], 0, sizeof(crypto_word_t) * WORDS_PER_POLY);
  267. }
  268. // poly2_cmov sets |out| to |in| iff |mov| is all ones.
  269. static void poly2_cmov(struct poly2 *out, const struct poly2 *in,
  270. crypto_word_t mov) {
  271. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  272. out->v[i] = (out->v[i] & ~mov) | (in->v[i] & mov);
  273. }
  274. }
  275. // poly2_rotr_words performs a right-rotate on |in|, writing the result to
  276. // |out|. The shift count, |bits|, must be a non-zero multiple of the word size.
  277. static void poly2_rotr_words(struct poly2 *out, const struct poly2 *in,
  278. size_t bits) {
  279. assert(bits >= BITS_PER_WORD && bits % BITS_PER_WORD == 0);
  280. assert(out != in);
  281. const size_t start = bits / BITS_PER_WORD;
  282. const size_t n = (N - bits) / BITS_PER_WORD;
  283. // The rotate is by a whole number of words so the first few words are easy:
  284. // just move them down.
  285. for (size_t i = 0; i < n; i++) {
  286. out->v[i] = in->v[start + i];
  287. }
  288. // Since the last word is only partially filled, however, the remainder needs
  289. // shifting and merging of words to take care of that.
  290. crypto_word_t carry = in->v[WORDS_PER_POLY - 1];
  291. for (size_t i = 0; i < start; i++) {
  292. out->v[n + i] = carry | in->v[i] << BITS_IN_LAST_WORD;
  293. carry = in->v[i] >> (BITS_PER_WORD - BITS_IN_LAST_WORD);
  294. }
  295. out->v[WORDS_PER_POLY - 1] = carry;
  296. }
  297. // poly2_rotr_bits performs a right-rotate on |in|, writing the result to |out|.
  298. // The shift count, |bits|, must be a power of two that is less than
  299. // |BITS_PER_WORD|.
  300. static void poly2_rotr_bits(struct poly2 *out, const struct poly2 *in,
  301. size_t bits) {
  302. assert(bits <= BITS_PER_WORD / 2);
  303. assert(bits != 0);
  304. assert((bits & (bits - 1)) == 0);
  305. assert(out != in);
  306. // BITS_PER_WORD/2 is the greatest legal value of |bits|. If
  307. // |BITS_IN_LAST_WORD| is smaller than this then the code below doesn't work
  308. // because more than the last word needs to carry down in the previous one and
  309. // so on.
  310. OPENSSL_STATIC_ASSERT(
  311. BITS_IN_LAST_WORD >= BITS_PER_WORD / 2,
  312. "there are more carry bits than fit in BITS_IN_LAST_WORD");
  313. crypto_word_t carry = in->v[WORDS_PER_POLY - 1] << (BITS_PER_WORD - bits);
  314. for (size_t i = WORDS_PER_POLY - 2; i < WORDS_PER_POLY; i--) {
  315. out->v[i] = carry | in->v[i] >> bits;
  316. carry = in->v[i] << (BITS_PER_WORD - bits);
  317. }
  318. crypto_word_t last_word = carry >> (BITS_PER_WORD - BITS_IN_LAST_WORD) |
  319. in->v[WORDS_PER_POLY - 1] >> bits;
  320. last_word &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
  321. out->v[WORDS_PER_POLY - 1] = last_word;
  322. }
  323. // HRSS_poly2_rotr_consttime right-rotates |p| by |bits| in constant-time.
  324. void HRSS_poly2_rotr_consttime(struct poly2 *p, size_t bits) {
  325. assert(bits <= N);
  326. assert(p->v[WORDS_PER_POLY-1] >> BITS_IN_LAST_WORD == 0);
  327. // Constant-time rotation is implemented by calculating the rotations of
  328. // powers-of-two bits and throwing away the unneeded values. 2^9 (i.e. 512) is
  329. // the largest power-of-two shift that we need to consider because 2^10 > N.
  330. #define HRSS_POLY2_MAX_SHIFT 9
  331. size_t shift = HRSS_POLY2_MAX_SHIFT;
  332. OPENSSL_STATIC_ASSERT((1 << (HRSS_POLY2_MAX_SHIFT + 1)) > N,
  333. "maximum shift is too small");
  334. OPENSSL_STATIC_ASSERT((1 << HRSS_POLY2_MAX_SHIFT) <= N,
  335. "maximum shift is too large");
  336. struct poly2 shifted;
  337. for (; (UINT64_C(1) << shift) >= BITS_PER_WORD; shift--) {
  338. poly2_rotr_words(&shifted, p, UINT64_C(1) << shift);
  339. poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
  340. }
  341. for (; shift < HRSS_POLY2_MAX_SHIFT; shift--) {
  342. poly2_rotr_bits(&shifted, p, UINT64_C(1) << shift);
  343. poly2_cmov(p, &shifted, ~((1 & (bits >> shift)) - 1));
  344. }
  345. #undef HRSS_POLY2_MAX_SHIFT
  346. }
  347. // poly2_cswap exchanges the values of |a| and |b| if |swap| is all ones.
  348. static void poly2_cswap(struct poly2 *a, struct poly2 *b, crypto_word_t swap) {
  349. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  350. const crypto_word_t sum = swap & (a->v[i] ^ b->v[i]);
  351. a->v[i] ^= sum;
  352. b->v[i] ^= sum;
  353. }
  354. }
  355. // poly2_fmadd sets |out| to |out| + |in| * m, where m is either
  356. // |CONSTTIME_TRUE_W| or |CONSTTIME_FALSE_W|.
  357. static void poly2_fmadd(struct poly2 *out, const struct poly2 *in,
  358. crypto_word_t m) {
  359. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  360. out->v[i] ^= in->v[i] & m;
  361. }
  362. }
  363. // poly2_lshift1 left-shifts |p| by one bit.
  364. static void poly2_lshift1(struct poly2 *p) {
  365. crypto_word_t carry = 0;
  366. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  367. const crypto_word_t next_carry = p->v[i] >> (BITS_PER_WORD - 1);
  368. p->v[i] <<= 1;
  369. p->v[i] |= carry;
  370. carry = next_carry;
  371. }
  372. }
  373. // poly2_rshift1 right-shifts |p| by one bit.
  374. static void poly2_rshift1(struct poly2 *p) {
  375. crypto_word_t carry = 0;
  376. for (size_t i = WORDS_PER_POLY - 1; i < WORDS_PER_POLY; i--) {
  377. const crypto_word_t next_carry = p->v[i] & 1;
  378. p->v[i] >>= 1;
  379. p->v[i] |= carry << (BITS_PER_WORD - 1);
  380. carry = next_carry;
  381. }
  382. }
  383. // poly2_clear_top_bits clears the bits in the final word that are only for
  384. // alignment.
  385. static void poly2_clear_top_bits(struct poly2 *p) {
  386. p->v[WORDS_PER_POLY - 1] &= (UINT64_C(1) << BITS_IN_LAST_WORD) - 1;
  387. }
  388. // poly2_top_bits_are_clear returns one iff the extra bits in the final words of
  389. // |p| are zero.
  390. static int poly2_top_bits_are_clear(const struct poly2 *p) {
  391. return (p->v[WORDS_PER_POLY - 1] &
  392. ~((UINT64_C(1) << BITS_IN_LAST_WORD) - 1)) == 0;
  393. }
  394. // Ternary polynomials.
  395. // poly3 represents a degree-N polynomial over GF(3). Each coefficient is
  396. // bitsliced across the |s| and |a| arrays, like this:
  397. //
  398. // s | a | value
  399. // -----------------
  400. // 0 | 0 | 0
  401. // 0 | 1 | 1
  402. // 1 | 1 | -1 (aka 2)
  403. // 1 | 0 | <invalid>
  404. //
  405. // ('s' is for sign, and 'a' is the absolute value.)
  406. //
  407. // Once bitsliced as such, the following circuits can be used to implement
  408. // addition and multiplication mod 3:
  409. //
  410. // (s3, a3) = (s1, a1) × (s2, a2)
  411. // a3 = a1 ∧ a2
  412. // s3 = (s1 ⊕ s2) ∧ a3
  413. //
  414. // (s3, a3) = (s1, a1) + (s2, a2)
  415. // t = s1 ⊕ a2
  416. // s3 = t ∧ (s2 ⊕ a1)
  417. // a3 = (a1 ⊕ a2) ∨ (t ⊕ s2)
  418. //
  419. // (s3, a3) = (s1, a1) - (s2, a2)
  420. // t = a1 ⊕ a2
  421. // s3 = (s1 ⊕ a2) ∧ (t ⊕ s2)
  422. // a3 = t ∨ (s1 ⊕ s2)
  423. //
  424. // Negating a value just involves XORing s by a.
  425. //
  426. // struct poly3 {
  427. // struct poly2 s, a;
  428. // };
  429. OPENSSL_UNUSED static void poly3_print(const struct poly3 *in) {
  430. struct poly3 p;
  431. OPENSSL_memcpy(&p, in, sizeof(p));
  432. p.s.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
  433. p.a.v[WORDS_PER_POLY - 1] &= ((crypto_word_t)1 << BITS_IN_LAST_WORD) - 1;
  434. printf("{[");
  435. for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
  436. if (i) {
  437. printf(" ");
  438. }
  439. printf(BN_HEX_FMT2, p.s.v[i]);
  440. }
  441. printf("] [");
  442. for (unsigned i = 0; i < WORDS_PER_POLY; i++) {
  443. if (i) {
  444. printf(" ");
  445. }
  446. printf(BN_HEX_FMT2, p.a.v[i]);
  447. }
  448. printf("]}\n");
  449. }
  450. static void poly3_zero(struct poly3 *p) {
  451. poly2_zero(&p->s);
  452. poly2_zero(&p->a);
  453. }
  454. // poly3_word_mul sets (|out_s|, |out_a) to (|s1|, |a1|) × (|s2|, |a2|).
  455. static void poly3_word_mul(crypto_word_t *out_s, crypto_word_t *out_a,
  456. const crypto_word_t s1, const crypto_word_t a1,
  457. const crypto_word_t s2, const crypto_word_t a2) {
  458. *out_a = a1 & a2;
  459. *out_s = (s1 ^ s2) & *out_a;
  460. }
  461. // poly3_word_add sets (|out_s|, |out_a|) to (|s1|, |a1|) + (|s2|, |a2|).
  462. static void poly3_word_add(crypto_word_t *out_s, crypto_word_t *out_a,
  463. const crypto_word_t s1, const crypto_word_t a1,
  464. const crypto_word_t s2, const crypto_word_t a2) {
  465. const crypto_word_t t = s1 ^ a2;
  466. *out_s = t & (s2 ^ a1);
  467. *out_a = (a1 ^ a2) | (t ^ s2);
  468. }
  469. // poly3_word_sub sets (|out_s|, |out_a|) to (|s1|, |a1|) - (|s2|, |a2|).
  470. static void poly3_word_sub(crypto_word_t *out_s, crypto_word_t *out_a,
  471. const crypto_word_t s1, const crypto_word_t a1,
  472. const crypto_word_t s2, const crypto_word_t a2) {
  473. const crypto_word_t t = a1 ^ a2;
  474. *out_s = (s1 ^ a2) & (t ^ s2);
  475. *out_a = t | (s1 ^ s2);
  476. }
  477. // lsb_to_all replicates the least-significant bit of |v| to all bits of the
  478. // word. This is used in bit-slicing operations to make a vector from a fixed
  479. // value.
  480. static crypto_word_t lsb_to_all(crypto_word_t v) { return 0u - (v & 1); }
  481. // poly3_mul_const sets |p| to |p|×m, where m = (ms, ma).
  482. static void poly3_mul_const(struct poly3 *p, crypto_word_t ms,
  483. crypto_word_t ma) {
  484. ms = lsb_to_all(ms);
  485. ma = lsb_to_all(ma);
  486. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  487. poly3_word_mul(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], ms, ma);
  488. }
  489. }
  490. // poly3_rotr_consttime right-rotates |p| by |bits| in constant-time.
  491. static void poly3_rotr_consttime(struct poly3 *p, size_t bits) {
  492. assert(bits <= N);
  493. HRSS_poly2_rotr_consttime(&p->s, bits);
  494. HRSS_poly2_rotr_consttime(&p->a, bits);
  495. }
  496. // poly3_fmadd sets |out| to |out| - |in|×m, where m is (ms, ma).
  497. static void poly3_fmsub(struct poly3 *RESTRICT out,
  498. const struct poly3 *RESTRICT in, crypto_word_t ms,
  499. crypto_word_t ma) {
  500. crypto_word_t product_s, product_a;
  501. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  502. poly3_word_mul(&product_s, &product_a, in->s.v[i], in->a.v[i], ms, ma);
  503. poly3_word_sub(&out->s.v[i], &out->a.v[i], out->s.v[i], out->a.v[i],
  504. product_s, product_a);
  505. }
  506. }
  507. // final_bit_to_all replicates the bit in the final position of the last word to
  508. // all the bits in the word.
  509. static crypto_word_t final_bit_to_all(crypto_word_t v) {
  510. return lsb_to_all(v >> (BITS_IN_LAST_WORD - 1));
  511. }
  512. // poly3_top_bits_are_clear returns one iff the extra bits in the final words of
  513. // |p| are zero.
  514. OPENSSL_UNUSED static int poly3_top_bits_are_clear(const struct poly3 *p) {
  515. return poly2_top_bits_are_clear(&p->s) && poly2_top_bits_are_clear(&p->a);
  516. }
  517. // poly3_mod_phiN reduces |p| by Φ(N).
  518. static void poly3_mod_phiN(struct poly3 *p) {
  519. // In order to reduce by Φ(N) we subtract by the value of the greatest
  520. // coefficient.
  521. const crypto_word_t factor_s = final_bit_to_all(p->s.v[WORDS_PER_POLY - 1]);
  522. const crypto_word_t factor_a = final_bit_to_all(p->a.v[WORDS_PER_POLY - 1]);
  523. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  524. poly3_word_sub(&p->s.v[i], &p->a.v[i], p->s.v[i], p->a.v[i], factor_s,
  525. factor_a);
  526. }
  527. poly2_clear_top_bits(&p->s);
  528. poly2_clear_top_bits(&p->a);
  529. }
  530. static void poly3_cswap(struct poly3 *a, struct poly3 *b, crypto_word_t swap) {
  531. poly2_cswap(&a->s, &b->s, swap);
  532. poly2_cswap(&a->a, &b->a, swap);
  533. }
  534. static void poly3_lshift1(struct poly3 *p) {
  535. poly2_lshift1(&p->s);
  536. poly2_lshift1(&p->a);
  537. }
  538. static void poly3_rshift1(struct poly3 *p) {
  539. poly2_rshift1(&p->s);
  540. poly2_rshift1(&p->a);
  541. }
  542. // poly3_span represents a pointer into a poly3.
  543. struct poly3_span {
  544. crypto_word_t *s;
  545. crypto_word_t *a;
  546. };
  547. // poly3_span_add adds |n| words of values from |a| and |b| and writes the
  548. // result to |out|.
  549. static void poly3_span_add(const struct poly3_span *out,
  550. const struct poly3_span *a,
  551. const struct poly3_span *b, size_t n) {
  552. for (size_t i = 0; i < n; i++) {
  553. poly3_word_add(&out->s[i], &out->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
  554. }
  555. }
  556. // poly3_span_sub subtracts |n| words of |b| from |n| words of |a|.
  557. static void poly3_span_sub(const struct poly3_span *a,
  558. const struct poly3_span *b, size_t n) {
  559. for (size_t i = 0; i < n; i++) {
  560. poly3_word_sub(&a->s[i], &a->a[i], a->s[i], a->a[i], b->s[i], b->a[i]);
  561. }
  562. }
  563. // poly3_mul_aux is a recursive function that multiplies |n| words from |a| and
  564. // |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements of
  565. // |scratch| and the function recurses, except if |n| == 1, when |scratch| isn't
  566. // used and the recursion stops. For |n| in {11, 22}, the transitive total
  567. // amount of |scratch| needed happens to be 2n+2.
  568. static void poly3_mul_aux(const struct poly3_span *out,
  569. const struct poly3_span *scratch,
  570. const struct poly3_span *a,
  571. const struct poly3_span *b, size_t n) {
  572. if (n == 1) {
  573. crypto_word_t r_s_low = 0, r_s_high = 0, r_a_low = 0, r_a_high = 0;
  574. crypto_word_t b_s = b->s[0], b_a = b->a[0];
  575. const crypto_word_t a_s = a->s[0], a_a = a->a[0];
  576. for (size_t i = 0; i < BITS_PER_WORD; i++) {
  577. // Multiply (s, a) by the next value from (b_s, b_a).
  578. crypto_word_t m_s, m_a;
  579. poly3_word_mul(&m_s, &m_a, a_s, a_a, lsb_to_all(b_s), lsb_to_all(b_a));
  580. b_s >>= 1;
  581. b_a >>= 1;
  582. if (i == 0) {
  583. // Special case otherwise the code tries to shift by BITS_PER_WORD
  584. // below, which is undefined.
  585. r_s_low = m_s;
  586. r_a_low = m_a;
  587. continue;
  588. }
  589. // Shift the multiplication result to the correct position.
  590. const crypto_word_t m_s_low = m_s << i;
  591. const crypto_word_t m_s_high = m_s >> (BITS_PER_WORD - i);
  592. const crypto_word_t m_a_low = m_a << i;
  593. const crypto_word_t m_a_high = m_a >> (BITS_PER_WORD - i);
  594. // Add into the result.
  595. poly3_word_add(&r_s_low, &r_a_low, r_s_low, r_a_low, m_s_low, m_a_low);
  596. poly3_word_add(&r_s_high, &r_a_high, r_s_high, r_a_high, m_s_high,
  597. m_a_high);
  598. }
  599. out->s[0] = r_s_low;
  600. out->s[1] = r_s_high;
  601. out->a[0] = r_a_low;
  602. out->a[1] = r_a_high;
  603. return;
  604. }
  605. // Karatsuba multiplication.
  606. // https://en.wikipedia.org/wiki/Karatsuba_algorithm
  607. // When |n| is odd, the two "halves" will have different lengths. The first
  608. // is always the smaller.
  609. const size_t low_len = n / 2;
  610. const size_t high_len = n - low_len;
  611. const struct poly3_span a_high = {&a->s[low_len], &a->a[low_len]};
  612. const struct poly3_span b_high = {&b->s[low_len], &b->a[low_len]};
  613. // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
  614. // half.
  615. const struct poly3_span a_cross_sum = *out;
  616. const struct poly3_span b_cross_sum = {&out->s[high_len], &out->a[high_len]};
  617. poly3_span_add(&a_cross_sum, a, &a_high, low_len);
  618. poly3_span_add(&b_cross_sum, b, &b_high, low_len);
  619. if (high_len != low_len) {
  620. a_cross_sum.s[low_len] = a_high.s[low_len];
  621. a_cross_sum.a[low_len] = a_high.a[low_len];
  622. b_cross_sum.s[low_len] = b_high.s[low_len];
  623. b_cross_sum.a[low_len] = b_high.a[low_len];
  624. }
  625. const struct poly3_span child_scratch = {&scratch->s[2 * high_len],
  626. &scratch->a[2 * high_len]};
  627. const struct poly3_span out_mid = {&out->s[low_len], &out->a[low_len]};
  628. const struct poly3_span out_high = {&out->s[2 * low_len],
  629. &out->a[2 * low_len]};
  630. // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
  631. poly3_mul_aux(scratch, &child_scratch, &a_cross_sum, &b_cross_sum, high_len);
  632. // Calculate a_1 × b_1.
  633. poly3_mul_aux(&out_high, &child_scratch, &a_high, &b_high, high_len);
  634. // Calculate a_0 × b_0.
  635. poly3_mul_aux(out, &child_scratch, a, b, low_len);
  636. // Subtract those last two products from the first.
  637. poly3_span_sub(scratch, out, low_len * 2);
  638. poly3_span_sub(scratch, &out_high, high_len * 2);
  639. // Add the middle product into the output.
  640. poly3_span_add(&out_mid, &out_mid, scratch, high_len * 2);
  641. }
  642. // HRSS_poly3_mul sets |*out| to |x|×|y| mod Φ(N).
  643. void HRSS_poly3_mul(struct poly3 *out, const struct poly3 *x,
  644. const struct poly3 *y) {
  645. crypto_word_t prod_s[WORDS_PER_POLY * 2];
  646. crypto_word_t prod_a[WORDS_PER_POLY * 2];
  647. crypto_word_t scratch_s[WORDS_PER_POLY * 2 + 2];
  648. crypto_word_t scratch_a[WORDS_PER_POLY * 2 + 2];
  649. const struct poly3_span prod_span = {prod_s, prod_a};
  650. const struct poly3_span scratch_span = {scratch_s, scratch_a};
  651. const struct poly3_span x_span = {(crypto_word_t *)x->s.v,
  652. (crypto_word_t *)x->a.v};
  653. const struct poly3_span y_span = {(crypto_word_t *)y->s.v,
  654. (crypto_word_t *)y->a.v};
  655. poly3_mul_aux(&prod_span, &scratch_span, &x_span, &y_span, WORDS_PER_POLY);
  656. // |prod| needs to be reduced mod (𝑥^n - 1), which just involves adding the
  657. // upper-half to the lower-half. However, N is 701, which isn't a multiple of
  658. // BITS_PER_WORD, so the upper-half vectors all have to be shifted before
  659. // being added to the lower-half.
  660. for (size_t i = 0; i < WORDS_PER_POLY; i++) {
  661. crypto_word_t v_s = prod_s[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
  662. v_s |= prod_s[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
  663. crypto_word_t v_a = prod_a[WORDS_PER_POLY + i - 1] >> BITS_IN_LAST_WORD;
  664. v_a |= prod_a[WORDS_PER_POLY + i] << (BITS_PER_WORD - BITS_IN_LAST_WORD);
  665. poly3_word_add(&out->s.v[i], &out->a.v[i], prod_s[i], prod_a[i], v_s, v_a);
  666. }
  667. poly3_mod_phiN(out);
  668. }
  669. #if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
  670. // poly3_vec_cswap swaps (|a_s|, |a_a|) and (|b_s|, |b_a|) if |swap| is
  671. // |0xff..ff|. Otherwise, |swap| must be zero.
  672. static inline void poly3_vec_cswap(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
  673. vec_t b_a[6], const vec_t swap) {
  674. for (int i = 0; i < 6; i++) {
  675. const vec_t sum_s = swap & (a_s[i] ^ b_s[i]);
  676. a_s[i] ^= sum_s;
  677. b_s[i] ^= sum_s;
  678. const vec_t sum_a = swap & (a_a[i] ^ b_a[i]);
  679. a_a[i] ^= sum_a;
  680. b_a[i] ^= sum_a;
  681. }
  682. }
  683. // poly3_vec_fmsub subtracts (|ms|, |ma|) × (|b_s|, |b_a|) from (|a_s|, |a_a|).
  684. static inline void poly3_vec_fmsub(vec_t a_s[6], vec_t a_a[6], vec_t b_s[6],
  685. vec_t b_a[6], const vec_t ms,
  686. const vec_t ma) {
  687. for (int i = 0; i < 6; i++) {
  688. // See the bitslice formula, above.
  689. const vec_t s = b_s[i];
  690. const vec_t a = b_a[i];
  691. const vec_t product_a = a & ma;
  692. const vec_t product_s = (s ^ ms) & product_a;
  693. const vec_t out_s = a_s[i];
  694. const vec_t out_a = a_a[i];
  695. const vec_t t = out_a ^ product_a;
  696. a_s[i] = (out_s ^ product_a) & (t ^ product_s);
  697. a_a[i] = t | (out_s ^ product_s);
  698. }
  699. }
  700. // poly3_invert_vec sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
  701. // Φ(N).
  702. static void poly3_invert_vec(struct poly3 *out, const struct poly3 *in) {
  703. // See the comment in |HRSS_poly3_invert| about this algorithm. In addition to
  704. // the changes described there, this implementation attempts to use vector
  705. // registers to speed up the computation. Even non-poly3 variables are held in
  706. // vectors where possible to minimise the amount of data movement between
  707. // the vector and general-purpose registers.
  708. vec_t b_s[6], b_a[6], c_s[6], c_a[6], f_s[6], f_a[6], g_s[6], g_a[6];
  709. const vec_t kZero = {0};
  710. const vec_t kOne = {1};
  711. static const uint8_t kOneBytes[sizeof(vec_t)] = {1};
  712. static const uint8_t kBottomSixtyOne[sizeof(vec_t)] = {
  713. 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x1f};
  714. memset(b_s, 0, sizeof(b_s));
  715. memcpy(b_a, kOneBytes, sizeof(kOneBytes));
  716. memset(&b_a[1], 0, 5 * sizeof(vec_t));
  717. memset(c_s, 0, sizeof(c_s));
  718. memset(c_a, 0, sizeof(c_a));
  719. f_s[5] = kZero;
  720. memcpy(f_s, in->s.v, WORDS_PER_POLY * sizeof(crypto_word_t));
  721. f_a[5] = kZero;
  722. memcpy(f_a, in->a.v, WORDS_PER_POLY * sizeof(crypto_word_t));
  723. // Set g to all ones.
  724. memset(g_s, 0, sizeof(g_s));
  725. memset(g_a, 0xff, 5 * sizeof(vec_t));
  726. memcpy(&g_a[5], kBottomSixtyOne, sizeof(kBottomSixtyOne));
  727. vec_t deg_f = {N - 1}, deg_g = {N - 1}, rotation = kZero;
  728. vec_t k = kOne;
  729. vec_t f0s = {0}, f0a = {0};
  730. vec_t still_going;
  731. memset(&still_going, 0xff, sizeof(still_going));
  732. for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
  733. const vec_t s_a = vec_broadcast_bit(still_going & (f_a[0] & g_a[0]));
  734. const vec_t s_s =
  735. vec_broadcast_bit(still_going & ((f_s[0] ^ g_s[0]) & s_a));
  736. const vec_t should_swap =
  737. (s_s | s_a) & vec_broadcast_bit15(deg_f - deg_g);
  738. poly3_vec_cswap(f_s, f_a, g_s, g_a, should_swap);
  739. poly3_vec_fmsub(f_s, f_a, g_s, g_a, s_s, s_a);
  740. poly3_vec_rshift1(f_s, f_a);
  741. poly3_vec_cswap(b_s, b_a, c_s, c_a, should_swap);
  742. poly3_vec_fmsub(b_s, b_a, c_s, c_a, s_s, s_a);
  743. poly3_vec_lshift1(c_s, c_a);
  744. const vec_t deg_sum = should_swap & (deg_f ^ deg_g);
  745. deg_f ^= deg_sum;
  746. deg_g ^= deg_sum;
  747. deg_f -= kOne;
  748. still_going &= ~vec_broadcast_bit15(deg_f - kOne);
  749. const vec_t f0_is_nonzero = vec_broadcast_bit(f_s[0] | f_a[0]);
  750. // |f0_is_nonzero| implies |still_going|.
  751. rotation ^= f0_is_nonzero & (k ^ rotation);
  752. k += kOne;
  753. const vec_t f0s_sum = f0_is_nonzero & (f_s[0] ^ f0s);
  754. f0s ^= f0s_sum;
  755. const vec_t f0a_sum = f0_is_nonzero & (f_a[0] ^ f0a);
  756. f0a ^= f0a_sum;
  757. }
  758. crypto_word_t rotation_word = vec_get_word(rotation, 0);
  759. rotation_word -= N & constant_time_lt_w(N, rotation_word);
  760. memcpy(out->s.v, b_s, WORDS_PER_POLY * sizeof(crypto_word_t));
  761. memcpy(out->a.v, b_a, WORDS_PER_POLY * sizeof(crypto_word_t));
  762. assert(poly3_top_bits_are_clear(out));
  763. poly3_rotr_consttime(out, rotation_word);
  764. poly3_mul_const(out, vec_get_word(f0s, 0), vec_get_word(f0a, 0));
  765. poly3_mod_phiN(out);
  766. }
  767. #endif // HRSS_HAVE_VECTOR_UNIT
  768. // HRSS_poly3_invert sets |*out| to |in|^-1, i.e. such that |out|×|in| == 1 mod
  769. // Φ(N).
  770. void HRSS_poly3_invert(struct poly3 *out, const struct poly3 *in) {
  771. // The vector version of this function seems slightly slower on AArch64, but
  772. // is useful on ARMv7 and x86-64.
  773. #if defined(HRSS_HAVE_VECTOR_UNIT) && !defined(OPENSSL_AARCH64)
  774. if (vec_capable()) {
  775. poly3_invert_vec(out, in);
  776. return;
  777. }
  778. #endif
  779. // This algorithm mostly follows algorithm 10 in the paper. Some changes:
  780. // 1) k should start at zero, not one. In the code below k is omitted and
  781. // the loop counter, |i|, is used instead.
  782. // 2) The rotation count is conditionally updated to handle trailing zero
  783. // coefficients.
  784. // The best explanation for why it works is in the "Why it works" section of
  785. // [NTRUTN14].
  786. struct poly3 c, f, g;
  787. OPENSSL_memcpy(&f, in, sizeof(f));
  788. // Set g to all ones.
  789. OPENSSL_memset(&g.s, 0, sizeof(struct poly2));
  790. OPENSSL_memset(&g.a, 0xff, sizeof(struct poly2));
  791. g.a.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
  792. struct poly3 *b = out;
  793. poly3_zero(b);
  794. poly3_zero(&c);
  795. // Set b to one.
  796. b->a.v[0] = 1;
  797. crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
  798. crypto_word_t f0s = 0, f0a = 0;
  799. crypto_word_t still_going = CONSTTIME_TRUE_W;
  800. for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
  801. const crypto_word_t s_a = lsb_to_all(
  802. still_going & (f.a.v[0] & g.a.v[0]));
  803. const crypto_word_t s_s = lsb_to_all(
  804. still_going & ((f.s.v[0] ^ g.s.v[0]) & s_a));
  805. const crypto_word_t should_swap =
  806. (s_s | s_a) & constant_time_lt_w(deg_f, deg_g);
  807. poly3_cswap(&f, &g, should_swap);
  808. poly3_cswap(b, &c, should_swap);
  809. const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
  810. deg_f ^= deg_sum;
  811. deg_g ^= deg_sum;
  812. assert(deg_g >= 1);
  813. poly3_fmsub(&f, &g, s_s, s_a);
  814. poly3_fmsub(b, &c, s_s, s_a);
  815. poly3_rshift1(&f);
  816. poly3_lshift1(&c);
  817. deg_f--;
  818. const crypto_word_t f0_is_nonzero =
  819. lsb_to_all(f.s.v[0]) | lsb_to_all(f.a.v[0]);
  820. // |f0_is_nonzero| implies |still_going|.
  821. assert(!(f0_is_nonzero && !still_going));
  822. still_going &= ~constant_time_is_zero_w(deg_f);
  823. rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
  824. f0s = constant_time_select_w(f0_is_nonzero, f.s.v[0], f0s);
  825. f0a = constant_time_select_w(f0_is_nonzero, f.a.v[0], f0a);
  826. }
  827. rotation++;
  828. rotation -= N & constant_time_lt_w(N, rotation);
  829. assert(poly3_top_bits_are_clear(out));
  830. poly3_rotr_consttime(out, rotation);
  831. poly3_mul_const(out, f0s, f0a);
  832. poly3_mod_phiN(out);
  833. }
  834. // Polynomials in Q.
  835. // Coefficients are reduced mod Q. (Q is clearly not prime, therefore the
  836. // coefficients do not form a field.)
  837. #define Q 8192
  838. // VECS_PER_POLY is the number of 128-bit vectors needed to represent a
  839. // polynomial.
  840. #define COEFFICIENTS_PER_VEC (sizeof(vec_t) / sizeof(uint16_t))
  841. #define VECS_PER_POLY ((N + COEFFICIENTS_PER_VEC - 1) / COEFFICIENTS_PER_VEC)
  842. // poly represents a polynomial with coefficients mod Q. Note that, while Q is a
  843. // power of two, this does not operate in GF(Q). That would be a binary field
  844. // but this is simply mod Q. Thus the coefficients are not a field.
  845. //
  846. // Coefficients are ordered little-endian, thus the coefficient of x^0 is the
  847. // first element of the array.
  848. struct poly {
  849. #if defined(HRSS_HAVE_VECTOR_UNIT)
  850. union {
  851. // N + 3 = 704, which is a multiple of 64 and thus aligns things, esp for
  852. // the vector code.
  853. uint16_t v[N + 3];
  854. vec_t vectors[VECS_PER_POLY];
  855. };
  856. #else
  857. // Even if !HRSS_HAVE_VECTOR_UNIT, external assembly may be called that
  858. // requires alignment.
  859. alignas(16) uint16_t v[N + 3];
  860. #endif
  861. };
  862. OPENSSL_UNUSED static void poly_print(const struct poly *p) {
  863. printf("[");
  864. for (unsigned i = 0; i < N; i++) {
  865. if (i) {
  866. printf(" ");
  867. }
  868. printf("%d", p->v[i]);
  869. }
  870. printf("]\n");
  871. }
  872. #if defined(HRSS_HAVE_VECTOR_UNIT)
  873. // poly_mul_vec_aux is a recursive function that multiplies |n| words from |a|
  874. // and |b| and writes 2×|n| words to |out|. Each call uses 2*ceil(n/2) elements
  875. // of |scratch| and the function recurses, except if |n| < 3, when |scratch|
  876. // isn't used and the recursion stops. If |n| == |VECS_PER_POLY| then |scratch|
  877. // needs 172 elements.
  878. static void poly_mul_vec_aux(vec_t *restrict out, vec_t *restrict scratch,
  879. const vec_t *restrict a, const vec_t *restrict b,
  880. const size_t n) {
  881. // In [HRSS], the technique they used for polynomial multiplication is
  882. // described: they start with Toom-4 at the top level and then two layers of
  883. // Karatsuba. Karatsuba is a specific instance of the general Toom–Cook
  884. // decomposition, which splits an input n-ways and produces 2n-1
  885. // multiplications of those parts. So, starting with 704 coefficients (rounded
  886. // up from 701 to have more factors of two), Toom-4 gives seven
  887. // multiplications of degree-174 polynomials. Each round of Karatsuba (which
  888. // is Toom-2) increases the number of multiplications by a factor of three
  889. // while halving the size of the values being multiplied. So two rounds gives
  890. // 63 multiplications of degree-44 polynomials. Then they (I think) form
  891. // vectors by gathering all 63 coefficients of each power together, for each
  892. // input, and doing more rounds of Karatsuba on the vectors until they bottom-
  893. // out somewhere with schoolbook multiplication.
  894. //
  895. // I tried something like that for NEON. NEON vectors are 128 bits so hold
  896. // eight coefficients. I wrote a function that did Karatsuba on eight
  897. // multiplications at the same time, using such vectors, and a Go script that
  898. // decomposed from degree-704, with Karatsuba in non-transposed form, until it
  899. // reached multiplications of degree-44. It batched up those 81
  900. // multiplications into lots of eight with a single one left over (which was
  901. // handled directly).
  902. //
  903. // It worked, but it was significantly slower than the dumb algorithm used
  904. // below. Potentially that was because I misunderstood how [HRSS] did it, or
  905. // because Clang is bad at generating good code from NEON intrinsics on ARMv7.
  906. // (Which is true: the code generated by Clang for the below is pretty crap.)
  907. //
  908. // This algorithm is much simpler. It just does Karatsuba decomposition all
  909. // the way down and never transposes. When it gets down to degree-16 or
  910. // degree-24 values, they are multiplied using schoolbook multiplication and
  911. // vector intrinsics. The vector operations form each of the eight phase-
  912. // shifts of one of the inputs, point-wise multiply, and then add into the
  913. // result at the correct place. This means that 33% (degree-16) or 25%
  914. // (degree-24) of the multiplies and adds are wasted, but it does ok.
  915. if (n == 2) {
  916. vec_t result[4];
  917. vec_t vec_a[3];
  918. static const vec_t kZero = {0};
  919. vec_a[0] = a[0];
  920. vec_a[1] = a[1];
  921. vec_a[2] = kZero;
  922. result[0] = vec_mul(vec_a[0], vec_get_word(b[0], 0));
  923. result[1] = vec_mul(vec_a[1], vec_get_word(b[0], 0));
  924. result[1] = vec_fma(result[1], vec_a[0], vec_get_word(b[1], 0));
  925. result[2] = vec_mul(vec_a[1], vec_get_word(b[1], 0));
  926. result[3] = kZero;
  927. vec3_rshift_word(vec_a);
  928. #define BLOCK(x, y) \
  929. do { \
  930. result[x + 0] = \
  931. vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
  932. result[x + 1] = \
  933. vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
  934. result[x + 2] = \
  935. vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
  936. } while (0)
  937. BLOCK(0, 1);
  938. BLOCK(1, 9);
  939. vec3_rshift_word(vec_a);
  940. BLOCK(0, 2);
  941. BLOCK(1, 10);
  942. vec3_rshift_word(vec_a);
  943. BLOCK(0, 3);
  944. BLOCK(1, 11);
  945. vec3_rshift_word(vec_a);
  946. BLOCK(0, 4);
  947. BLOCK(1, 12);
  948. vec3_rshift_word(vec_a);
  949. BLOCK(0, 5);
  950. BLOCK(1, 13);
  951. vec3_rshift_word(vec_a);
  952. BLOCK(0, 6);
  953. BLOCK(1, 14);
  954. vec3_rshift_word(vec_a);
  955. BLOCK(0, 7);
  956. BLOCK(1, 15);
  957. #undef BLOCK
  958. memcpy(out, result, sizeof(result));
  959. return;
  960. }
  961. if (n == 3) {
  962. vec_t result[6];
  963. vec_t vec_a[4];
  964. static const vec_t kZero = {0};
  965. vec_a[0] = a[0];
  966. vec_a[1] = a[1];
  967. vec_a[2] = a[2];
  968. vec_a[3] = kZero;
  969. result[0] = vec_mul(a[0], vec_get_word(b[0], 0));
  970. result[1] = vec_mul(a[1], vec_get_word(b[0], 0));
  971. result[2] = vec_mul(a[2], vec_get_word(b[0], 0));
  972. #define BLOCK_PRE(x, y) \
  973. do { \
  974. result[x + 0] = \
  975. vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
  976. result[x + 1] = \
  977. vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
  978. result[x + 2] = vec_mul(vec_a[2], vec_get_word(b[y / 8], y % 8)); \
  979. } while (0)
  980. BLOCK_PRE(1, 8);
  981. BLOCK_PRE(2, 16);
  982. result[5] = kZero;
  983. vec4_rshift_word(vec_a);
  984. #define BLOCK(x, y) \
  985. do { \
  986. result[x + 0] = \
  987. vec_fma(result[x + 0], vec_a[0], vec_get_word(b[y / 8], y % 8)); \
  988. result[x + 1] = \
  989. vec_fma(result[x + 1], vec_a[1], vec_get_word(b[y / 8], y % 8)); \
  990. result[x + 2] = \
  991. vec_fma(result[x + 2], vec_a[2], vec_get_word(b[y / 8], y % 8)); \
  992. result[x + 3] = \
  993. vec_fma(result[x + 3], vec_a[3], vec_get_word(b[y / 8], y % 8)); \
  994. } while (0)
  995. BLOCK(0, 1);
  996. BLOCK(1, 9);
  997. BLOCK(2, 17);
  998. vec4_rshift_word(vec_a);
  999. BLOCK(0, 2);
  1000. BLOCK(1, 10);
  1001. BLOCK(2, 18);
  1002. vec4_rshift_word(vec_a);
  1003. BLOCK(0, 3);
  1004. BLOCK(1, 11);
  1005. BLOCK(2, 19);
  1006. vec4_rshift_word(vec_a);
  1007. BLOCK(0, 4);
  1008. BLOCK(1, 12);
  1009. BLOCK(2, 20);
  1010. vec4_rshift_word(vec_a);
  1011. BLOCK(0, 5);
  1012. BLOCK(1, 13);
  1013. BLOCK(2, 21);
  1014. vec4_rshift_word(vec_a);
  1015. BLOCK(0, 6);
  1016. BLOCK(1, 14);
  1017. BLOCK(2, 22);
  1018. vec4_rshift_word(vec_a);
  1019. BLOCK(0, 7);
  1020. BLOCK(1, 15);
  1021. BLOCK(2, 23);
  1022. #undef BLOCK
  1023. #undef BLOCK_PRE
  1024. memcpy(out, result, sizeof(result));
  1025. return;
  1026. }
  1027. // Karatsuba multiplication.
  1028. // https://en.wikipedia.org/wiki/Karatsuba_algorithm
  1029. // When |n| is odd, the two "halves" will have different lengths. The first is
  1030. // always the smaller.
  1031. const size_t low_len = n / 2;
  1032. const size_t high_len = n - low_len;
  1033. const vec_t *a_high = &a[low_len];
  1034. const vec_t *b_high = &b[low_len];
  1035. // Store a_1 + a_0 in the first half of |out| and b_1 + b_0 in the second
  1036. // half.
  1037. for (size_t i = 0; i < low_len; i++) {
  1038. out[i] = vec_add(a_high[i], a[i]);
  1039. out[high_len + i] = vec_add(b_high[i], b[i]);
  1040. }
  1041. if (high_len != low_len) {
  1042. out[low_len] = a_high[low_len];
  1043. out[high_len + low_len] = b_high[low_len];
  1044. }
  1045. vec_t *const child_scratch = &scratch[2 * high_len];
  1046. // Calculate (a_1 + a_0) × (b_1 + b_0) and write to scratch buffer.
  1047. poly_mul_vec_aux(scratch, child_scratch, out, &out[high_len], high_len);
  1048. // Calculate a_1 × b_1.
  1049. poly_mul_vec_aux(&out[low_len * 2], child_scratch, a_high, b_high, high_len);
  1050. // Calculate a_0 × b_0.
  1051. poly_mul_vec_aux(out, child_scratch, a, b, low_len);
  1052. // Subtract those last two products from the first.
  1053. for (size_t i = 0; i < low_len * 2; i++) {
  1054. scratch[i] = vec_sub(scratch[i], vec_add(out[i], out[low_len * 2 + i]));
  1055. }
  1056. if (low_len != high_len) {
  1057. scratch[low_len * 2] = vec_sub(scratch[low_len * 2], out[low_len * 4]);
  1058. scratch[low_len * 2 + 1] =
  1059. vec_sub(scratch[low_len * 2 + 1], out[low_len * 4 + 1]);
  1060. }
  1061. // Add the middle product into the output.
  1062. for (size_t i = 0; i < high_len * 2; i++) {
  1063. out[low_len + i] = vec_add(out[low_len + i], scratch[i]);
  1064. }
  1065. }
  1066. // poly_mul_vec sets |*out| to |x|×|y| mod (𝑥^n - 1).
  1067. static void poly_mul_vec(struct poly *out, const struct poly *x,
  1068. const struct poly *y) {
  1069. OPENSSL_memset((uint16_t *)&x->v[N], 0, 3 * sizeof(uint16_t));
  1070. OPENSSL_memset((uint16_t *)&y->v[N], 0, 3 * sizeof(uint16_t));
  1071. OPENSSL_STATIC_ASSERT(sizeof(out->v) == sizeof(vec_t) * VECS_PER_POLY,
  1072. "struct poly is the wrong size");
  1073. OPENSSL_STATIC_ASSERT(alignof(struct poly) == alignof(vec_t),
  1074. "struct poly has incorrect alignment");
  1075. vec_t prod[VECS_PER_POLY * 2];
  1076. vec_t scratch[172];
  1077. poly_mul_vec_aux(prod, scratch, x->vectors, y->vectors, VECS_PER_POLY);
  1078. // |prod| needs to be reduced mod (𝑥^n - 1), which just involves adding the
  1079. // upper-half to the lower-half. However, N is 701, which isn't a multiple of
  1080. // the vector size, so the upper-half vectors all have to be shifted before
  1081. // being added to the lower-half.
  1082. vec_t *out_vecs = (vec_t *)out->v;
  1083. for (size_t i = 0; i < VECS_PER_POLY; i++) {
  1084. const vec_t prev = prod[VECS_PER_POLY - 1 + i];
  1085. const vec_t this = prod[VECS_PER_POLY + i];
  1086. out_vecs[i] = vec_add(prod[i], vec_merge_3_5(prev, this));
  1087. }
  1088. OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
  1089. }
  1090. #endif // HRSS_HAVE_VECTOR_UNIT
  1091. // poly_mul_novec_aux writes the product of |a| and |b| to |out|, using
  1092. // |scratch| as scratch space. It'll use Karatsuba if the inputs are large
  1093. // enough to warrant it. Each call uses 2*ceil(n/2) elements of |scratch| and
  1094. // the function recurses, except if |n| < 64, when |scratch| isn't used and the
  1095. // recursion stops. If |n| == |N| then |scratch| needs 1318 elements.
  1096. static void poly_mul_novec_aux(uint16_t *out, uint16_t *scratch,
  1097. const uint16_t *a, const uint16_t *b, size_t n) {
  1098. static const size_t kSchoolbookLimit = 64;
  1099. if (n < kSchoolbookLimit) {
  1100. OPENSSL_memset(out, 0, sizeof(uint16_t) * n * 2);
  1101. for (size_t i = 0; i < n; i++) {
  1102. for (size_t j = 0; j < n; j++) {
  1103. out[i + j] += (unsigned) a[i] * b[j];
  1104. }
  1105. }
  1106. return;
  1107. }
  1108. // Karatsuba multiplication.
  1109. // https://en.wikipedia.org/wiki/Karatsuba_algorithm
  1110. // When |n| is odd, the two "halves" will have different lengths. The
  1111. // first is always the smaller.
  1112. const size_t low_len = n / 2;
  1113. const size_t high_len = n - low_len;
  1114. const uint16_t *const a_high = &a[low_len];
  1115. const uint16_t *const b_high = &b[low_len];
  1116. for (size_t i = 0; i < low_len; i++) {
  1117. out[i] = a_high[i] + a[i];
  1118. out[high_len + i] = b_high[i] + b[i];
  1119. }
  1120. if (high_len != low_len) {
  1121. out[low_len] = a_high[low_len];
  1122. out[high_len + low_len] = b_high[low_len];
  1123. }
  1124. uint16_t *const child_scratch = &scratch[2 * high_len];
  1125. poly_mul_novec_aux(scratch, child_scratch, out, &out[high_len], high_len);
  1126. poly_mul_novec_aux(&out[low_len * 2], child_scratch, a_high, b_high,
  1127. high_len);
  1128. poly_mul_novec_aux(out, child_scratch, a, b, low_len);
  1129. for (size_t i = 0; i < low_len * 2; i++) {
  1130. scratch[i] -= out[i] + out[low_len * 2 + i];
  1131. }
  1132. if (low_len != high_len) {
  1133. scratch[low_len * 2] -= out[low_len * 4];
  1134. assert(out[low_len * 4 + 1] == 0);
  1135. }
  1136. for (size_t i = 0; i < high_len * 2; i++) {
  1137. out[low_len + i] += scratch[i];
  1138. }
  1139. }
  1140. // poly_mul_novec sets |*out| to |x|×|y| mod (𝑥^n - 1).
  1141. static void poly_mul_novec(struct poly *out, const struct poly *x,
  1142. const struct poly *y) {
  1143. uint16_t prod[2 * N];
  1144. uint16_t scratch[1318];
  1145. poly_mul_novec_aux(prod, scratch, x->v, y->v, N);
  1146. for (size_t i = 0; i < N; i++) {
  1147. out->v[i] = prod[i] + prod[i + N];
  1148. }
  1149. OPENSSL_memset(&out->v[N], 0, 3 * sizeof(uint16_t));
  1150. }
  1151. static void poly_mul(struct poly *r, const struct poly *a,
  1152. const struct poly *b) {
  1153. #if defined(POLY_RQ_MUL_ASM)
  1154. const int has_avx2 = (OPENSSL_ia32cap_P[2] & (1 << 5)) != 0;
  1155. if (has_avx2) {
  1156. poly_Rq_mul(r->v, a->v, b->v);
  1157. return;
  1158. }
  1159. #endif
  1160. #if defined(HRSS_HAVE_VECTOR_UNIT)
  1161. if (vec_capable()) {
  1162. poly_mul_vec(r, a, b);
  1163. return;
  1164. }
  1165. #endif
  1166. // Fallback, non-vector case.
  1167. poly_mul_novec(r, a, b);
  1168. }
  1169. // poly_mul_x_minus_1 sets |p| to |p|×(𝑥 - 1) mod (𝑥^n - 1).
  1170. static void poly_mul_x_minus_1(struct poly *p) {
  1171. // Multiplying by (𝑥 - 1) means negating each coefficient and adding in
  1172. // the value of the previous one.
  1173. const uint16_t orig_final_coefficient = p->v[N - 1];
  1174. for (size_t i = N - 1; i > 0; i--) {
  1175. p->v[i] = p->v[i - 1] - p->v[i];
  1176. }
  1177. p->v[0] = orig_final_coefficient - p->v[0];
  1178. }
  1179. // poly_mod_phiN sets |p| to |p| mod Φ(N).
  1180. static void poly_mod_phiN(struct poly *p) {
  1181. const uint16_t coeff700 = p->v[N - 1];
  1182. for (unsigned i = 0; i < N; i++) {
  1183. p->v[i] -= coeff700;
  1184. }
  1185. }
  1186. // poly_clamp reduces each coefficient mod Q.
  1187. static void poly_clamp(struct poly *p) {
  1188. for (unsigned i = 0; i < N; i++) {
  1189. p->v[i] &= Q - 1;
  1190. }
  1191. }
  1192. // Conversion functions
  1193. // --------------------
  1194. // poly2_from_poly sets |*out| to |in| mod 2.
  1195. static void poly2_from_poly(struct poly2 *out, const struct poly *in) {
  1196. crypto_word_t *words = out->v;
  1197. unsigned shift = 0;
  1198. crypto_word_t word = 0;
  1199. for (unsigned i = 0; i < N; i++) {
  1200. word >>= 1;
  1201. word |= (crypto_word_t)(in->v[i] & 1) << (BITS_PER_WORD - 1);
  1202. shift++;
  1203. if (shift == BITS_PER_WORD) {
  1204. *words = word;
  1205. words++;
  1206. word = 0;
  1207. shift = 0;
  1208. }
  1209. }
  1210. word >>= BITS_PER_WORD - shift;
  1211. *words = word;
  1212. }
  1213. // mod3 treats |a| as a signed number and returns |a| mod 3.
  1214. static uint16_t mod3(int16_t a) {
  1215. const int16_t q = ((int32_t)a * 21845) >> 16;
  1216. int16_t ret = a - 3 * q;
  1217. // At this point, |ret| is in {0, 1, 2, 3} and that needs to be mapped to {0,
  1218. // 1, 2, 0}.
  1219. return ret & ((ret & (ret >> 1)) - 1);
  1220. }
  1221. // poly3_from_poly sets |*out| to |in|.
  1222. static void poly3_from_poly(struct poly3 *out, const struct poly *in) {
  1223. crypto_word_t *words_s = out->s.v;
  1224. crypto_word_t *words_a = out->a.v;
  1225. crypto_word_t s = 0;
  1226. crypto_word_t a = 0;
  1227. unsigned shift = 0;
  1228. for (unsigned i = 0; i < N; i++) {
  1229. // This duplicates the 13th bit upwards to the top of the uint16,
  1230. // essentially treating it as a sign bit and converting into a signed int16.
  1231. // The signed value is reduced mod 3, yielding {0, 1, 2}.
  1232. const uint16_t v = mod3((int16_t)(in->v[i] << 3) >> 3);
  1233. s >>= 1;
  1234. const crypto_word_t s_bit = (crypto_word_t)(v & 2) << (BITS_PER_WORD - 2);
  1235. s |= s_bit;
  1236. a >>= 1;
  1237. a |= s_bit | (crypto_word_t)(v & 1) << (BITS_PER_WORD - 1);
  1238. shift++;
  1239. if (shift == BITS_PER_WORD) {
  1240. *words_s = s;
  1241. words_s++;
  1242. *words_a = a;
  1243. words_a++;
  1244. s = a = 0;
  1245. shift = 0;
  1246. }
  1247. }
  1248. s >>= BITS_PER_WORD - shift;
  1249. a >>= BITS_PER_WORD - shift;
  1250. *words_s = s;
  1251. *words_a = a;
  1252. }
  1253. // poly3_from_poly_checked sets |*out| to |in|, which has coefficients in {0, 1,
  1254. // Q-1}. It returns a mask indicating whether all coefficients were found to be
  1255. // in that set.
  1256. static crypto_word_t poly3_from_poly_checked(struct poly3 *out,
  1257. const struct poly *in) {
  1258. crypto_word_t *words_s = out->s.v;
  1259. crypto_word_t *words_a = out->a.v;
  1260. crypto_word_t s = 0;
  1261. crypto_word_t a = 0;
  1262. unsigned shift = 0;
  1263. crypto_word_t ok = CONSTTIME_TRUE_W;
  1264. for (unsigned i = 0; i < N; i++) {
  1265. const uint16_t v = in->v[i];
  1266. // Maps {0, 1, Q-1} to {0, 1, 2}.
  1267. uint16_t mod3 = v & 3;
  1268. mod3 ^= mod3 >> 1;
  1269. const uint16_t expected = (uint16_t)((~((mod3 >> 1) - 1)) | mod3) % Q;
  1270. ok &= constant_time_eq_w(v, expected);
  1271. s >>= 1;
  1272. const crypto_word_t s_bit = (crypto_word_t)(mod3 & 2)
  1273. << (BITS_PER_WORD - 2);
  1274. s |= s_bit;
  1275. a >>= 1;
  1276. a |= s_bit | (crypto_word_t)(mod3 & 1) << (BITS_PER_WORD - 1);
  1277. shift++;
  1278. if (shift == BITS_PER_WORD) {
  1279. *words_s = s;
  1280. words_s++;
  1281. *words_a = a;
  1282. words_a++;
  1283. s = a = 0;
  1284. shift = 0;
  1285. }
  1286. }
  1287. s >>= BITS_PER_WORD - shift;
  1288. a >>= BITS_PER_WORD - shift;
  1289. *words_s = s;
  1290. *words_a = a;
  1291. return ok;
  1292. }
  1293. static void poly_from_poly2(struct poly *out, const struct poly2 *in) {
  1294. const crypto_word_t *words = in->v;
  1295. unsigned shift = 0;
  1296. crypto_word_t word = *words;
  1297. for (unsigned i = 0; i < N; i++) {
  1298. out->v[i] = word & 1;
  1299. word >>= 1;
  1300. shift++;
  1301. if (shift == BITS_PER_WORD) {
  1302. words++;
  1303. word = *words;
  1304. shift = 0;
  1305. }
  1306. }
  1307. }
  1308. static void poly_from_poly3(struct poly *out, const struct poly3 *in) {
  1309. const crypto_word_t *words_s = in->s.v;
  1310. const crypto_word_t *words_a = in->a.v;
  1311. crypto_word_t word_s = ~(*words_s);
  1312. crypto_word_t word_a = *words_a;
  1313. unsigned shift = 0;
  1314. for (unsigned i = 0; i < N; i++) {
  1315. out->v[i] = (uint16_t)(word_s & 1) - 1;
  1316. out->v[i] |= word_a & 1;
  1317. word_s >>= 1;
  1318. word_a >>= 1;
  1319. shift++;
  1320. if (shift == BITS_PER_WORD) {
  1321. words_s++;
  1322. words_a++;
  1323. word_s = ~(*words_s);
  1324. word_a = *words_a;
  1325. shift = 0;
  1326. }
  1327. }
  1328. }
  1329. // Polynomial inversion
  1330. // --------------------
  1331. // poly_invert_mod2 sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod
  1332. // Φ(N)), all mod 2. This isn't useful in itself, but is part of doing inversion
  1333. // mod Q.
  1334. static void poly_invert_mod2(struct poly *out, const struct poly *in) {
  1335. // This algorithm follows algorithm 10 in the paper. (Although, in contrast to
  1336. // the paper, k should start at zero, not one, and the rotation count is needs
  1337. // to handle trailing zero coefficients.) The best explanation for why it
  1338. // works is in the "Why it works" section of [NTRUTN14].
  1339. struct poly2 b, c, f, g;
  1340. poly2_from_poly(&f, in);
  1341. OPENSSL_memset(&b, 0, sizeof(b));
  1342. b.v[0] = 1;
  1343. OPENSSL_memset(&c, 0, sizeof(c));
  1344. // Set g to all ones.
  1345. OPENSSL_memset(&g, 0xff, sizeof(struct poly2));
  1346. g.v[WORDS_PER_POLY - 1] >>= BITS_PER_WORD - BITS_IN_LAST_WORD;
  1347. crypto_word_t deg_f = N - 1, deg_g = N - 1, rotation = 0;
  1348. crypto_word_t still_going = CONSTTIME_TRUE_W;
  1349. for (unsigned i = 0; i < 2 * (N - 1) - 1; i++) {
  1350. const crypto_word_t s = still_going & lsb_to_all(f.v[0]);
  1351. const crypto_word_t should_swap = s & constant_time_lt_w(deg_f, deg_g);
  1352. poly2_cswap(&f, &g, should_swap);
  1353. poly2_cswap(&b, &c, should_swap);
  1354. const crypto_word_t deg_sum = should_swap & (deg_f ^ deg_g);
  1355. deg_f ^= deg_sum;
  1356. deg_g ^= deg_sum;
  1357. assert(deg_g >= 1);
  1358. poly2_fmadd(&f, &g, s);
  1359. poly2_fmadd(&b, &c, s);
  1360. poly2_rshift1(&f);
  1361. poly2_lshift1(&c);
  1362. deg_f--;
  1363. const crypto_word_t f0_is_nonzero = lsb_to_all(f.v[0]);
  1364. // |f0_is_nonzero| implies |still_going|.
  1365. assert(!(f0_is_nonzero && !still_going));
  1366. rotation = constant_time_select_w(f0_is_nonzero, i, rotation);
  1367. still_going &= ~constant_time_is_zero_w(deg_f);
  1368. }
  1369. rotation++;
  1370. rotation -= N & constant_time_lt_w(N, rotation);
  1371. assert(poly2_top_bits_are_clear(&b));
  1372. HRSS_poly2_rotr_consttime(&b, rotation);
  1373. poly_from_poly2(out, &b);
  1374. }
  1375. // poly_invert sets |*out| to |in^-1| (i.e. such that |*out|×|in| = 1 mod Φ(N)).
  1376. static void poly_invert(struct poly *out, const struct poly *in) {
  1377. // Inversion mod Q, which is done based on the result of inverting mod
  1378. // 2. See [NTRUTN14] paper, bottom of page two.
  1379. struct poly a, *b, tmp;
  1380. // a = -in.
  1381. for (unsigned i = 0; i < N; i++) {
  1382. a.v[i] = -in->v[i];
  1383. }
  1384. // b = in^-1 mod 2.
  1385. b = out;
  1386. poly_invert_mod2(b, in);
  1387. // We are working mod Q=2**13 and we need to iterate ceil(log_2(13))
  1388. // times, which is four.
  1389. for (unsigned i = 0; i < 4; i++) {
  1390. poly_mul(&tmp, &a, b);
  1391. tmp.v[0] += 2;
  1392. poly_mul(b, b, &tmp);
  1393. }
  1394. }
  1395. // Marshal and unmarshal functions for various basic types.
  1396. // --------------------------------------------------------
  1397. #define POLY_BYTES 1138
  1398. // poly_marshal serialises all but the final coefficient of |in| to |out|.
  1399. static void poly_marshal(uint8_t out[POLY_BYTES], const struct poly *in) {
  1400. const uint16_t *p = in->v;
  1401. for (size_t i = 0; i < N / 8; i++) {
  1402. out[0] = p[0];
  1403. out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
  1404. out[2] = p[1] >> 3;
  1405. out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
  1406. out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
  1407. out[5] = p[3] >> 1;
  1408. out[6] = (0xf & (p[3] >> 9)) | ((p[4] & 0x0f) << 4);
  1409. out[7] = p[4] >> 4;
  1410. out[8] = (1 & (p[4] >> 12)) | ((p[5] & 0x7f) << 1);
  1411. out[9] = (0x3f & (p[5] >> 7)) | ((p[6] & 0x03) << 6);
  1412. out[10] = p[6] >> 2;
  1413. out[11] = (7 & (p[6] >> 10)) | ((p[7] & 0x1f) << 3);
  1414. out[12] = p[7] >> 5;
  1415. p += 8;
  1416. out += 13;
  1417. }
  1418. // There are four remaining values.
  1419. out[0] = p[0];
  1420. out[1] = (0x1f & (p[0] >> 8)) | ((p[1] & 0x07) << 5);
  1421. out[2] = p[1] >> 3;
  1422. out[3] = (3 & (p[1] >> 11)) | ((p[2] & 0x3f) << 2);
  1423. out[4] = (0x7f & (p[2] >> 6)) | ((p[3] & 0x01) << 7);
  1424. out[5] = p[3] >> 1;
  1425. out[6] = 0xf & (p[3] >> 9);
  1426. }
  1427. // poly_unmarshal parses the output of |poly_marshal| and sets |out| such that
  1428. // all but the final coefficients match, and the final coefficient is calculated
  1429. // such that evaluating |out| at one results in zero. It returns one on success
  1430. // or zero if |in| is an invalid encoding.
  1431. static int poly_unmarshal(struct poly *out, const uint8_t in[POLY_BYTES]) {
  1432. uint16_t *p = out->v;
  1433. for (size_t i = 0; i < N / 8; i++) {
  1434. p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
  1435. p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
  1436. (uint16_t)(in[3] & 3) << 11;
  1437. p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
  1438. p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
  1439. (uint16_t)(in[6] & 0xf) << 9;
  1440. p[4] = (uint16_t)(in[6] >> 4) | (uint16_t)(in[7]) << 4 |
  1441. (uint16_t)(in[8] & 1) << 12;
  1442. p[5] = (uint16_t)(in[8] >> 1) | (uint16_t)(in[9] & 0x3f) << 7;
  1443. p[6] = (uint16_t)(in[9] >> 6) | (uint16_t)(in[10]) << 2 |
  1444. (uint16_t)(in[11] & 7) << 10;
  1445. p[7] = (uint16_t)(in[11] >> 3) | (uint16_t)(in[12]) << 5;
  1446. p += 8;
  1447. in += 13;
  1448. }
  1449. // There are four coefficients remaining.
  1450. p[0] = (uint16_t)(in[0]) | (uint16_t)(in[1] & 0x1f) << 8;
  1451. p[1] = (uint16_t)(in[1] >> 5) | (uint16_t)(in[2]) << 3 |
  1452. (uint16_t)(in[3] & 3) << 11;
  1453. p[2] = (uint16_t)(in[3] >> 2) | (uint16_t)(in[4] & 0x7f) << 6;
  1454. p[3] = (uint16_t)(in[4] >> 7) | (uint16_t)(in[5]) << 1 |
  1455. (uint16_t)(in[6] & 0xf) << 9;
  1456. for (unsigned i = 0; i < N - 1; i++) {
  1457. out->v[i] = (int16_t)(out->v[i] << 3) >> 3;
  1458. }
  1459. // There are four unused bits in the last byte. We require them to be zero.
  1460. if ((in[6] & 0xf0) != 0) {
  1461. return 0;
  1462. }
  1463. // Set the final coefficient as specifed in [HRSSNIST] 1.9.2 step 6.
  1464. uint32_t sum = 0;
  1465. for (size_t i = 0; i < N - 1; i++) {
  1466. sum += out->v[i];
  1467. }
  1468. out->v[N - 1] = (uint16_t)(0u - sum);
  1469. return 1;
  1470. }
  1471. // mod3_from_modQ maps {0, 1, Q-1, 65535} -> {0, 1, 2, 2}. Note that |v| may
  1472. // have an invalid value when processing attacker-controlled inputs.
  1473. static uint16_t mod3_from_modQ(uint16_t v) {
  1474. v &= 3;
  1475. return v ^ (v >> 1);
  1476. }
  1477. // poly_marshal_mod3 marshals |in| to |out| where the coefficients of |in| are
  1478. // all in {0, 1, Q-1, 65535} and |in| is mod Φ(N). (Note that coefficients may
  1479. // have invalid values when processing attacker-controlled inputs.)
  1480. static void poly_marshal_mod3(uint8_t out[HRSS_POLY3_BYTES],
  1481. const struct poly *in) {
  1482. const uint16_t *coeffs = in->v;
  1483. // Only 700 coefficients are marshaled because in[700] must be zero.
  1484. assert(coeffs[N-1] == 0);
  1485. for (size_t i = 0; i < HRSS_POLY3_BYTES; i++) {
  1486. const uint16_t coeffs0 = mod3_from_modQ(coeffs[0]);
  1487. const uint16_t coeffs1 = mod3_from_modQ(coeffs[1]);
  1488. const uint16_t coeffs2 = mod3_from_modQ(coeffs[2]);
  1489. const uint16_t coeffs3 = mod3_from_modQ(coeffs[3]);
  1490. const uint16_t coeffs4 = mod3_from_modQ(coeffs[4]);
  1491. out[i] = coeffs0 + coeffs1 * 3 + coeffs2 * 9 + coeffs3 * 27 + coeffs4 * 81;
  1492. coeffs += 5;
  1493. }
  1494. }
  1495. // HRSS-specific functions
  1496. // -----------------------
  1497. // poly_short_sample samples a vector of values in {0xffff (i.e. -1), 0, 1}.
  1498. // This is the same action as the algorithm in [HRSSNIST] section 1.8.1, but
  1499. // with HRSS-SXY the sampling algorithm is now a private detail of the
  1500. // implementation (previously it had to match between two parties). This
  1501. // function uses that freedom to implement a flatter distribution of values.
  1502. static void poly_short_sample(struct poly *out,
  1503. const uint8_t in[HRSS_SAMPLE_BYTES]) {
  1504. OPENSSL_STATIC_ASSERT(HRSS_SAMPLE_BYTES == N - 1,
  1505. "HRSS_SAMPLE_BYTES incorrect");
  1506. for (size_t i = 0; i < N - 1; i++) {
  1507. uint16_t v = mod3(in[i]);
  1508. // Map {0, 1, 2} -> {0, 1, 0xffff}
  1509. v |= ((v >> 1) ^ 1) - 1;
  1510. out->v[i] = v;
  1511. }
  1512. out->v[N - 1] = 0;
  1513. }
  1514. // poly_short_sample_plus performs the T+ sample as defined in [HRSSNIST],
  1515. // section 1.8.2.
  1516. static void poly_short_sample_plus(struct poly *out,
  1517. const uint8_t in[HRSS_SAMPLE_BYTES]) {
  1518. poly_short_sample(out, in);
  1519. // sum (and the product in the for loop) will overflow. But that's fine
  1520. // because |sum| is bound by +/- (N-2), and N < 2^15 so it works out.
  1521. uint16_t sum = 0;
  1522. for (unsigned i = 0; i < N - 2; i++) {
  1523. sum += (unsigned) out->v[i] * out->v[i + 1];
  1524. }
  1525. // If the sum is negative, flip the sign of even-positioned coefficients. (See
  1526. // page 8 of [HRSS].)
  1527. sum = ((int16_t) sum) >> 15;
  1528. const uint16_t scale = sum | (~sum & 1);
  1529. for (unsigned i = 0; i < N; i += 2) {
  1530. out->v[i] = (unsigned) out->v[i] * scale;
  1531. }
  1532. }
  1533. // poly_lift computes the function discussed in [HRSS], appendix B.
  1534. static void poly_lift(struct poly *out, const struct poly *a) {
  1535. // We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the
  1536. // Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime).
  1537. // 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up:
  1538. //
  1539. // R.<x> = PolynomialRing(GF(3)…)
  1540. // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n))
  1541. // list(inv)[:15]
  1542. // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]
  1543. //
  1544. // This three-element pattern of coefficients repeats for the whole
  1545. // polynomial.
  1546. //
  1547. // Next define the overbar operator such that z̅ = z[0] +
  1548. // reverse(z[1:]). (Index zero of a polynomial here is the coefficient
  1549. // of the constant term. So index one is the coefficient of 𝑥 and so
  1550. // on.)
  1551. //
  1552. // A less odd way to define this is to see that z̅ negates the indexes,
  1553. // so z̅[0] = z[-0], z̅[1] = z[-1] and so on.
  1554. //
  1555. // The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = <v,
  1556. // z̅>, vz[1] = <v, 𝑥z̅>, …. (Where <a, b> is the inner product: the sum
  1557. // of the point-wise products.) Although we calculated the inverse mod
  1558. // Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end.
  1559. // (That's because (𝑥^N - 1) is a multiple of Φ(N).)
  1560. //
  1561. // When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation
  1562. // of the list of coefficients.
  1563. //
  1564. // Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like:
  1565. //
  1566. // def reverse(xs):
  1567. // suffix = list(xs[1:])
  1568. // suffix.reverse()
  1569. // return [xs[0]] + suffix
  1570. //
  1571. // def rotate(xs):
  1572. // return [xs[-1]] + xs[:-1]
  1573. //
  1574. // zoverbar = reverse(list(inv) + [0])
  1575. // xzoverbar = rotate(reverse(list(inv) + [0]))
  1576. // x2zoverbar = rotate(rotate(reverse(list(inv) + [0])))
  1577. //
  1578. // zoverbar[:15]
  1579. // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
  1580. // xzoverbar[:15]
  1581. // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
  1582. // x2zoverbar[:15]
  1583. // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
  1584. //
  1585. // (For a formula for z̅, see lemma two of appendix B.)
  1586. //
  1587. // After the first three elements have been taken care of, all then have
  1588. // a repeating three-element cycle. The next value (𝑥^3z̅) involves
  1589. // three rotations of the first pattern, thus the three-element cycle
  1590. // lines up. However, the discontinuity in the first three elements
  1591. // obviously moves to a different position. Consider the difference
  1592. // between 𝑥^3z̅ and z̅:
  1593. //
  1594. // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15]
  1595. // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  1596. //
  1597. // This pattern of differences is the same for all elements, although it
  1598. // obviously moves right with the rotations.
  1599. //
  1600. // From this, we reach algorithm eight of appendix B.
  1601. // Handle the first three elements of the inner products.
  1602. out->v[0] = a->v[0] + a->v[2];
  1603. out->v[1] = a->v[1];
  1604. out->v[2] = -a->v[0] + a->v[2];
  1605. // s0, s1, s2 are added into out->v[0], out->v[1], and out->v[2],
  1606. // respectively. We do not compute s1 because it's just -(s0 + s1).
  1607. uint16_t s0 = 0, s2 = 0;
  1608. for (size_t i = 3; i < 699; i += 3) {
  1609. s0 += -a->v[i] + a->v[i + 2];
  1610. // s1 += a->v[i] - a->v[i + 1];
  1611. s2 += a->v[i + 1] - a->v[i + 2];
  1612. }
  1613. // Handle the fact that the three-element pattern doesn't fill the
  1614. // polynomial exactly (since 701 isn't a multiple of three).
  1615. s0 -= a->v[699];
  1616. // s1 += a->v[699] - a->v[700];
  1617. s2 += a->v[700];
  1618. // Note that s0 + s1 + s2 = 0.
  1619. out->v[0] += s0;
  1620. out->v[1] -= (s0 + s2); // = s1
  1621. out->v[2] += s2;
  1622. // Calculate the remaining inner products by taking advantage of the
  1623. // fact that the pattern repeats every three cycles and the pattern of
  1624. // differences moves with the rotation.
  1625. for (size_t i = 3; i < N; i++) {
  1626. out->v[i] = (out->v[i - 3] - (a->v[i - 2] + a->v[i - 1] + a->v[i]));
  1627. }
  1628. // Reduce mod Φ(N) by subtracting a multiple of out[700] from every
  1629. // element and convert to mod Q. (See above about adding twice as
  1630. // subtraction.)
  1631. const crypto_word_t v = out->v[700];
  1632. for (unsigned i = 0; i < N; i++) {
  1633. const uint16_t vi_mod3 = mod3(out->v[i] - v);
  1634. // Map {0, 1, 2} to {0, 1, 0xffff}.
  1635. out->v[i] = (~((vi_mod3 >> 1) - 1)) | vi_mod3;
  1636. }
  1637. poly_mul_x_minus_1(out);
  1638. }
  1639. struct public_key {
  1640. struct poly ph;
  1641. };
  1642. struct private_key {
  1643. struct poly3 f, f_inverse;
  1644. struct poly ph_inverse;
  1645. uint8_t hmac_key[32];
  1646. };
  1647. // public_key_from_external converts an external public key pointer into an
  1648. // internal one. Externally the alignment is only specified to be eight bytes
  1649. // but we need 16-byte alignment. We could annotate the external struct with
  1650. // that alignment but we can only assume that malloced pointers are 8-byte
  1651. // aligned in any case. (Even if the underlying malloc returns values with
  1652. // 16-byte alignment, |OPENSSL_malloc| will store an 8-byte size prefix and mess
  1653. // that up.)
  1654. static struct public_key *public_key_from_external(
  1655. struct HRSS_public_key *ext) {
  1656. OPENSSL_STATIC_ASSERT(
  1657. sizeof(struct HRSS_public_key) >= sizeof(struct public_key) + 15,
  1658. "HRSS public key too small");
  1659. uintptr_t p = (uintptr_t)ext;
  1660. p = (p + 15) & ~15;
  1661. return (struct public_key *)p;
  1662. }
  1663. // private_key_from_external does the same thing as |public_key_from_external|,
  1664. // but for private keys. See the comment on that function about alignment
  1665. // issues.
  1666. static struct private_key *private_key_from_external(
  1667. struct HRSS_private_key *ext) {
  1668. OPENSSL_STATIC_ASSERT(
  1669. sizeof(struct HRSS_private_key) >= sizeof(struct private_key) + 15,
  1670. "HRSS private key too small");
  1671. uintptr_t p = (uintptr_t)ext;
  1672. p = (p + 15) & ~15;
  1673. return (struct private_key *)p;
  1674. }
  1675. void HRSS_generate_key(
  1676. struct HRSS_public_key *out_pub, struct HRSS_private_key *out_priv,
  1677. const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES + 32]) {
  1678. struct public_key *pub = public_key_from_external(out_pub);
  1679. struct private_key *priv = private_key_from_external(out_priv);
  1680. OPENSSL_memcpy(priv->hmac_key, in + 2 * HRSS_SAMPLE_BYTES,
  1681. sizeof(priv->hmac_key));
  1682. struct poly f;
  1683. poly_short_sample_plus(&f, in);
  1684. poly3_from_poly(&priv->f, &f);
  1685. HRSS_poly3_invert(&priv->f_inverse, &priv->f);
  1686. // pg_phi1 is p (i.e. 3) × g × Φ(1) (i.e. 𝑥-1).
  1687. struct poly pg_phi1;
  1688. poly_short_sample_plus(&pg_phi1, in + HRSS_SAMPLE_BYTES);
  1689. for (unsigned i = 0; i < N; i++) {
  1690. pg_phi1.v[i] *= 3;
  1691. }
  1692. poly_mul_x_minus_1(&pg_phi1);
  1693. struct poly pfg_phi1;
  1694. poly_mul(&pfg_phi1, &f, &pg_phi1);
  1695. struct poly pfg_phi1_inverse;
  1696. poly_invert(&pfg_phi1_inverse, &pfg_phi1);
  1697. poly_mul(&pub->ph, &pfg_phi1_inverse, &pg_phi1);
  1698. poly_mul(&pub->ph, &pub->ph, &pg_phi1);
  1699. poly_clamp(&pub->ph);
  1700. poly_mul(&priv->ph_inverse, &pfg_phi1_inverse, &f);
  1701. poly_mul(&priv->ph_inverse, &priv->ph_inverse, &f);
  1702. poly_clamp(&priv->ph_inverse);
  1703. }
  1704. static const char kSharedKey[] = "shared key";
  1705. void HRSS_encap(uint8_t out_ciphertext[POLY_BYTES],
  1706. uint8_t out_shared_key[32],
  1707. const struct HRSS_public_key *in_pub,
  1708. const uint8_t in[HRSS_SAMPLE_BYTES + HRSS_SAMPLE_BYTES]) {
  1709. const struct public_key *pub =
  1710. public_key_from_external((struct HRSS_public_key *)in_pub);
  1711. struct poly m, r, m_lifted;
  1712. poly_short_sample(&m, in);
  1713. poly_short_sample(&r, in + HRSS_SAMPLE_BYTES);
  1714. poly_lift(&m_lifted, &m);
  1715. struct poly prh_plus_m;
  1716. poly_mul(&prh_plus_m, &r, &pub->ph);
  1717. for (unsigned i = 0; i < N; i++) {
  1718. prh_plus_m.v[i] += m_lifted.v[i];
  1719. }
  1720. poly_marshal(out_ciphertext, &prh_plus_m);
  1721. uint8_t m_bytes[HRSS_POLY3_BYTES], r_bytes[HRSS_POLY3_BYTES];
  1722. poly_marshal_mod3(m_bytes, &m);
  1723. poly_marshal_mod3(r_bytes, &r);
  1724. SHA256_CTX hash_ctx;
  1725. SHA256_Init(&hash_ctx);
  1726. SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
  1727. SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
  1728. SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
  1729. SHA256_Update(&hash_ctx, out_ciphertext, POLY_BYTES);
  1730. SHA256_Final(out_shared_key, &hash_ctx);
  1731. }
  1732. void HRSS_decap(uint8_t out_shared_key[HRSS_KEY_BYTES],
  1733. const struct HRSS_private_key *in_priv,
  1734. const uint8_t *ciphertext, size_t ciphertext_len) {
  1735. const struct private_key *priv =
  1736. private_key_from_external((struct HRSS_private_key *)in_priv);
  1737. // This is HMAC, expanded inline rather than using the |HMAC| function so that
  1738. // we can avoid dealing with possible allocation failures and so keep this
  1739. // function infallible.
  1740. uint8_t masked_key[SHA256_CBLOCK];
  1741. OPENSSL_STATIC_ASSERT(sizeof(priv->hmac_key) <= sizeof(masked_key),
  1742. "HRSS HMAC key larger than SHA-256 block size");
  1743. for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
  1744. masked_key[i] = priv->hmac_key[i] ^ 0x36;
  1745. }
  1746. OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x36,
  1747. sizeof(masked_key) - sizeof(priv->hmac_key));
  1748. SHA256_CTX hash_ctx;
  1749. SHA256_Init(&hash_ctx);
  1750. SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
  1751. SHA256_Update(&hash_ctx, ciphertext, ciphertext_len);
  1752. uint8_t inner_digest[SHA256_DIGEST_LENGTH];
  1753. SHA256_Final(inner_digest, &hash_ctx);
  1754. for (size_t i = 0; i < sizeof(priv->hmac_key); i++) {
  1755. masked_key[i] ^= (0x5c ^ 0x36);
  1756. }
  1757. OPENSSL_memset(masked_key + sizeof(priv->hmac_key), 0x5c,
  1758. sizeof(masked_key) - sizeof(priv->hmac_key));
  1759. SHA256_Init(&hash_ctx);
  1760. SHA256_Update(&hash_ctx, masked_key, sizeof(masked_key));
  1761. SHA256_Update(&hash_ctx, inner_digest, sizeof(inner_digest));
  1762. OPENSSL_STATIC_ASSERT(HRSS_KEY_BYTES == SHA256_DIGEST_LENGTH,
  1763. "HRSS shared key length incorrect");
  1764. SHA256_Final(out_shared_key, &hash_ctx);
  1765. struct poly c;
  1766. // If the ciphertext is publicly invalid then a random shared key is still
  1767. // returned to simply the logic of the caller, but this path is not constant
  1768. // time.
  1769. if (ciphertext_len != HRSS_CIPHERTEXT_BYTES ||
  1770. !poly_unmarshal(&c, ciphertext)) {
  1771. return;
  1772. }
  1773. struct poly f, cf;
  1774. struct poly3 cf3, m3;
  1775. poly_from_poly3(&f, &priv->f);
  1776. poly_mul(&cf, &c, &f);
  1777. poly3_from_poly(&cf3, &cf);
  1778. // Note that cf3 is not reduced mod Φ(N). That reduction is deferred.
  1779. HRSS_poly3_mul(&m3, &cf3, &priv->f_inverse);
  1780. struct poly m, m_lifted;
  1781. poly_from_poly3(&m, &m3);
  1782. poly_lift(&m_lifted, &m);
  1783. struct poly r;
  1784. for (unsigned i = 0; i < N; i++) {
  1785. r.v[i] = c.v[i] - m_lifted.v[i];
  1786. }
  1787. poly_mul(&r, &r, &priv->ph_inverse);
  1788. poly_mod_phiN(&r);
  1789. poly_clamp(&r);
  1790. struct poly3 r3;
  1791. crypto_word_t ok = poly3_from_poly_checked(&r3, &r);
  1792. // [NTRUCOMP] section 5.1 includes ReEnc2 and a proof that it's valid. Rather
  1793. // than do an expensive |poly_mul|, it rebuilds |c'| from |c - lift(m)|
  1794. // (called |b|) with:
  1795. // t = (−b(1)/N) mod Q
  1796. // c' = b + tΦ(N) + lift(m) mod Q
  1797. //
  1798. // When polynomials are transmitted, the final coefficient is omitted and
  1799. // |poly_unmarshal| sets it such that f(1) == 0. Thus c(1) == 0. Also,
  1800. // |poly_lift| multiplies the result by (x-1) and therefore evaluating a
  1801. // lifted polynomial at 1 is also zero. Thus lift(m)(1) == 0 and so
  1802. // (c - lift(m))(1) == 0.
  1803. //
  1804. // Although we defer the reduction above, |b| is conceptually reduced mod
  1805. // Φ(N). In order to do that reduction one subtracts |c[N-1]| from every
  1806. // coefficient. Therefore b(1) = -c[N-1]×N. The value of |t|, above, then is
  1807. // just recovering |c[N-1]|, and adding tΦ(N) is simply undoing the reduction.
  1808. // Therefore b + tΦ(N) + lift(m) = c by construction and we don't need to
  1809. // recover |c| at all so long as we do the checks in
  1810. // |poly3_from_poly_checked|.
  1811. //
  1812. // The |poly_marshal| here then is just confirming that |poly_unmarshal| is
  1813. // strict and could be omitted.
  1814. uint8_t expected_ciphertext[HRSS_CIPHERTEXT_BYTES];
  1815. OPENSSL_STATIC_ASSERT(HRSS_CIPHERTEXT_BYTES == POLY_BYTES,
  1816. "ciphertext is the wrong size");
  1817. assert(ciphertext_len == sizeof(expected_ciphertext));
  1818. poly_marshal(expected_ciphertext, &c);
  1819. uint8_t m_bytes[HRSS_POLY3_BYTES];
  1820. uint8_t r_bytes[HRSS_POLY3_BYTES];
  1821. poly_marshal_mod3(m_bytes, &m);
  1822. poly_marshal_mod3(r_bytes, &r);
  1823. ok &= constant_time_is_zero_w(CRYPTO_memcmp(ciphertext, expected_ciphertext,
  1824. sizeof(expected_ciphertext)));
  1825. uint8_t shared_key[32];
  1826. SHA256_Init(&hash_ctx);
  1827. SHA256_Update(&hash_ctx, kSharedKey, sizeof(kSharedKey));
  1828. SHA256_Update(&hash_ctx, m_bytes, sizeof(m_bytes));
  1829. SHA256_Update(&hash_ctx, r_bytes, sizeof(r_bytes));
  1830. SHA256_Update(&hash_ctx, expected_ciphertext, sizeof(expected_ciphertext));
  1831. SHA256_Final(shared_key, &hash_ctx);
  1832. for (unsigned i = 0; i < sizeof(shared_key); i++) {
  1833. out_shared_key[i] =
  1834. constant_time_select_8(ok, shared_key[i], out_shared_key[i]);
  1835. }
  1836. }
  1837. void HRSS_marshal_public_key(uint8_t out[HRSS_PUBLIC_KEY_BYTES],
  1838. const struct HRSS_public_key *in_pub) {
  1839. const struct public_key *pub =
  1840. public_key_from_external((struct HRSS_public_key *)in_pub);
  1841. poly_marshal(out, &pub->ph);
  1842. }
  1843. int HRSS_parse_public_key(struct HRSS_public_key *out,
  1844. const uint8_t in[HRSS_PUBLIC_KEY_BYTES]) {
  1845. struct public_key *pub = public_key_from_external(out);
  1846. if (!poly_unmarshal(&pub->ph, in)) {
  1847. return 0;
  1848. }
  1849. OPENSSL_memset(&pub->ph.v[N], 0, 3 * sizeof(uint16_t));
  1850. return 1;
  1851. }