141 lines
2.9 KiB
Go
141 lines
2.9 KiB
Go
|
package newhope
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/rand"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func TestNTTRoundTrip(t *testing.T) {
|
||
|
var a Poly
|
||
|
for i := range a {
|
||
|
a[i] = uint16(i)
|
||
|
}
|
||
|
|
||
|
frequency := forwardNTT(&a)
|
||
|
original := inverseNTT(frequency)
|
||
|
|
||
|
for i, v := range a {
|
||
|
if v != original[i] {
|
||
|
t.Errorf("NTT didn't invert correctly: original[%d] = %d", i, original[i])
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestNTTInv(t *testing.T) {
|
||
|
var a Poly
|
||
|
for i := range a {
|
||
|
a[i] = uint16(i)
|
||
|
}
|
||
|
|
||
|
result := ntt(&a, invOmega, 1, invSqrtOmega, invN)
|
||
|
if result[0] != 6656 || result[1] != 1792 || result[2] != 1234 {
|
||
|
t.Errorf("NTT^-1 gave bad result: %v", result[:8])
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func disabledTestNoise(t *testing.T) {
|
||
|
var buckets [1 + 2*k]int
|
||
|
numSamples := 100
|
||
|
|
||
|
for i := 0; i < numSamples; i++ {
|
||
|
noise := sampleNoise(rand.Reader)
|
||
|
for _, v := range noise {
|
||
|
value := (int(v) + k) % q
|
||
|
buckets[value]++
|
||
|
}
|
||
|
}
|
||
|
|
||
|
sum := 0
|
||
|
squareSum := 0
|
||
|
|
||
|
for i, count := range buckets {
|
||
|
sum += (i - k) * count
|
||
|
squareSum += (i - k) * (i - k) * count
|
||
|
}
|
||
|
|
||
|
mean := float64(sum) / float64(n*numSamples)
|
||
|
if mean < -0.5 || 0.5 < mean {
|
||
|
t.Errorf("mean out of range: %f", mean)
|
||
|
}
|
||
|
|
||
|
expectedVariance := 0.5 * 0.5 * float64(k*2) // I think?
|
||
|
variance := float64(squareSum)/float64(n*numSamples) - mean*mean
|
||
|
|
||
|
if variance < expectedVariance-1.0 || expectedVariance+1.0 < variance {
|
||
|
t.Errorf("variance out of range: got %f, want %f", variance, expectedVariance)
|
||
|
}
|
||
|
|
||
|
file, err := ioutil.TempFile("", "noise")
|
||
|
fmt.Printf("writing noise to %s\n", file.Name())
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
for i, count := range buckets {
|
||
|
dots := ""
|
||
|
for i := 0; i < count/(3*numSamples); i++ {
|
||
|
dots += "++"
|
||
|
}
|
||
|
fmt.Fprintf(file, "%+d\t%d\t%s\n", i-k, count, dots)
|
||
|
}
|
||
|
file.Close()
|
||
|
}
|
||
|
|
||
|
func TestSeedToPolynomial(t *testing.T) {
|
||
|
seed := make([]byte, 32)
|
||
|
seed[0] = 1
|
||
|
seed[31] = 2
|
||
|
|
||
|
poly := seedToPolynomial(seed)
|
||
|
if poly[0] != 3313 || poly[1] != 9277 || poly[2] != 11020 {
|
||
|
t.Errorf("bad result: %v", poly[:3])
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEncodeDecodePoly(t *testing.T) {
|
||
|
poly := randomPolynomial(rand.Reader)
|
||
|
poly2 := decodePoly(encodePoly(poly))
|
||
|
if *poly != *poly2 {
|
||
|
t.Errorf("decodePoly(encodePoly) isn't the identity function")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEncodeDecodeRec(t *testing.T) {
|
||
|
var r reconciliationData
|
||
|
if _, err := io.ReadFull(rand.Reader, r[:]); err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
for i := range r {
|
||
|
r[i] &= 3
|
||
|
}
|
||
|
|
||
|
encoded := encodeRec(&r)
|
||
|
decoded := decodeRec(encoded)
|
||
|
|
||
|
if *decoded != r {
|
||
|
t.Errorf("bad decode of rec")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestExchange(t *testing.T) {
|
||
|
for count := 0; count < 64; count++ {
|
||
|
offerMsg, state := Offer(rand.Reader)
|
||
|
sharedKey1, acceptMsg, err := Accept(rand.Reader, offerMsg)
|
||
|
if err != nil {
|
||
|
t.Errorf("Accept: %v", err)
|
||
|
}
|
||
|
sharedKey2, err := state.Finish(acceptMsg)
|
||
|
if err != nil {
|
||
|
t.Fatal("Finish: %v", err)
|
||
|
}
|
||
|
|
||
|
if !bytes.Equal(sharedKey1[:], sharedKey2[:]) {
|
||
|
t.Fatalf("keys mismatched on iteration %d: %x vs %x", count, sharedKey1, sharedKey2)
|
||
|
}
|
||
|
}
|
||
|
}
|