pqc/crypto_kem/ntruhrss701/clean/poly_rq_mul.c

285 lines
8.1 KiB
C

#include "poly.h"
/* Polynomial multiplication using */
/* Toom-4 and two layers of Karatsuba. */
#define L PAD32(NTRU_N)
#define M (L/4)
#define K (L/16)
static void toom4_k2x2_mul(uint16_t ab[2 * L], const uint16_t a[L], const uint16_t b[L]);
static void toom4_k2x2_eval_0(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_p1(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_m1(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_p2(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_m2(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_p3(uint16_t r[9 * K], const uint16_t a[M]);
static void toom4_k2x2_eval_inf(uint16_t r[9 * K], const uint16_t a[M]);
static inline void k2x2_eval(uint16_t r[9 * K]);
static void toom4_k2x2_basemul(uint16_t r[18 * K], const uint16_t a[9 * K], const uint16_t b[9 * K]);
static inline void schoolbook_KxK(uint16_t r[2 * K], const uint16_t a[K], const uint16_t b[K]);
static void toom4_k2x2_interpolate(uint16_t r[2 * M], const uint16_t a[63 * 2 * K]);
static inline void k2x2_interpolate(uint16_t r[M], const uint16_t a[9 * K]);
void PQCLEAN_NTRUHRSS701_CLEAN_poly_Rq_mul(poly *r, const poly *a, const poly *b) {
size_t i;
uint16_t ab[2 * L];
for (i = 0; i < NTRU_N; i++) {
ab[i] = a->coeffs[i];
ab[L + i] = b->coeffs[i];
}
for (i = NTRU_N; i < L; i++) {
ab[i] = 0;
ab[L + i] = 0;
}
toom4_k2x2_mul(ab, ab, ab + L);
for (i = 0; i < NTRU_N; i++) {
r->coeffs[i] = ab[i] + ab[NTRU_N + i];
}
}
static void toom4_k2x2_mul(uint16_t ab[2 * L], const uint16_t a[L], const uint16_t b[L]) {
uint16_t tmpA[9 * K];
uint16_t tmpB[9 * K];
uint16_t eC[63 * 2 * K];
toom4_k2x2_eval_0(tmpA, a);
toom4_k2x2_eval_0(tmpB, b);
toom4_k2x2_basemul(eC + 0 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_p1(tmpA, a);
toom4_k2x2_eval_p1(tmpB, b);
toom4_k2x2_basemul(eC + 1 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_m1(tmpA, a);
toom4_k2x2_eval_m1(tmpB, b);
toom4_k2x2_basemul(eC + 2 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_p2(tmpA, a);
toom4_k2x2_eval_p2(tmpB, b);
toom4_k2x2_basemul(eC + 3 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_m2(tmpA, a);
toom4_k2x2_eval_m2(tmpB, b);
toom4_k2x2_basemul(eC + 4 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_p3(tmpA, a);
toom4_k2x2_eval_p3(tmpB, b);
toom4_k2x2_basemul(eC + 5 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_eval_inf(tmpA, a);
toom4_k2x2_eval_inf(tmpB, b);
toom4_k2x2_basemul(eC + 6 * 9 * 2 * K, tmpA, tmpB);
toom4_k2x2_interpolate(ab, eC);
}
static void toom4_k2x2_eval_0(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_p1(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[0 * M + i];
r[i] += a[1 * M + i];
r[i] += a[2 * M + i];
r[i] += a[3 * M + i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_m1(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[0 * M + i];
r[i] -= a[1 * M + i];
r[i] += a[2 * M + i];
r[i] -= a[3 * M + i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_p2(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[0 * M + i];
r[i] += 2 * a[1 * M + i];
r[i] += 4 * a[2 * M + i];
r[i] += 8 * a[3 * M + i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_m2(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[0 * M + i];
r[i] -= 2 * a[1 * M + i];
r[i] += 4 * a[2 * M + i];
r[i] -= 8 * a[3 * M + i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_p3(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[0 * M + i];
r[i] += 3 * a[1 * M + i];
r[i] += 9 * a[2 * M + i];
r[i] += 27 * a[3 * M + i];
}
k2x2_eval(r);
}
static void toom4_k2x2_eval_inf(uint16_t r[9 * K], const uint16_t a[M]) {
for (size_t i = 0; i < M; i++) {
r[i] = a[3 * M + i];
}
k2x2_eval(r);
}
static inline void k2x2_eval(uint16_t r[9 * K]) {
/* Input: e + f.Y + g.Y^2 + h.Y^3 */
/* Output: [ e | f | g | h | e+f | f+h | g+e | h+g | e+f+g+h ] */
size_t i;
for (i = 0; i < 4 * K; i++) {
r[4 * K + i] = r[i];
}
for (i = 0; i < K; i++) {
r[4 * K + i] += r[1 * K + i];
r[5 * K + i] += r[3 * K + i];
r[6 * K + i] += r[0 * K + i];
r[7 * K + i] += r[2 * K + i];
r[8 * K + i] = r[5 * K + i];
r[8 * K + i] += r[6 * K + i];
}
}
static void toom4_k2x2_basemul(uint16_t r[18 * K], const uint16_t a[9 * K], const uint16_t b[9 * K]) {
schoolbook_KxK(r + 0 * 2 * K, a + 0 * K, b + 0 * K);
schoolbook_KxK(r + 1 * 2 * K, a + 1 * K, b + 1 * K);
schoolbook_KxK(r + 2 * 2 * K, a + 2 * K, b + 2 * K);
schoolbook_KxK(r + 3 * 2 * K, a + 3 * K, b + 3 * K);
schoolbook_KxK(r + 4 * 2 * K, a + 4 * K, b + 4 * K);
schoolbook_KxK(r + 5 * 2 * K, a + 5 * K, b + 5 * K);
schoolbook_KxK(r + 6 * 2 * K, a + 6 * K, b + 6 * K);
schoolbook_KxK(r + 7 * 2 * K, a + 7 * K, b + 7 * K);
schoolbook_KxK(r + 8 * 2 * K, a + 8 * K, b + 8 * K);
}
static inline void schoolbook_KxK(uint16_t r[2 * K], const uint16_t a[K], const uint16_t b[K]) {
size_t i, j;
for (j = 0; j < K; j++) {
r[j] = a[0] * b[j];
}
for (i = 1; i < K; i++) {
for (j = 0; j < K - 1; j++) {
r[i + j] += a[i] * b[j];
}
r[i + K - 1] = a[i] * b[K - 1];
}
r[2 * K - 1] = 0;
}
static void toom4_k2x2_interpolate(uint16_t r[2 * M], const uint16_t a[7 * 18 * K]) {
size_t i;
uint16_t P1[2 * M];
uint16_t Pm1[2 * M];
uint16_t P2[2 * M];
uint16_t Pm2[2 * M];
uint16_t *C0 = r;
uint16_t *C2 = r + 2 * M;
uint16_t *C4 = r + 4 * M;
uint16_t *C6 = r + 6 * M;
uint16_t V0, V1, V2;
k2x2_interpolate(C0, a + 0 * 9 * 2 * K);
k2x2_interpolate(P1, a + 1 * 9 * 2 * K);
k2x2_interpolate(Pm1, a + 2 * 9 * 2 * K);
k2x2_interpolate(P2, a + 3 * 9 * 2 * K);
k2x2_interpolate(Pm2, a + 4 * 9 * 2 * K);
k2x2_interpolate(C6, a + 6 * 9 * 2 * K);
for (i = 0; i < 2 * M; i++) {
V0 = ((uint32_t)(P1[i] + Pm1[i])) >> 1;
V0 = V0 - C0[i] - C6[i];
V1 = ((uint32_t)(P2[i] + Pm2[i] - 2 * C0[i] - 128 * C6[i])) >> 3;
C4[i] = 43691 * (V1 - V0);
C2[i] = V0 - C4[i];
P1[i] = ((uint32_t)(P1[i] - Pm1[i])) >> 1;
}
/* reuse Pm1 for P3 */
#define P3 Pm1
k2x2_interpolate(P3, a + 5 * 9 * 2 * K);
for (i = 0; i < 2 * M; i++) {
V0 = P1[i];
V1 = 43691 * ((((uint32_t)(P2[i] - Pm2[i])) >> 2) - V0);
V2 = 43691 * (P3[i] - C0[i] - 9 * (C2[i] + 9 * (C4[i] + 9 * C6[i])));
V2 = ((uint32_t)(V2 - V0)) >> 3;
V2 -= V1;
P3[i] = 52429 * V2;
P2[i] = V1 - V2;
P1[i] = V0 - P2[i] - P3[i];
}
for (i = 0; i < 2 * M; i++) {
r[1 * M + i] += P1[i];
r[3 * M + i] += P2[i];
r[5 * M + i] += P3[i];
}
}
static inline void k2x2_interpolate(uint16_t r[M], const uint16_t a[9 * K]) {
size_t i;
uint16_t tmp[4 * K];
for (i = 0; i < 2 * K; i++) {
r[0 * K + i] = a[0 * K + i];
r[2 * K + i] = a[2 * K + i];
}
for (i = 0; i < 2 * K; i++) {
r[1 * K + i] += a[8 * K + i] - a[0 * K + i] - a[2 * K + i];
}
for (i = 0; i < 2 * K; i++) {
r[4 * K + i] = a[4 * K + i];
r[6 * K + i] = a[6 * K + i];
}
for (i = 0; i < 2 * K; i++) {
r[5 * K + i] += a[14 * K + i] - a[4 * K + i] - a[6 * K + i];
}
for (i = 0; i < 2 * K; i++) {
tmp[0 * K + i] = a[12 * K + i];
tmp[2 * K + i] = a[10 * K + i];
}
for (i = 0; i < 2 * K; i++) {
tmp[K + i] += a[16 * K + i] - a[12 * K + i] - a[10 * K + i];
}
for (i = 0; i < 4 * K; i++) {
tmp[0 * K + i] = tmp[0 * K + i] - r[0 * K + i] - r[4 * K + i];
}
for (i = 0; i < 4 * K; i++) {
r[2 * K + i] += tmp[0 * K + i];
}
}