diff --git a/13.go b/13.go index 6b96a6e..a55bb78 100644 --- a/13.go +++ b/13.go @@ -121,6 +121,7 @@ type kexNIST struct{ defaultServerKEX } // Used by NIST curves; P-256, P-384 type kexX25519 struct{ defaultServerKEX } // Used by X25519 type kexSIDHp503 struct{ defaultServerKEX } // Used by SIDH/P503 type kexSIKEp503 struct{} // Used by SIKE/P503 +type kexHRSS struct{} // Used by HRSS type kexHybridSIDHp503X25519 struct { defaultServerKEX classicKEX kexX25519 @@ -132,6 +133,11 @@ type kexHybridSIKEp503X25519 struct { pqKEX kexSIKEp503 } // Used by SIKE-ECDHE hybrid scheme +type kexHybridHRSSCurve25519 struct { + classicKEX kexX25519 + pqKEX kexHRSS +} // Used by HRSS-ECDHE hybrid scheme + // Routing map for key exchange strategies var kexStrat = map[CurveID]kex{ CurveP256: &kexNIST{}, @@ -140,6 +146,7 @@ var kexStrat = map[CurveID]kex{ X25519: &kexX25519{}, HybridSIDHp503Curve25519: &kexHybridSIDHp503X25519{}, HybridSIKEp503Curve25519: &kexHybridSIKEp503X25519{}, + HybridHRSSCurve25519: &kexHybridHRSSCurve25519{}, } func newKeySchedule13(suite *cipherSuite, config *Config, clientRandom []byte) *keySchedule13 { diff --git a/common.go b/common.go index ce869c7..137d036 100644 --- a/common.go +++ b/common.go @@ -125,6 +125,7 @@ const ( // Experimental KEX HybridSIDHp503Curve25519 CurveID = 0xFE30 HybridSIKEp503Curve25519 CurveID = 0xFE32 + HybridHRSSCurve25519 CurveID = 0x4138 ) // TLS 1.3 Key Share diff --git a/hrss.go b/hrss.go new file mode 100644 index 0000000..1a2b6b9 --- /dev/null +++ b/hrss.go @@ -0,0 +1,1212 @@ +package tls + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "encoding/binary" + "io" + "math/bits" +) + +const ( + PublicKeySize = modQBytes + CiphertextSize = modQBytes +) + +const ( + N = 701 + Q = 8192 + mod3Bytes = 140 + modQBytes = 1138 +) + +const ( + bitsPerWord = bits.UintSize + wordsPerPoly = (N + bitsPerWord - 1) / bitsPerWord + fullWordsPerPoly = N / bitsPerWord + bitsInLastWord = N % bitsPerWord +) + +// poly3 represents a degree-N polynomial over GF(3). Each coefficient is +// bitsliced across the |s| and |a| arrays, like this: +// +// s | a | value +// ----------------- +// 0 | 0 | 0 +// 0 | 1 | 1 +// 1 | 0 | 2 (aka -1) +// 1 | 1 | +// +// ('s' is for sign, and 'a' is just a letter.) +// +// Once bitsliced as such, the following circuits can be used to implement +// addition and multiplication mod 3: +// +// (s3, a3) = (s1, a1) × (s2, a2) +// s3 = (s2 ∧ a1) ⊕ (s1 ∧ a2) +// a3 = (s1 ∧ s2) ⊕ (a1 ∧ a2) +// +// (s3, a3) = (s1, a1) + (s2, a2) +// t1 = ~(s1 ∨ a1) +// t2 = ~(s2 ∨ a2) +// s3 = (a1 ∧ a2) ⊕ (t1 ∧ s2) ⊕ (t2 ∧ s1) +// a3 = (s1 ∧ s2) ⊕ (t1 ∧ a2) ⊕ (t2 ∧ a1) +// +// Negating a value just involves swapping s and a. +type poly3 struct { + s [wordsPerPoly]uint + a [wordsPerPoly]uint +} + +func (p *poly3) trim() { + p.s[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 + p.a[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 +} + +func (p *poly3) zero() { + for i := range p.a { + p.s[i] = 0 + p.a[i] = 0 + } +} + +func (p *poly3) fromDiscrete(in *poly) { + var shift uint + s := p.s[:] + a := p.a[:] + s[0] = 0 + a[0] = 0 + + for _, v := range in { + s[0] >>= 1 + s[0] |= uint((v>>1)&1) << (bitsPerWord - 1) + a[0] >>= 1 + a[0] |= uint(v&1) << (bitsPerWord - 1) + shift++ + if shift == bitsPerWord { + s = s[1:] + a = a[1:] + s[0] = 0 + a[0] = 0 + shift = 0 + } + } + + a[0] >>= bitsPerWord - shift + s[0] >>= bitsPerWord - shift +} + +func (p *poly3) fromModQ(in *poly) int { + var shift uint + s := p.s[:] + a := p.a[:] + s[0] = 0 + a[0] = 0 + ok := 1 + + for _, v := range in { + vMod3, vOk := modQToMod3(v) + ok &= vOk + + s[0] >>= 1 + s[0] |= uint((vMod3>>1)&1) << (bitsPerWord - 1) + a[0] >>= 1 + a[0] |= uint(vMod3&1) << (bitsPerWord - 1) + shift++ + if shift == bitsPerWord { + s = s[1:] + a = a[1:] + s[0] = 0 + a[0] = 0 + shift = 0 + } + } + + a[0] >>= bitsPerWord - shift + s[0] >>= bitsPerWord - shift + + return ok +} + +func (p *poly3) fromDiscreteMod3(in *poly) { + var shift uint + s := p.s[:] + a := p.a[:] + s[0] = 0 + a[0] = 0 + + for _, v := range in { + // This duplicates the 13th bit upwards to the top of the + // uint16, essentially treating it as a sign bit and converting + // into a signed int16. The signed value is reduced mod 3, + // yeilding {-2, -1, 0, 1, 2}. + v = uint16((int16(v<<3)>>3)%3) & 7 + + // We want to map v thus: + // {-2, -1, 0, 1, 2} -> {1, 2, 0, 1, 2}. We take the bottom + // three bits and then the constants below, when shifted by + // those three bits, perform the required mapping. + s[0] >>= 1 + s[0] |= (0xbc >> v) << (bitsPerWord - 1) + a[0] >>= 1 + a[0] |= (0x7a >> v) << (bitsPerWord - 1) + shift++ + if shift == bitsPerWord { + s = s[1:] + a = a[1:] + s[0] = 0 + a[0] = 0 + shift = 0 + } + } + + a[0] >>= bitsPerWord - shift + s[0] >>= bitsPerWord - shift +} + +func (p *poly3) marshal(out []byte) { + s := p.s[:] + a := p.a[:] + sw := s[0] + aw := a[0] + var shift int + + for i := 0; i < 700; i += 5 { + acc, scale := 0, 1 + for j := 0; j < 5; j++ { + v := int(aw&1) | int(sw&1)<<1 + acc += scale * v + scale *= 3 + + shift++ + if shift == bitsPerWord { + s = s[1:] + a = a[1:] + sw = s[0] + aw = a[0] + shift = 0 + } else { + sw >>= 1 + aw >>= 1 + } + } + + out[0] = byte(acc) + out = out[1:] + } +} + +func (p *poly) fromMod2(in *poly2) { + var shift uint + words := in[:] + word := words[0] + + for i := range p { + p[i] = uint16(word & 1) + word >>= 1 + shift++ + if shift == bitsPerWord { + words = words[1:] + word = words[0] + shift = 0 + } + } +} + +func (p *poly) fromMod3(in *poly3) { + var shift uint + s := in.s[:] + a := in.a[:] + sw := s[0] + aw := a[0] + + for i := range p { + p[i] = uint16(aw&1 | (sw&1)<<1) + aw >>= 1 + sw >>= 1 + shift++ + if shift == bitsPerWord { + a = a[1:] + s = s[1:] + aw = a[0] + sw = s[0] + shift = 0 + } + } +} + +func (p *poly) fromMod3ToModQ(in *poly3) { + var shift uint + s := in.s[:] + a := in.a[:] + sw := s[0] + aw := a[0] + + for i := range p { + p[i] = mod3ToModQ(uint16(aw&1 | (sw&1)<<1)) + aw >>= 1 + sw >>= 1 + shift++ + if shift == bitsPerWord { + a = a[1:] + s = s[1:] + aw = a[0] + sw = s[0] + shift = 0 + } + } +} + +func lsbToAll(v uint) uint { + return uint(int(v<<(bitsPerWord-1)) >> (bitsPerWord - 1)) +} + +func (p *poly3) mulConst(ms, ma uint) { + ms = lsbToAll(ms) + ma = lsbToAll(ma) + + for i := range p.a { + p.s[i], p.a[i] = (ma&p.s[i])^(ms&p.a[i]), (ma&p.a[i])^(ms&p.s[i]) + } +} + +func cmovWords(out, in *[wordsPerPoly]uint, mov uint) { + for i := range out { + out[i] = (out[i] & ^mov) | (in[i] & mov) + } +} + +func rotWords(out, in *[wordsPerPoly]uint, bits uint) { + start := bits / bitsPerWord + n := (N - bits) / bitsPerWord + + for i := uint(0); i < n; i++ { + out[i] = in[start+i] + } + + carry := in[wordsPerPoly-1] + + for i := uint(0); i < start; i++ { + out[n+i] = carry | in[i]<> (bitsPerWord - bitsInLastWord) + } + + out[wordsPerPoly-1] = carry +} + +// rotBits right-rotates the bits in |in|. bits must be a non-zero power of two +// and less than bitsPerWord. +func rotBits(out, in *[wordsPerPoly]uint, bits uint) { + if bits == 0 || (bits&(bits-1)) != 0 || bits > bitsPerWord/2 || bitsInLastWord < bitsPerWord/2 { + panic("internal error") + } + + carry := in[wordsPerPoly-1] << (bitsPerWord - bits) + + for i := wordsPerPoly - 2; i >= 0; i-- { + out[i] = carry | in[i]>>bits + carry = in[i] << (bitsPerWord - bits) + } + + out[wordsPerPoly-1] = carry>>(bitsPerWord-bitsInLastWord) | in[wordsPerPoly-1]>>bits +} + +func (p *poly3) rotWords(bits uint, in *poly3) { + rotWords(&p.s, &in.s, bits) + rotWords(&p.a, &in.a, bits) +} + +func (p *poly3) rotBits(bits uint, in *poly3) { + rotBits(&p.s, &in.s, bits) + rotBits(&p.a, &in.a, bits) +} + +func (p *poly3) cmov(in *poly3, mov uint) { + cmovWords(&p.s, &in.s, mov) + cmovWords(&p.a, &in.a, mov) +} + +func (p *poly3) rot(bits uint) { + if bits > N { + panic("invalid") + } + var shifted poly3 + + shift := uint(9) + for ; (1 << shift) >= bitsPerWord; shift-- { + shifted.rotWords(1<>shift)) + } + for ; shift < 9; shift-- { + shifted.rotBits(1<>shift)) + } +} + +func (p *poly3) fmadd(ms, ma uint, in *poly3) { + ms = lsbToAll(ms) + ma = lsbToAll(ma) + + for i := range p.a { + products := (ma & in.s[i]) ^ (ms & in.a[i]) + producta := (ma & in.a[i]) ^ (ms & in.s[i]) + + ns1Ana1 := ^p.s[i] & ^p.a[i] + ns2Ana2 := ^products & ^producta + + p.s[i], p.a[i] = (p.a[i]&producta)^(ns1Ana1&products)^(p.s[i]&ns2Ana2), (p.s[i]&products)^(ns1Ana1&producta)^(p.a[i]&ns2Ana2) + } +} + +func (p *poly3) modPhiN() { + factora := uint(int(p.s[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) + factors := uint(int(p.a[wordsPerPoly-1]<<(bitsPerWord-bitsInLastWord)) >> (bitsPerWord - 1)) + ns2Ana2 := ^factors & ^factora + + for i := range p.s { + ns1Ana1 := ^p.s[i] & ^p.a[i] + p.s[i], p.a[i] = (p.a[i]&factora)^(ns1Ana1&factors)^(p.s[i]&ns2Ana2), (p.s[i]&factors)^(ns1Ana1&factora)^(p.a[i]&ns2Ana2) + } +} + +func (p *poly3) cswap(other *poly3, swap uint) { + for i := range p.s { + sums := swap & (p.s[i] ^ other.s[i]) + p.s[i] ^= sums + other.s[i] ^= sums + + suma := swap & (p.a[i] ^ other.a[i]) + p.a[i] ^= suma + other.a[i] ^= suma + } +} + +func (p *poly3) mulx() { + carrys := (p.s[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 + carrya := (p.a[wordsPerPoly-1] >> (bitsInLastWord - 1)) & 1 + + for i := range p.s { + outCarrys := p.s[i] >> (bitsPerWord - 1) + outCarrya := p.a[i] >> (bitsPerWord - 1) + p.s[i] <<= 1 + p.a[i] <<= 1 + p.s[i] |= carrys + p.a[i] |= carrya + carrys = outCarrys + carrya = outCarrya + } +} + +func (p *poly3) divx() { + var carrys, carrya uint + + for i := len(p.s) - 1; i >= 0; i-- { + outCarrys := p.s[i] & 1 + outCarrya := p.a[i] & 1 + p.s[i] >>= 1 + p.a[i] >>= 1 + p.s[i] |= carrys << (bitsPerWord - 1) + p.a[i] |= carrya << (bitsPerWord - 1) + carrys = outCarrys + carrya = outCarrya + } +} + +type poly2 [wordsPerPoly]uint + +func (p *poly2) fromDiscrete(in *poly) { + var shift uint + words := p[:] + words[0] = 0 + + for _, v := range in { + words[0] >>= 1 + words[0] |= uint(v&1) << (bitsPerWord - 1) + shift++ + if shift == bitsPerWord { + words = words[1:] + words[0] = 0 + shift = 0 + } + } + + words[0] >>= bitsPerWord - shift +} + +func (p *poly2) setPhiN() { + for i := range p { + p[i] = ^uint(0) + } + p[wordsPerPoly-1] &= (1 << bitsInLastWord) - 1 +} + +func (p *poly2) cswap(other *poly2, swap uint) { + for i := range p { + sum := swap & (p[i] ^ other[i]) + p[i] ^= sum + other[i] ^= sum + } +} + +func (p *poly2) fmadd(m uint, in *poly2) { + m = ^(m - 1) + + for i := range p { + p[i] ^= in[i] & m + } +} + +func (p *poly2) lshift1() { + var carry uint + for i := range p { + nextCarry := p[i] >> (bitsPerWord - 1) + p[i] <<= 1 + p[i] |= carry + carry = nextCarry + } +} + +func (p *poly2) rshift1() { + var carry uint + for i := len(p) - 1; i >= 0; i-- { + nextCarry := p[i] & 1 + p[i] >>= 1 + p[i] |= carry << (bitsPerWord - 1) + carry = nextCarry + } +} + +func (p *poly2) rot(bits uint) { + if bits > N { + panic("invalid") + } + var shifted [wordsPerPoly]uint + out := (*[wordsPerPoly]uint)(p) + + shift := uint(9) + for ; (1 << shift) >= bitsPerWord; shift-- { + rotWords(&shifted, out, 1<>shift)) + } + for ; shift < 9; shift-- { + rotBits(&shifted, out, 1<>shift)) + } +} + +type poly [N]uint16 + +func (in *poly) marshal(out []byte) { + p := in[:] + + for len(p) >= 8 { + out[0] = byte(p[0]) + out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) + out[2] = byte(p[1] >> 3) + out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) + out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) + out[5] = byte(p[3] >> 1) + out[6] = byte(p[3]>>9) | byte((p[4]&0x0f)<<4) + out[7] = byte(p[4] >> 4) + out[8] = byte(p[4]>>12) | byte((p[5]&0x7f)<<1) + out[9] = byte(p[5]>>7) | byte((p[6]&0x03)<<6) + out[10] = byte(p[6] >> 2) + out[11] = byte(p[6]>>10) | byte((p[7]&0x1f)<<3) + out[12] = byte(p[7] >> 5) + + p = p[8:] + out = out[13:] + } + + // There are four remaining values. + out[0] = byte(p[0]) + out[1] = byte(p[0]>>8) | byte((p[1]&0x07)<<5) + out[2] = byte(p[1] >> 3) + out[3] = byte(p[1]>>11) | byte((p[2]&0x3f)<<2) + out[4] = byte(p[2]>>6) | byte((p[3]&0x01)<<7) + out[5] = byte(p[3] >> 1) + out[6] = byte(p[3] >> 9) +} + +func (out *poly) unmarshal(in []byte) bool { + p := out[:] + for i := 0; i < 87; i++ { + p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 + p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 + p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 + p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 + p[4] = uint16(in[6]>>4) | uint16(in[7])<<4 | uint16(in[8]&1)<<12 + p[5] = uint16(in[8]>>1) | uint16(in[9]&0x3f)<<7 + p[6] = uint16(in[9]>>6) | uint16(in[10])<<2 | uint16(in[11]&7)<<10 + p[7] = uint16(in[11]>>3) | uint16(in[12])<<5 + + p = p[8:] + in = in[13:] + } + + // There are four coefficients left over + p[0] = uint16(in[0]) | uint16(in[1]&0x1f)<<8 + p[1] = uint16(in[1]>>5) | uint16(in[2])<<3 | uint16(in[3]&3)<<11 + p[2] = uint16(in[3]>>2) | uint16(in[4]&0x7f)<<6 + p[3] = uint16(in[4]>>7) | uint16(in[5])<<1 | uint16(in[6]&0xf)<<9 + + if in[6]&0xf0 != 0 { + return false + } + + out[N-1] = 0 + var top int + for _, v := range out { + top += int(v) + } + + out[N-1] = uint16(-top) % Q + return true +} + +func (in *poly) marshalS3(out []byte) { + p := in[:] + for len(p) >= 5 { + out[0] = byte(p[0] + p[1]*3 + p[2]*9 + p[3]*27 + p[4]*81) + out = out[1:] + p = p[5:] + } +} + +func (out *poly) unmarshalS3(in []byte) bool { + p := out[:] + for i := 0; i < 140; i++ { + c := in[0] + if c >= 243 { + return false + } + p[0] = uint16(c % 3) + p[1] = uint16((c / 3) % 3) + p[2] = uint16((c / 9) % 3) + p[3] = uint16((c / 27) % 3) + p[4] = uint16((c / 81) % 3) + + p = p[5:] + in = in[1:] + } + + out[N-1] = 0 + return true +} + +func (p *poly) modPhiN() { + for i := range p { + p[i] = (p[i] + Q - p[N-1]) % Q + } +} + +func (out *poly) shortSample(in []byte) { + // b a result + // 00 00 00 + // 00 01 01 + // 00 10 10 + // 00 11 11 + // 01 00 10 + // 01 01 00 + // 01 10 01 + // 01 11 11 + // 10 00 01 + // 10 01 10 + // 10 10 00 + // 10 11 11 + // 11 00 11 + // 11 01 11 + // 11 10 11 + // 11 11 11 + + // 1111 1111 1100 1001 1101 0010 1110 0100 + // f f c 9 d 2 e 4 + const lookup = uint32(0xffc9d2e4) + + p := out[:] + for i := 0; i < 87; i++ { + v := binary.LittleEndian.Uint32(in) + v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) + for j := 0; j < 8; j++ { + p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) + v2 >>= 4 + } + p = p[8:] + in = in[4:] + } + + // There are four values remaining. + v := binary.LittleEndian.Uint32(in) + v2 := (v & 0x55555555) + ((v >> 1) & 0x55555555) + for j := 0; j < 4; j++ { + p[j] = uint16(lookup >> ((v2 & 15) << 1) & 3) + v2 >>= 4 + } + + out[N-1] = 0 +} + +func (out *poly) shortSamplePlus(in []byte) { + out.shortSample(in) + + var sum uint16 + for i := 0; i < N-1; i++ { + sum += mod3ResultToModQ(out[i] * out[i+1]) + } + + scale := 1 + (1 & (sum >> 12)) + for i := 0; i < len(out); i += 2 { + out[i] = (out[i] * scale) % 3 + } +} + +func mul(out, scratch, a, b []uint16) { + const schoolbookLimit = 32 + if len(a) < schoolbookLimit { + for i := 0; i < len(a)*2; i++ { + out[i] = 0 + } + for i := range a { + for j := range b { + out[i+j] += a[i] * b[j] + } + } + return + } + + lowLen := len(a) / 2 + highLen := len(a) - lowLen + aLow, aHigh := a[:lowLen], a[lowLen:] + bLow, bHigh := b[:lowLen], b[lowLen:] + + for i := 0; i < lowLen; i++ { + out[i] = aHigh[i] + aLow[i] + } + if highLen != lowLen { + out[lowLen] = aHigh[lowLen] + } + + for i := 0; i < lowLen; i++ { + out[highLen+i] = bHigh[i] + bLow[i] + } + if highLen != lowLen { + out[highLen+lowLen] = bHigh[lowLen] + } + + mul(scratch, scratch[2*highLen:], out[:highLen], out[highLen:highLen*2]) + mul(out[lowLen*2:], scratch[2*highLen:], aHigh, bHigh) + mul(out, scratch[2*highLen:], aLow, bLow) + + for i := 0; i < lowLen*2; i++ { + scratch[i] -= out[i] + out[lowLen*2+i] + } + if lowLen != highLen { + scratch[lowLen*2] -= out[lowLen*4] + } + + for i := 0; i < 2*highLen; i++ { + out[lowLen+i] += scratch[i] + } +} + +func (out *poly) mul(a, b *poly) { + var prod, scratch [2 * N]uint16 + mul(prod[:], scratch[:], a[:], b[:]) + for i := range out { + out[i] = (prod[i] + prod[i+N]) % Q + } +} + +func (p3 *poly3) mulMod3(x, y *poly3) { + // (𝑥^n - 1) is a multiple of Φ(N) so we can work mod (𝑥^n - 1) here and + // (reduce mod Φ(N) afterwards. + x3 := *x + y3 := *y + s := x3.s[:] + a := x3.a[:] + sw := s[0] + aw := a[0] + p3.zero() + var shift uint + for i := 0; i < N; i++ { + p3.fmadd(sw, aw, &y3) + sw >>= 1 + aw >>= 1 + shift++ + if shift == bitsPerWord { + s = s[1:] + a = a[1:] + sw = s[0] + aw = a[0] + shift = 0 + } + y3.mulx() + } + p3.modPhiN() +} + +// mod3ToModQ maps {0, 1, 2, 3} to {0, 1, Q-1, 0xffff} +// The case of n == 3 should never happen but is included so that modQToMod3 +// can easily catch invalid inputs. +func mod3ToModQ(n uint16) uint16 { + return uint16(uint64(0xffff1fff00010000) >> (16 * n)) +} + +// modQToMod3 maps {0, 1, Q-1} to {(0, 0), (0, 1), (1, 0)} and also returns an int +// which is one if the input is in range and zero otherwise. +func modQToMod3(n uint16) (uint16, int) { + result := (n&3 - (n>>1)&1) + return result, subtle.ConstantTimeEq(int32(mod3ToModQ(result)), int32(n)) +} + +// mod3ResultToModQ maps {0, 1, 2, 4} to {0, 1, Q-1, 1} +func mod3ResultToModQ(n uint16) uint16 { + return ((((uint16(0x13) >> n) & 1) - 1) & 0x1fff) | ((uint16(0x12) >> n) & 1) + //shift := (uint(0x324) >> (2 * n)) & 3 + //return uint16(uint64(0x00011fff00010000) >> (16 * shift)) +} + +// mulXMinus1 sets out to a×(𝑥 - 1) mod (𝑥^n - 1) +func (out *poly) mulXMinus1() { + // Multiplying by (𝑥 - 1) means negating each coefficient and adding in + // the value of the previous one. + origOut700 := out[700] + + for i := N - 1; i > 0; i-- { + out[i] = (Q - out[i] + out[i-1]) % Q + } + out[0] = (Q - out[0] + origOut700) % Q +} + +func (out *poly) lift(a *poly) { + // We wish to calculate a/(𝑥-1) mod Φ(N) over GF(3), where Φ(N) is the + // Nth cyclotomic polynomial, i.e. 1 + 𝑥 + … + 𝑥^700 (since N is prime). + + // 1/(𝑥-1) has a fairly basic structure that we can exploit to speed this up: + // + // R. = PolynomialRing(GF(3)…) + // inv = R.cyclotomic_polynomial(1).inverse_mod(R.cyclotomic_polynomial(n)) + // list(inv)[:15] + // [1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2] + // + // This three-element pattern of coefficients repeats for the whole + // polynomial. + // + // Next define the overbar operator such that z̅ = z[0] + + // reverse(z[1:]). (Index zero of a polynomial here is the coefficient + // of the constant term. So index one is the coefficient of 𝑥 and so + // on.) + // + // A less odd way to define this is to see that z̅ negates the indexes, + // so z̅[0] = z[-0], z̅[1] = z[-1] and so on. + // + // The use of z̅ is that, when working mod (𝑥^701 - 1), vz[0] = , vz[1] = , …. (Where is the inner product: the sum + // of the point-wise products.) Although we calculated the inverse mod + // Φ(N), we can work mod (𝑥^N - 1) and reduce mod Φ(N) at the end. + // (That's because (𝑥^N - 1) is a multiple of Φ(N).) + // + // When working mod (𝑥^N - 1), multiplication by 𝑥 is a right-rotation + // of the list of coefficients. + // + // Thus we can consider what the pattern of z̅, 𝑥z̅, 𝑥^2z̅, … looks like: + // + // def reverse(xs): + // suffix = list(xs[1:]) + // suffix.reverse() + // return [xs[0]] + suffix + // + // def rotate(xs): + // return [xs[-1]] + xs[:-1] + // + // zoverbar = reverse(list(inv) + [0]) + // xzoverbar = rotate(reverse(list(inv) + [0])) + // x2zoverbar = rotate(rotate(reverse(list(inv) + [0]))) + // + // zoverbar[:15] + // [1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1] + // xzoverbar[:15] + // [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0] + // x2zoverbar[:15] + // [2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] + // + // (For a formula for z̅, see lemma two of appendix B.) + // + // After the first three elements have been taken care of, all then have + // a repeating three-element cycle. The next value (𝑥^3z̅) involves + // three rotations of the first pattern, thus the three-element cycle + // lines up. However, the discontinuity in the first three elements + // obviously moves to a different position. Consider the difference + // between 𝑥^3z̅ and z̅: + // + // [x-y for (x,y) in zip(zoverbar, x3zoverbar)][:15] + // [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + // + // This pattern of differences is the same for all elements, although it + // obviously moves right with the rotations. + // + // From this, we reach algorithm eight of appendix B. + + // Handle the first three elements of the inner products. + out[0] = a[0] + a[2] + out[1] = a[1] + out[2] = 2*a[0] + a[2] + + // Use the repeating pattern to complete the first three inner products. + for i := 3; i < 699; i += 3 { + out[0] += 2*a[i] + a[i+2] + out[1] += a[i] + 2*a[i+1] + out[2] += a[i+1] + 2*a[i+2] + } + + // Handle the fact that the three-element pattern doesn't fill the + // polynomial exactly (since 701 isn't a multiple of three). + out[2] += a[700] + out[0] += 2 * a[699] + out[1] += a[699] + 2*a[700] + + out[0] = out[0] % 3 + out[1] = out[1] % 3 + out[2] = out[2] % 3 + + // Calculate the remaining inner products by taking advantage of the + // fact that the pattern repeats every three cycles and the pattern of + // differences is moves with the rotation. + for i := 3; i < N; i++ { + // Add twice something is the same as subtracting when working + // mod 3. Doing it this way avoids underflow. Underflow is bad + // because "% 3" doesn't work correctly for negative numbers + // here since underflow will wrap to 2^16-1 and 2^16 isn't a + // multiple of three. + out[i] = (out[i-3] + 2*(a[i-2]+a[i-1]+a[i])) % 3 + } + + // Reduce mod Φ(N) by subtracting a multiple of out[700] from every + // element and convert to mod Q. (See above about adding twice as + // subtraction.) + v := out[700] * 2 + for i := range out { + out[i] = mod3ToModQ((out[i] + v) % 3) + } + + out.mulXMinus1() +} + +func (a *poly) cswap(b *poly, swap uint16) { + for i := range a { + sum := swap & (a[i] ^ b[i]) + a[i] ^= sum + b[i] ^= sum + } +} + +func lt(a, b uint) uint { + if a < b { + return ^uint(0) + } + return 0 +} + +func bsMul(s1, a1, s2, a2 uint) (s3, a3 uint) { + s3 = (a1 & s2) ^ (s1 & a2) + a3 = (a1 & a2) ^ (s1 & s2) + return +} + +func (out *poly3) invertMod3(in *poly3) { + // This algorithm follows algorithm 10 in the paper. (Although note that + // the paper appears to have a bug: k should start at zero, not one.) + // The best explanation for why it works is in the "Why it works" + // section of + // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. + var k uint + degF, degG := uint(N-1), uint(N-1) + + var b, c, g poly3 + f := *in + + for i := range g.a { + g.a[i] = ^uint(0) + } + + b.a[0] = 1 + + var f0s, f0a uint + stillGoing := ^uint(0) + for i := 0; i < 2*(N-1)-1; i++ { + ss, sa := bsMul(f.s[0], f.a[0], g.s[0], g.a[0]) + ss, sa = sa&stillGoing&1, ss&stillGoing&1 + shouldSwap := ^uint(int((ss|sa)-1)>>(bitsPerWord-1)) & lt(degF, degG) + f.cswap(&g, shouldSwap) + b.cswap(&c, shouldSwap) + degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) + f.fmadd(ss, sa, &g) + b.fmadd(ss, sa, &c) + + f.divx() + f.s[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 + f.a[wordsPerPoly-1] &= ((1 << bitsInLastWord) - 1) >> 1 + c.mulx() + c.s[0] &= ^uint(1) + c.a[0] &= ^uint(1) + + degF-- + k += 1 & stillGoing + f0s = (stillGoing & f.s[0]) | (^stillGoing & f0s) + f0a = (stillGoing & f.a[0]) | (^stillGoing & f0a) + stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) + } + + k -= N & lt(N, k) + *out = b + out.rot(k) + out.mulConst(f0s, f0a) + out.modPhiN() +} + +func (out *poly) invertMod2(a *poly) { + // This algorithm follows mix of algorithm 10 in the paper and the first + // page of the PDF linked below. (Although note that the paper appears + // to have a bug: k should start at zero, not one.) The best explanation + // for why it works is in the "Why it works" section of + // https://assets.onboardsecurity.com/static/downloads/NTRU/resources/NTRUTech014.pdf. + var k uint + degF, degG := uint(N-1), uint(N-1) + + var f poly2 + f.fromDiscrete(a) + var b, c, g poly2 + g.setPhiN() + b[0] = 1 + + stillGoing := ^uint(0) + for i := 0; i < 2*(N-1)-1; i++ { + s := uint(f[0]&1) & stillGoing + shouldSwap := ^(s - 1) & lt(degF, degG) + f.cswap(&g, shouldSwap) + b.cswap(&c, shouldSwap) + degF, degG = (degG&shouldSwap)|(degF & ^shouldSwap), (degF&shouldSwap)|(degG&^shouldSwap) + f.fmadd(s, &g) + b.fmadd(s, &c) + + f.rshift1() + c.lshift1() + + degF-- + k += 1 & stillGoing + stillGoing = ^uint(int(degF-1) >> (bitsPerWord - 1)) + } + + k -= N & lt(N, k) + b.rot(k) + out.fromMod2(&b) +} + +func (out *poly) invert(origA *poly) { + // Inversion mod Q, which is done based on the result of inverting mod + // 2. See the NTRU paper, page three. + var a, tmp, tmp2, b poly + b.invertMod2(origA) + + // Negate a. + for i := range a { + a[i] = Q - origA[i] + } + + // We are working mod Q=2**13 and we need to iterate ceil(log_2(13)) + // times, which is four. + for i := 0; i < 4; i++ { + tmp.mul(&a, &b) + tmp[0] += 2 + tmp2.mul(&b, &tmp) + b = tmp2 + } + + *out = b +} + +type PublicKey struct { + h poly +} + +func ParsePublicKey(in []byte) (*PublicKey, bool) { + ret := new(PublicKey) + if !ret.h.unmarshal(in) { + return nil, false + } + return ret, true +} + +func (pub *PublicKey) Marshal() []byte { + ret := make([]byte, modQBytes) + pub.h.marshal(ret) + return ret +} + +func (pub *PublicKey) Encap(rand io.Reader) (ciphertext []byte, sharedKey []byte) { + var randBytes [352 + 352]byte + if _, err := io.ReadFull(rand, randBytes[:]); err != nil { + panic("rand failed") + } + + var m, r poly + m.shortSample(randBytes[:352]) + r.shortSample(randBytes[352:]) + + var mBytes, rBytes [mod3Bytes]byte + m.marshalS3(mBytes[:]) + r.marshalS3(rBytes[:]) + + ciphertext = pub.owf(&m, &r) + + h := sha256.New() + h.Write([]byte("shared key\x00")) + h.Write(mBytes[:]) + h.Write(rBytes[:]) + h.Write(ciphertext) + sharedKey = h.Sum(nil) + + return ciphertext, sharedKey +} + +func (pub *PublicKey) owf(m, r *poly) []byte { + for i := range r { + r[i] = mod3ToModQ(r[i]) + } + + var mq poly + mq.lift(m) + + var e poly + e.mul(r, &pub.h) + for i := range e { + e[i] = (e[i] + mq[i]) % Q + } + + ret := make([]byte, modQBytes) + e.marshal(ret[:]) + return ret +} + +type PrivateKey struct { + PublicKey + f, fp poly3 + hInv poly + hmacKey [32]byte +} + +func (priv *PrivateKey) Marshal() []byte { + var ret [2*mod3Bytes + modQBytes]byte + priv.f.marshal(ret[:]) + priv.fp.marshal(ret[mod3Bytes:]) + priv.h.marshal(ret[2*mod3Bytes:]) + return ret[:] +} + +func (priv *PrivateKey) Decap(ciphertext []byte) (sharedKey []byte, ok bool) { + if len(ciphertext) != modQBytes { + return nil, false + } + + var e poly + if !e.unmarshal(ciphertext) { + return nil, false + } + + var f poly + f.fromMod3ToModQ(&priv.f) + + var v1, m poly + v1.mul(&e, &f) + + var v13 poly3 + v13.fromDiscreteMod3(&v1) + // Note: v13 is not reduced mod phi(n). + + var m3 poly3 + m3.mulMod3(&v13, &priv.fp) + m3.modPhiN() + m.fromMod3(&m3) + + var mLift, delta poly + mLift.lift(&m) + for i := range delta { + delta[i] = (e[i] - mLift[i] + Q) % Q + } + delta.mul(&delta, &priv.hInv) + delta.modPhiN() + + var r poly3 + allOk := r.fromModQ(&delta) + + var mBytes, rBytes [mod3Bytes]byte + m.marshalS3(mBytes[:]) + r.marshal(rBytes[:]) + + var rPoly poly + rPoly.fromMod3(&r) + expectedCiphertext := priv.PublicKey.owf(&m, &rPoly) + + allOk &= subtle.ConstantTimeCompare(ciphertext, expectedCiphertext) + + hmacHash := hmac.New(sha256.New, priv.hmacKey[:]) + hmacHash.Write(ciphertext) + hmacDigest := hmacHash.Sum(nil) + + h := sha256.New() + h.Write([]byte("shared key\x00")) + h.Write(mBytes[:]) + h.Write(rBytes[:]) + h.Write(ciphertext) + sharedKey = h.Sum(nil) + + mask := uint8(allOk - 1) + for i := range sharedKey { + sharedKey[i] = (sharedKey[i] & ^mask) | (hmacDigest[i] & mask) + } + + return sharedKey, true +} + +func GenerateKey(rand io.Reader) PrivateKey { + var randBytes [352 + 352]byte + if _, err := io.ReadFull(rand, randBytes[:]); err != nil { + panic("rand failed") + } + + var f poly + f.shortSamplePlus(randBytes[:352]) + var priv PrivateKey + priv.f.fromDiscrete(&f) + priv.fp.invertMod3(&priv.f) + + var g poly + g.shortSamplePlus(randBytes[352:]) + + var pgPhi1 poly + for i := range g { + pgPhi1[i] = mod3ToModQ(g[i]) + } + for i := range pgPhi1 { + pgPhi1[i] = (pgPhi1[i] * 3) % Q + } + pgPhi1.mulXMinus1() + + var fModQ poly + fModQ.fromMod3ToModQ(&priv.f) + + var pfgPhi1 poly + pfgPhi1.mul(&fModQ, &pgPhi1) + + var i poly + i.invert(&pfgPhi1) + + priv.h.mul(&i, &pgPhi1) + priv.h.mul(&priv.h, &pgPhi1) + + priv.hInv.mul(&i, &fModQ) + priv.hInv.mul(&priv.hInv, &fModQ) + + return priv +}