pqc/crypto_sign/falcon-1024/avx2/fpr.h
2020-10-21 16:37:33 -04:00

350 lines
11 KiB
C

#ifndef PQCLEAN_FALCON1024_AVX2_FPR_H
#define PQCLEAN_FALCON1024_AVX2_FPR_H
/*
* Floating-point operations.
*
* ==========================(LICENSE BEGIN)============================
*
* Copyright (c) 2017-2019 Falcon Project
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ===========================(LICENSE END)=============================
*
* @author Thomas Pornin <thomas.pornin@nccgroup.com>
*/
/* ====================================================================== */
#include <immintrin.h>
#include <math.h>
#define FMADD(a, b, c) _mm256_add_pd(_mm256_mul_pd(a, b), c)
#define FMSUB(a, b, c) _mm256_sub_pd(_mm256_mul_pd(a, b), c)
/*
* We wrap the native 'double' type into a structure so that the C compiler
* complains if we inadvertently use raw arithmetic operators on the 'fpr'
* type instead of using the inline functions below. This should have no
* extra runtime cost, since all the functions below are 'inline'.
*/
typedef struct {
double v;
} fpr;
static inline fpr
FPR(double v) {
fpr x;
x.v = v;
return x;
}
static inline fpr
fpr_of(int64_t i) {
return FPR((double)i);
}
static const fpr fpr_q = { 12289.0 };
static const fpr fpr_inverse_of_q = { 1.0 / 12289.0 };
static const fpr fpr_inv_2sqrsigma0 = { .150865048875372721532312163019 };
static const fpr fpr_inv_sigma = { .005819826392951607426919370871 };
static const fpr fpr_sigma_min_9 = { 1.291500756233514568549480827642 };
static const fpr fpr_sigma_min_10 = { 1.311734375905083682667395805765 };
static const fpr fpr_log2 = { 0.69314718055994530941723212146 };
static const fpr fpr_inv_log2 = { 1.4426950408889634073599246810 };
static const fpr fpr_bnorm_max = { 16822.4121 };
static const fpr fpr_zero = { 0.0 };
static const fpr fpr_one = { 1.0 };
static const fpr fpr_two = { 2.0 };
static const fpr fpr_onehalf = { 0.5 };
static const fpr fpr_invsqrt2 = { 0.707106781186547524400844362105 };
static const fpr fpr_invsqrt8 = { 0.353553390593273762200422181052 };
static const fpr fpr_ptwo31 = { 2147483648.0 };
static const fpr fpr_ptwo31m1 = { 2147483647.0 };
static const fpr fpr_mtwo31m1 = { -2147483647.0 };
static const fpr fpr_ptwo63m1 = { 9223372036854775807.0 };
static const fpr fpr_mtwo63m1 = { -9223372036854775807.0 };
static const fpr fpr_ptwo63 = { 9223372036854775808.0 };
static inline int64_t
fpr_rint(fpr x) {
/*
* We do not want to use llrint() since it might be not
* constant-time.
*
* Suppose that x >= 0. If x >= 2^52, then it is already an
* integer. Otherwise, if x < 2^52, then computing x+2^52 will
* yield a value that will be rounded to the nearest integer
* with exactly the right rules (round-to-nearest-even).
*
* In order to have constant-time processing, we must do the
* computation for both x >= 0 and x < 0 cases, and use a
* cast to an integer to access the sign and select the proper
* value. Such casts also allow us to find out if |x| < 2^52.
*/
int64_t sx, tx, rp, rn, m;
uint32_t ub;
sx = (int64_t)(x.v - 1.0);
tx = (int64_t)x.v;
rp = (int64_t)(x.v + 4503599627370496.0) - 4503599627370496;
rn = (int64_t)(x.v - 4503599627370496.0) + 4503599627370496;
/*
* If tx >= 2^52 or tx < -2^52, then result is tx.
* Otherwise, if sx >= 0, then result is rp.
* Otherwise, result is rn. We use the fact that when x is
* close to 0 (|x| <= 0.25) then both rp and rn are correct;
* and if x is not close to 0, then trunc(x-1.0) yields the
* appropriate sign.
*/
/*
* Clamp rp to zero if tx < 0.
* Clamp rn to zero if tx >= 0.
*/
m = sx >> 63;
rn &= m;
rp &= ~m;
/*
* Get the 12 upper bits of tx; if they are not all zeros or
* all ones, then tx >= 2^52 or tx < -2^52, and we clamp both
* rp and rn to zero. Otherwise, we clamp tx to zero.
*/
ub = (uint32_t)((uint64_t)tx >> 52);
m = -(int64_t)((((ub + 1) & 0xFFF) - 2) >> 31);
rp &= m;
rn &= m;
tx &= ~m;
/*
* Only one of tx, rn or rp (at most) can be non-zero at this
* point.
*/
return tx | rn | rp;
}
static inline int64_t
fpr_floor(fpr x) {
int64_t r;
/*
* The cast performs a trunc() (rounding toward 0) and thus is
* wrong by 1 for most negative values. The correction below is
* constant-time as long as the compiler turns the
* floating-point conversion result into a 0/1 integer without a
* conditional branch or another non-constant-time construction.
* This should hold on all modern architectures with an FPU (and
* if it is false on a given arch, then chances are that the FPU
* itself is not constant-time, making the point moot).
*/
r = (int64_t)x.v;
return r - (x.v < (double)r);
}
static inline int64_t
fpr_trunc(fpr x) {
return (int64_t)x.v;
}
static inline fpr
fpr_add(fpr x, fpr y) {
return FPR(x.v + y.v);
}
static inline fpr
fpr_sub(fpr x, fpr y) {
return FPR(x.v - y.v);
}
static inline fpr
fpr_neg(fpr x) {
return FPR(-x.v);
}
static inline fpr
fpr_half(fpr x) {
return FPR(x.v * 0.5);
}
static inline fpr
fpr_double(fpr x) {
return FPR(x.v + x.v);
}
static inline fpr
fpr_mul(fpr x, fpr y) {
return FPR(x.v * y.v);
}
static inline fpr
fpr_sqr(fpr x) {
return FPR(x.v * x.v);
}
static inline fpr
fpr_inv(fpr x) {
return FPR(1.0 / x.v);
}
static inline fpr
fpr_div(fpr x, fpr y) {
return FPR(x.v / y.v);
}
static inline void
fpr_sqrt_avx2(double *t) {
__m128d x;
x = _mm_load1_pd(t);
x = _mm_sqrt_pd(x);
_mm_storel_pd(t, x);
}
static inline fpr
fpr_sqrt(fpr x) {
/*
* We prefer not to have a dependency on libm when it can be
* avoided. On x86, calling the sqrt() libm function inlines
* the relevant opcode (fsqrt or sqrtsd, depending on whether
* the 387 FPU or SSE2 is used for floating-point operations)
* but then makes an optional call to the library function
* for proper error handling, in case the operand is negative.
*
* To avoid this dependency, we use intrinsics or inline assembly
* on recognized platforms:
*
* - If AVX2 is explicitly enabled, then we use SSE2 intrinsics.
*
* - On GCC/Clang with SSE maths, we use SSE2 intrinsics.
*
* - On GCC/Clang on i386, or MSVC on i386, we use inline assembly
* to call the 387 FPU fsqrt opcode.
*
* - On GCC/Clang/XLC on PowerPC, we use inline assembly to call
* the fsqrt opcode (Clang needs a special hack).
*
* - On GCC/Clang on ARM with hardware floating-point, we use
* inline assembly to call the vqsrt.f64 opcode. Due to a
* complex ecosystem of compilers and assembly syntaxes, we
* have to call it "fsqrt" or "fsqrtd", depending on case.
*
* If the platform is not recognized, a call to the system
* library function sqrt() is performed. On some compilers, this
* may actually inline the relevant opcode, and call the library
* function only when the input is invalid (e.g. negative);
* Falcon never actually calls sqrt() on a negative value, but
* the dependency to libm will still be there.
*/
fpr_sqrt_avx2(&x.v);
return x;
}
static inline int
fpr_lt(fpr x, fpr y) {
return x.v < y.v;
}
static inline uint64_t
fpr_expm_p63(fpr x, fpr ccs) {
/*
* Polynomial approximation of exp(-x) is taken from FACCT:
* https://eprint.iacr.org/2018/1234
* Specifically, values are extracted from the implementation
* referenced from the FACCT article, and available at:
* https://github.com/raykzhao/gaussian
* Tests over more than 24 billions of random inputs in the
* 0..log(2) range have never shown a deviation larger than
* 2^(-50) from the true mathematical value.
*/
/*
* AVX2 implementation uses more operations than Horner's method,
* but with a lower expression tree depth. This helps because
* additions and multiplications have a latency of 4 cycles on
* a Skylake, but the CPU can issue two of them per cycle.
*/
static const union {
double d[12];
__m256d v[3];
} c = {
{
0.999999999999994892974086724280,
0.500000000000019206858326015208,
0.166666666666984014666397229121,
0.041666666666110491190622155955,
0.008333333327800835146903501993,
0.001388888894063186997887560103,
0.000198412739277311890541063977,
0.000024801566833585381209939524,
0.000002755586350219122514855659,
0.000000275607356160477811864927,
0.000000025299506379442070029551,
0.000000002073772366009083061987
}
};
double d1, d2, d4, d8, y;
__m256d d14, d58, d9c;
d1 = -x.v;
d2 = d1 * d1;
d4 = d2 * d2;
d8 = d4 * d4;
d14 = _mm256_set_pd(d4, d2 * d1, d2, d1);
d58 = _mm256_mul_pd(d14, _mm256_set1_pd(d4));
d9c = _mm256_mul_pd(d14, _mm256_set1_pd(d8));
d14 = _mm256_mul_pd(d14, _mm256_loadu_pd(&c.d[0]));
d58 = FMADD(d58, _mm256_loadu_pd(&c.d[4]), d14);
d9c = FMADD(d9c, _mm256_loadu_pd(&c.d[8]), d58);
d9c = _mm256_hadd_pd(d9c, d9c);
y = 1.0 + _mm_cvtsd_f64(_mm256_castpd256_pd128(d9c)) // _mm256_cvtsd_f64(d9c)
+ _mm_cvtsd_f64(_mm256_extractf128_pd(d9c, 1));
y *= ccs.v;
/*
* Final conversion goes through int64_t first, because that's what
* the underlying opcode (vcvttsd2si) will do, and we know that the
* result will fit, since x >= 0 and ccs < 1. If we did the
* conversion directly to uint64_t, then the compiler would add some
* extra code to cover the case of a source value of 2^63 or more,
* and though the alternate path would never be exercised, the
* extra comparison would cost us some cycles.
*/
return (uint64_t)(int64_t)(y * fpr_ptwo63.v);
}
#define fpr_gm_tab PQCLEAN_FALCON1024_AVX2_fpr_gm_tab
extern const fpr fpr_gm_tab[];
#define fpr_p2_tab PQCLEAN_FALCON1024_AVX2_fpr_p2_tab
extern const fpr fpr_p2_tab[];
/* ====================================================================== */
#endif