th5/hrss.go
2019-04-05 14:39:31 +01:00

1213 lines
26 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tls
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"io"
"math/bits"
)
const (
PublicKeySize = modQBytes
CiphertextSize = modQBytes
)
const (
N = 701
Q = 8192
mod3Bytes = 140
modQBytes = 1138
)
const (
bitsPerWord = bits.UintSize
wordsPerPoly = (N + bitsPerWord - 1) / bitsPerWord
fullWordsPerPoly = N / bitsPerWord
bitsInLastWord = N % bitsPerWord
)
// poly3 represents a degree-N polynomial over GF(3). Each coefficient is
// bitsliced across the |s| and |a| arrays, like this:
//
// s | a | value
// -----------------
// 0 | 0 | 0
// 0 | 1 | 1
// 1 | 0 | 2 (aka -1)
// 1 | 1 | <invalid>
//
// ('s' is for sign, and 'a' is just a letter.)
//
// Once bitsliced as such, the following circuits can be used to implement
// addition and multiplication mod 3:
//
// (s3, a3) = (s1, a1) × (s2, a2)
// s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2)
// a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2)
//
// (s3, a3) = (s1, a1) + (s2, a2)
// t1 = ~(s1 a1)
// t2 = ~(s2 a2)
// s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1)
// a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1)
//
// Negating a value just involves swapping s and a.
type poly3 struct {
s [wordsPerPoly]uint
a [wordsPerPoly]uint
}
func (p *poly3) trim() {
p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
}
func (p *poly3) zero() {
for i := range p.a {
p.s[i] = 0
p.a[i] = 0
}
}
func (p *poly3) fromDiscrete(in *poly) {
var shift uint
s := p.s[:]
a := p.a[:]
s[0] = 0
a[0] = 0
for _, v := range in {
s[0] >>= 1
s[0] |= uint((v>>1)&1) << (bitsPerWord - 1)
a[0] >>= 1
a[0] |= uint(v&1) << (bitsPerWord - 1)
shift++
if shift == bitsPerWord {
s = s[1:]
a = a[1:]
s[0] = 0
a[0] = 0
shift = 0
}
}
a[0] >>= bitsPerWord - shift
s[0] >>= bitsPerWord - shift
}
func (p *poly3) fromModQ(in *poly) int {
var shift uint
s := p.s[:]
a := p.a[:]
s[0] = 0
a[0] = 0
ok := 1
for _, v := range in {
vMod3, vOk := modQToMod3(v)
ok &= vOk
s[0] >>= 1
s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1)
a[0] >>= 1
a[0] |= uint(vMod3&1) << (bitsPerWord - 1)
shift++
if shift == bitsPerWord {
s = s[1:]
a = a[1:]
s[0] = 0
a[0] = 0
shift = 0
}
}
a[0] >>= bitsPerWord - shift
s[0] >>= bitsPerWord - shift
return ok
}
func (p *poly3) fromDiscreteMod3(in *poly) {
var shift uint
s := p.s[:]
a := p.a[:]
s[0] = 0
a[0] = 0
for _, v := range in {
// This duplicates the 13th bit upwards to the top of the
// uint16, essentially treating it as a sign bit and converting
// into a signed int16. The signed value is reduced mod 3,
// yeilding {-2, -1, 0, 1, 2}.
v = uint16((int16(v<<3)>>3)%3) & 7
// We want to map v thus:
// {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom
// three bits and then the constants below, when shifted by
// those three bits, perform the required mapping.
s[0] >>= 1
s[0] |= (0xbc >> v) << (bitsPerWord - 1)
a[0] >>= 1
a[0] |= (0x7a >> v) << (bitsPerWord - 1)
shift++
if shift == bitsPerWord {
s = s[1:]
a = a[1:]
s[0] = 0
a[0] = 0
shift = 0
}
}
a[0] >>= bitsPerWord - shift
s[0] >>= bitsPerWord - shift
}
func (p *poly3) marshal(out []byte) {
s := p.s[:]
a := p.a[:]
sw := s[0]
aw := a[0]
var shift int
for i := 0; i < 700; i += 5 {
acc, scale := 0, 1
for j := 0; j < 5; j++ {
v := int(aw&1) | int(sw&1)<<1
acc += scale * v
scale *= 3
shift++
if shift == bitsPerWord {
s = s[1:]
a = a[1:]
sw = s[0]
aw = a[0]
shift = 0
} else {
sw >>= 1
aw >>= 1
}
}
out[0] = byte(acc)
out = out[1:]
}
}
func (p *poly) fromMod2(in *poly2) {
var shift uint
words := in[:]
word := words[0]
for i := range p {
p[i] = uint16(word & 1)
word >>= 1
shift++
if shift == bitsPerWord {
words = words[1:]
word = words[0]
shift = 0
}
}
}
func (p *poly) fromMod3(in *poly3) {
var shift uint
s := in.s[:]
a := in.a[:]
sw := s[0]
aw := a[0]
for i := range p {
p[i] = uint16(aw&1 | (sw&1)<<1)
aw >>= 1
sw >>= 1
shift++
if shift == bitsPerWord {
a = a[1:]
s = s[1:]
aw = a[0]
sw = s[0]
shift = 0
}
}
}
func (p *poly) fromMod3ToModQ(in *poly3) {
var shift uint
s := in.s[:]
a := in.a[:]
sw := s[0]
aw := a[0]
for i := range p {
p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1))
aw >>= 1
sw >>= 1
shift++
if shift == bitsPerWord {
a = a[1:]
s = s[1:]
aw = a[0]
sw = s[0]
shift = 0
}
}
}
func lsbToAll(v uint) uint {
return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1))
}
func (p *poly3) mulConst(ms, ma uint) {
ms = lsbToAll(ms)
ma = lsbToAll(ma)
for i := range p.a {
p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i])
}
}
func cmovWords(out, in *[wordsPerPoly]uint, mov uint) {
for i := range out {
out[i] = (out[i] & ^mov) | (in[i] & mov)
}
}
func rotWords(out, in *[wordsPerPoly]uint, bits uint) {
start := bits / bitsPerWord
n := (N - bits) / bitsPerWord
for i := uint(0); i < n; i++ {
out[i] = in[start+i]
}
carry := in[wordsPerPoly-1]
for i := uint(0); i < start; i++ {
out[n+i] = carry | in[i]<<bitsInLastWord
carry = in[i] >> (bitsPerWord - bitsInLastWord)
}
out[wordsPerPoly-1] = carry
}
// rotBits right-rotates the bits in |in|. bits must be a non-zero power of two
// and less than bitsPerWord.
func rotBits(out, in *[wordsPerPoly]uint, bits uint) {
if bits == 0 || (bits&(bits-1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2 {
panic("internal error")
}
carry := in[wordsPerPoly-1] << (bitsPerWord - bits)
for i := wordsPerPoly - 2; i >= 0; i-- {
out[i] = carry | in[i]>>bits
carry = in[i] << (bitsPerWord - bits)
}
out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits
}
func (p *poly3) rotWords(bits uint, in *poly3) {
rotWords(&p.s, &in.s, bits)
rotWords(&p.a, &in.a, bits)
}
func (p *poly3) rotBits(bits uint, in *poly3) {
rotBits(&p.s, &in.s, bits)
rotBits(&p.a, &in.a, bits)
}
func (p *poly3) cmov(in *poly3, mov uint) {
cmovWords(&p.s, &in.s, mov)
cmovWords(&p.a, &in.a, mov)
}
func (p *poly3) rot(bits uint) {
if bits > N {
panic("invalid")
}
var shifted poly3
shift := uint(9)
for ; (1 << shift) >= bitsPerWord; shift-- {
shifted.rotWords(1<<shift, p)
p.cmov(&shifted, lsbToAll(bits>>shift))
}
for ; shift < 9; shift-- {
shifted.rotBits(1<<shift, p)
p.cmov(&shifted, lsbToAll(bits>>shift))
}
}
func (p *poly3) fmadd(ms, ma uint, in *poly3) {
ms = lsbToAll(ms)
ma = lsbToAll(ma)
for i := range p.a {
products := (ma & in.s[i]) ^ (ms & in.a[i])
producta := (ma & in.a[i]) ^ (ms & in.s[i])
ns1Ana1 := ^p.s[i] & ^p.a[i]
ns2Ana2 := ^products & ^producta
p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2)
}
}
func (p *poly3) modPhiN() {
factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1))
factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1))
ns2Ana2 := ^factors & ^factora
for i := range p.s {
ns1Ana1 := ^p.s[i] & ^p.a[i]
p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2)
}
}
func (p *poly3) cswap(other *poly3, swap uint) {
for i := range p.s {
sums := swap & (p.s[i] ^ other.s[i])
p.s[i] ^= sums
other.s[i] ^= sums
suma := swap & (p.a[i] ^ other.a[i])
p.a[i] ^= suma
other.a[i] ^= suma
}
}
func (p *poly3) mulx() {
carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1
carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1
for i := range p.s {
outCarrys := p.s[i] >> (bitsPerWord - 1)
outCarrya := p.a[i] >> (bitsPerWord - 1)
p.s[i] <<= 1
p.a[i] <<= 1
p.s[i] |= carrys
p.a[i] |= carrya
carrys = outCarrys
carrya = outCarrya
}
}
func (p *poly3) divx() {
var carrys, carrya uint
for i := len(p.s) - 1; i >= 0; i-- {
outCarrys := p.s[i] & 1
outCarrya := p.a[i] & 1
p.s[i] >>= 1
p.a[i] >>= 1
p.s[i] |= carrys << (bitsPerWord - 1)
p.a[i] |= carrya << (bitsPerWord - 1)
carrys = outCarrys
carrya = outCarrya
}
}
type poly2 [wordsPerPoly]uint
func (p *poly2) fromDiscrete(in *poly) {
var shift uint
words := p[:]
words[0] = 0
for _, v := range in {
words[0] >>= 1
words[0] |= uint(v&1) << (bitsPerWord - 1)
shift++
if shift == bitsPerWord {
words = words[1:]
words[0] = 0
shift = 0
}
}
words[0] >>= bitsPerWord - shift
}
func (p *poly2) setPhiN() {
for i := range p {
p[i] = ^uint(0)
}
p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1
}
func (p *poly2) cswap(other *poly2, swap uint) {
for i := range p {
sum := swap & (p[i] ^ other[i])
p[i] ^= sum
other[i] ^= sum
}
}
func (p *poly2) fmadd(m uint, in *poly2) {
m = ^(m - 1)
for i := range p {
p[i] ^= in[i] & m
}
}
func (p *poly2) lshift1() {
var carry uint
for i := range p {
nextCarry := p[i] >> (bitsPerWord - 1)
p[i] <<= 1
p[i] |= carry
carry = nextCarry
}
}
func (p *poly2) rshift1() {
var carry uint
for i := len(p) - 1; i >= 0; i-- {
nextCarry := p[i] & 1
p[i] >>= 1
p[i] |= carry << (bitsPerWord - 1)
carry = nextCarry
}
}
func (p *poly2) rot(bits uint) {
if bits > N {
panic("invalid")
}
var shifted [wordsPerPoly]uint
out := (*[wordsPerPoly]uint)(p)
shift := uint(9)
for ; (1 << shift) >= bitsPerWord; shift-- {
rotWords(&shifted, out, 1<<shift)
cmovWords(out, &shifted, lsbToAll(bits>>shift))
}
for ; shift < 9; shift-- {
rotBits(&shifted, out, 1<<shift)
cmovWords(out, &shifted, lsbToAll(bits>>shift))
}
}
type poly [N]uint16
func (in *poly) marshal(out []byte) {
p := in[:]
for len(p) >= 8 {
out[0] = byte(p[0])
out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5)
out[2] = byte(p[1] >> 3)
out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2)
out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7)
out[5] = byte(p[3] >> 1)
out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4)
out[7] = byte(p[4] >> 4)
out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1)
out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6)
out[10] = byte(p[6] >> 2)
out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3)
out[12] = byte(p[7] >> 5)
p = p[8:]
out = out[13:]
}
// There are four remaining values.
out[0] = byte(p[0])
out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5)
out[2] = byte(p[1] >> 3)
out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2)
out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7)
out[5] = byte(p[3] >> 1)
out[6] = byte(p[3] >> 9)
}
func (out *poly) unmarshal(in []byte) bool {
p := out[:]
for i := 0; i < 87; i++ {
p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8
p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11
p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6
p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9
p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12
p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7
p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10
p[7] = uint16(in[11]>>3) | uint16(in[12])<<5
p = p[8:]
in = in[13:]
}
// There are four coefficients left over
p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8
p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11
p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6
p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9
if in[6]&0xf0 != 0 {
return false
}
out[N-1] = 0
var top int
for _, v := range out {
top += int(v)
}
out[N-1] = uint16(-top) % Q
return true
}
func (in *poly) marshalS3(out []byte) {
p := in[:]
for len(p) >= 5 {
out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81)
out = out[1:]
p = p[5:]
}
}
func (out *poly) unmarshalS3(in []byte) bool {
p := out[:]
for i := 0; i < 140; i++ {
c := in[0]
if c >= 243 {
return false
}
p[0] = uint16(c % 3)
p[1] = uint16((c / 3) % 3)
p[2] = uint16((c / 9) % 3)
p[3] = uint16((c / 27) % 3)
p[4] = uint16((c / 81) % 3)
p = p[5:]
in = in[1:]
}
out[N-1] = 0
return true
}
func (p *poly) modPhiN() {
for i := range p {
p[i] = (p[i] + Q - p[N-1]) % Q
}
}
func (out *poly) shortSample(in []byte) {
// b a result
// 00 00 00
// 00 01 01
// 00 10 10
// 00 11 11
// 01 00 10
// 01 01 00
// 01 10 01
// 01 11 11
// 10 00 01
// 10 01 10
// 10 10 00
// 10 11 11
// 11 00 11
// 11 01 11
// 11 10 11
// 11 11 11
// 1111 1111 1100 1001 1101 0010 1110 0100
// f f c 9 d 2 e 4
const lookup = uint32(0xffc9d2e4)
p := out[:]
for i := 0; i < 87; i++ {
v := binary.LittleEndian.Uint32(in)
v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555)
for j := 0; j < 8; j++ {
p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3)
v2 >>= 4
}
p = p[8:]
in = in[4:]
}
// There are four values remaining.
v := binary.LittleEndian.Uint32(in)
v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555)
for j := 0; j < 4; j++ {
p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3)
v2 >>= 4
}
out[N-1] = 0
}
func (out *poly) shortSamplePlus(in []byte) {
out.shortSample(in)
var sum uint16
for i := 0; i < N-1; i++ {
sum += mod3ResultToModQ(out[i] * out[i+1])
}
scale := 1 + (1 & (sum >> 12))
for i := 0; i < len(out); i += 2 {
out[i] = (out[i] * scale) % 3
}
}
func mul(out, scratch, a, b []uint16) {
const schoolbookLimit = 32
if len(a) < schoolbookLimit {
for i := 0; i < len(a)*2; i++ {
out[i] = 0
}
for i := range a {
for j := range b {
out[i+j] += a[i] * b[j]
}
}
return
}
lowLen := len(a) / 2
highLen := len(a) - lowLen
aLow, aHigh := a[:lowLen], a[lowLen:]
bLow, bHigh := b[:lowLen], b[lowLen:]
for i := 0; i < lowLen; i++ {
out[i] = aHigh[i] + aLow[i]
}
if highLen != lowLen {
out[lowLen] = aHigh[lowLen]
}
for i := 0; i < lowLen; i++ {
out[highLen+i] = bHigh[i] + bLow[i]
}
if highLen != lowLen {
out[highLen+lowLen] = bHigh[lowLen]
}
mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2])
mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh)
mul(out, scratch[2*highLen:], aLow, bLow)
for i := 0; i < lowLen*2; i++ {
scratch[i] -= out[i] + out[lowLen*2+i]
}
if lowLen != highLen {
scratch[lowLen*2] -= out[lowLen*4]
}
for i := 0; i < 2*highLen; i++ {
out[lowLen+i] += scratch[i]
}
}
func (out *poly) mul(a, b *poly) {
var prod, scratch [2 * N]uint16
mul(prod[:], scratch[:], a[:], b[:])
for i := range out {
out[i] = (prod[i] + prod[i+N]) % Q
}
}
func (p3 *poly3) mulMod3(x, y *poly3) {
// (𝑥^n - 1) is a multiple of Φ(N) so we can work mod (𝑥^n - 1) here and
// (reduce mod Φ(N) afterwards.
x3 := *x
y3 := *y
s := x3.s[:]
a := x3.a[:]
sw := s[0]
aw := a[0]
p3.zero()
var shift uint
for i := 0; i < N; i++ {
p3.fmadd(sw, aw, &y3)
sw >>= 1
aw >>= 1
shift++
if shift == bitsPerWord {
s = s[1:]
a = a[1:]
sw = s[0]
aw = a[0]
shift = 0
}
y3.mulx()
}
p3.modPhiN()
}
// mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff}
// The case of n == 3 should never happen but is included so that modQToMod3
// can easily catch invalid inputs.
func mod3ToModQ(n uint16) uint16 {
return uint16(uint64(0xffff1fff00010000) >> (16 * n))
}
// modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int
// which is one if the input is in range and zero otherwise.
func modQToMod3(n uint16) (uint16, int) {
result := (n&3 - (n>>1)&1)
return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n))
}
// mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1}
func mod3ResultToModQ(n uint16) uint16 {
return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1)
//shift := (uint(0x324) >> (2 * n)) & 3
//return uint16(uint64(0x00011fff00010000) >> (16 * shift))
}
// mulXMinus1 sets out to a×(𝑥 - 1) mod (𝑥^n - 1)
func (out *poly) mulXMinus1() {
// Multiplying by (𝑥 - 1) means negating each coefficient and adding in
// the value of the previous one.
origOut700 := out[700]
for i := N - 1; i > 0; i-- {
out[i] = (Q - out[i] + out[i-1]) % Q
}
out[0] = (Q - out[0] + origOut700) % Q
}
func (out *poly) lift(a *poly) {
// We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the
// Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime).
// 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up:
//
// R.<x> = PolynomialRing(GF(3)…)
// inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n))
// list(inv)[:15]
// [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]
//
// This three-element pattern of coefficients repeats for the whole
// polynomial.
//
// Next define the overbar operator such that z̅ = z[0] +
// reverse(z[1:]). (Index zero of a polynomial here is the coefficient
// of the constant term. So index one is the coefficient of 𝑥 and so
// on.)
//
// A less odd way to define this is to see that z̅ negates the indexes,
// so z̅[0] = z[-0], z̅[1] = z[-1] and so on.
//
// The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = <v,
// z̅>, vz[1] = <v, 𝑥z̅>, …. (Where <a, b> is the inner product: the sum
// of the point-wise products.) Although we calculated the inverse mod
// Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end.
// (That's because (𝑥^N - 1) is a multiple of Φ(N).)
//
// When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation
// of the list of coefficients.
//
// Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like:
//
// def reverse(xs):
// suffix = list(xs[1:])
// suffix.reverse()
// return [xs[0]] + suffix
//
// def rotate(xs):
// return [xs[-1]] + xs[:-1]
//
// zoverbar = reverse(list(inv) + [0])
// xzoverbar = rotate(reverse(list(inv) + [0]))
// x2zoverbar = rotate(rotate(reverse(list(inv) + [0])))
//
// zoverbar[:15]
// [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]
// xzoverbar[:15]
// [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0]
// x2zoverbar[:15]
// [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]
//
// (For a formula for z̅, see lemma two of appendix B.)
//
// After the first three elements have been taken care of, all then have
// a repeating three-element cycle. The next value (𝑥^3z̅) involves
// three rotations of the first pattern, thus the three-element cycle
// lines up. However, the discontinuity in the first three elements
// obviously moves to a different position. Consider the difference
// between 𝑥^3z̅ and z̅:
//
// [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15]
// [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
//
// This pattern of differences is the same for all elements, although it
// obviously moves right with the rotations.
//
// From this, we reach algorithm eight of appendix B.
// Handle the first three elements of the inner products.
out[0] = a[0] + a[2]
out[1] = a[1]
out[2] = 2*a[0] + a[2]
// Use the repeating pattern to complete the first three inner products.
for i := 3; i < 699; i += 3 {
out[0] += 2*a[i] + a[i+2]
out[1] += a[i] + 2*a[i+1]
out[2] += a[i+1] + 2*a[i+2]
}
// Handle the fact that the three-element pattern doesn't fill the
// polynomial exactly (since 701 isn't a multiple of three).
out[2] += a[700]
out[0] += 2 * a[699]
out[1] += a[699] + 2*a[700]
out[0] = out[0] % 3
out[1] = out[1] % 3
out[2] = out[2] % 3
// Calculate the remaining inner products by taking advantage of the
// fact that the pattern repeats every three cycles and the pattern of
// differences is moves with the rotation.
for i := 3; i < N; i++ {
// Add twice something is the same as subtracting when working
// mod 3. Doing it this way avoids underflow. Underflow is bad
// because "% 3" doesn't work correctly for negative numbers
// here since underflow will wrap to 2^16-1 and 2^16 isn't a
// multiple of three.
out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3
}
// Reduce mod Φ(N) by subtracting a multiple of out[700] from every
// element and convert to mod Q. (See above about adding twice as
// subtraction.)
v := out[700] * 2
for i := range out {
out[i] = mod3ToModQ((out[i] + v) % 3)
}
out.mulXMinus1()
}
func (a *poly) cswap(b *poly, swap uint16) {
for i := range a {
sum := swap & (a[i] ^ b[i])
a[i] ^= sum
b[i] ^= sum
}
}
func lt(a, b uint) uint {
if a < b {
return ^uint(0)
}
return 0
}
func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) {
s3 = (a1 & s2) ^ (s1 & a2)
a3 = (a1 & a2) ^ (s1 & s2)
return
}
func (out *poly3) invertMod3(in *poly3) {
// This algorithm follows algorithm 10 in the paper. (Although note that
// the paper appears to have a bug: k should start at zero, not one.)
// The best explanation for why it works is in the "Why it works"
// section of
// https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf.
var k uint
degF, degG := uint(N-1), uint(N-1)
var b, c, g poly3
f := *in
for i := range g.a {
g.a[i] = ^uint(0)
}
b.a[0] = 1
var f0s, f0a uint
stillGoing := ^uint(0)
for i := 0; i < 2*(N-1)-1; i++ {
ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0])
ss, sa = sa&stillGoing&1, ss&stillGoing&1
shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG)
f.cswap(&g, shouldSwap)
b.cswap(&c, shouldSwap)
degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap)
f.fmadd(ss, sa, &g)
b.fmadd(ss, sa, &c)
f.divx()
f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1
f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1
c.mulx()
c.s[0] &= ^uint(1)
c.a[0] &= ^uint(1)
degF--
k += 1 & stillGoing
f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s)
f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a)
stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1))
}
k -= N & lt(N, k)
*out = b
out.rot(k)
out.mulConst(f0s, f0a)
out.modPhiN()
}
func (out *poly) invertMod2(a *poly) {
// This algorithm follows mix of algorithm 10 in the paper and the first
// page of the PDF linked below. (Although note that the paper appears
// to have a bug: k should start at zero, not one.) The best explanation
// for why it works is in the "Why it works" section of
// https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf.
var k uint
degF, degG := uint(N-1), uint(N-1)
var f poly2
f.fromDiscrete(a)
var b, c, g poly2
g.setPhiN()
b[0] = 1
stillGoing := ^uint(0)
for i := 0; i < 2*(N-1)-1; i++ {
s := uint(f[0]&1) & stillGoing
shouldSwap := ^(s - 1) & lt(degF, degG)
f.cswap(&g, shouldSwap)
b.cswap(&c, shouldSwap)
degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap)
f.fmadd(s, &g)
b.fmadd(s, &c)
f.rshift1()
c.lshift1()
degF--
k += 1 & stillGoing
stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1))
}
k -= N & lt(N, k)
b.rot(k)
out.fromMod2(&b)
}
func (out *poly) invert(origA *poly) {
// Inversion mod Q, which is done based on the result of inverting mod
// 2. See the NTRU paper, page three.
var a, tmp, tmp2, b poly
b.invertMod2(origA)
// Negate a.
for i := range a {
a[i] = Q - origA[i]
}
// We are working mod Q=2**13 and we need to iterate ceil(log_2(13))
// times, which is four.
for i := 0; i < 4; i++ {
tmp.mul(&a, &b)
tmp[0] += 2
tmp2.mul(&b, &tmp)
b = tmp2
}
*out = b
}
type PublicKey struct {
h poly
}
func ParsePublicKey(in []byte) (*PublicKey, bool) {
ret := new(PublicKey)
if !ret.h.unmarshal(in) {
return nil, false
}
return ret, true
}
func (pub *PublicKey) Marshal() []byte {
ret := make([]byte, modQBytes)
pub.h.marshal(ret)
return ret
}
func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) {
var randBytes [352 + 352]byte
if _, err := io.ReadFull(rand, randBytes[:]); err != nil {
panic("rand failed")
}
var m, r poly
m.shortSample(randBytes[:352])
r.shortSample(randBytes[352:])
var mBytes, rBytes [mod3Bytes]byte
m.marshalS3(mBytes[:])
r.marshalS3(rBytes[:])
ciphertext = pub.owf(&m, &r)
h := sha256.New()
h.Write([]byte("shared key\x00"))
h.Write(mBytes[:])
h.Write(rBytes[:])
h.Write(ciphertext)
sharedKey = h.Sum(nil)
return ciphertext, sharedKey
}
func (pub *PublicKey) owf(m, r *poly) []byte {
for i := range r {
r[i] = mod3ToModQ(r[i])
}
var mq poly
mq.lift(m)
var e poly
e.mul(r, &pub.h)
for i := range e {
e[i] = (e[i] + mq[i]) % Q
}
ret := make([]byte, modQBytes)
e.marshal(ret[:])
return ret
}
type PrivateKey struct {
PublicKey
f, fp poly3
hInv poly
hmacKey [32]byte
}
func (priv *PrivateKey) Marshal() []byte {
var ret [2*mod3Bytes + modQBytes]byte
priv.f.marshal(ret[:])
priv.fp.marshal(ret[mod3Bytes:])
priv.h.marshal(ret[2*mod3Bytes:])
return ret[:]
}
func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) {
if len(ciphertext) != modQBytes {
return nil, false
}
var e poly
if !e.unmarshal(ciphertext) {
return nil, false
}
var f poly
f.fromMod3ToModQ(&priv.f)
var v1, m poly
v1.mul(&e, &f)
var v13 poly3
v13.fromDiscreteMod3(&v1)
// Note: v13 is not reduced mod phi(n).
var m3 poly3
m3.mulMod3(&v13, &priv.fp)
m3.modPhiN()
m.fromMod3(&m3)
var mLift, delta poly
mLift.lift(&m)
for i := range delta {
delta[i] = (e[i] - mLift[i] + Q) % Q
}
delta.mul(&delta, &priv.hInv)
delta.modPhiN()
var r poly3
allOk := r.fromModQ(&delta)
var mBytes, rBytes [mod3Bytes]byte
m.marshalS3(mBytes[:])
r.marshal(rBytes[:])
var rPoly poly
rPoly.fromMod3(&r)
expectedCiphertext := priv.PublicKey.owf(&m, &rPoly)
allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext)
hmacHash := hmac.New(sha256.New, priv.hmacKey[:])
hmacHash.Write(ciphertext)
hmacDigest := hmacHash.Sum(nil)
h := sha256.New()
h.Write([]byte("shared key\x00"))
h.Write(mBytes[:])
h.Write(rBytes[:])
h.Write(ciphertext)
sharedKey = h.Sum(nil)
mask := uint8(allOk - 1)
for i := range sharedKey {
sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask)
}
return sharedKey, true
}
func GenerateKey(rand io.Reader) PrivateKey {
var randBytes [352 + 352]byte
if _, err := io.ReadFull(rand, randBytes[:]); err != nil {
panic("rand failed")
}
var f poly
f.shortSamplePlus(randBytes[:352])
var priv PrivateKey
priv.f.fromDiscrete(&f)
priv.fp.invertMod3(&priv.f)
var g poly
g.shortSamplePlus(randBytes[352:])
var pgPhi1 poly
for i := range g {
pgPhi1[i] = mod3ToModQ(g[i])
}
for i := range pgPhi1 {
pgPhi1[i] = (pgPhi1[i] * 3) % Q
}
pgPhi1.mulXMinus1()
var fModQ poly
fModQ.fromMod3ToModQ(&priv.f)
var pfgPhi1 poly
pfgPhi1.mul(&fModQ, &pgPhi1)
var i poly
i.invert(&pfgPhi1)
priv.h.mul(&i, &pgPhi1)
priv.h.mul(&priv.h, &pgPhi1)
priv.hInv.mul(&i, &fModQ)
priv.hInv.mul(&priv.hInv, &fModQ)
return priv
}