c82b70155d
(Code mostly due to agl.) Change-Id: Iec77396141954e5f8e845cc261eadab77f551f08 Reviewed-on: https://boringssl-review.googlesource.com/7990 Reviewed-by: Adam Langley <agl@google.com>
141 lines
2.9 KiB
Go
141 lines
2.9 KiB
Go
package newhope
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"testing"
|
|
)
|
|
|
|
func TestNTTRoundTrip(t *testing.T) {
|
|
var a Poly
|
|
for i := range a {
|
|
a[i] = uint16(i)
|
|
}
|
|
|
|
frequency := forwardNTT(&a)
|
|
original := inverseNTT(frequency)
|
|
|
|
for i, v := range a {
|
|
if v != original[i] {
|
|
t.Errorf("NTT didn't invert correctly: original[%d] = %d", i, original[i])
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNTTInv(t *testing.T) {
|
|
var a Poly
|
|
for i := range a {
|
|
a[i] = uint16(i)
|
|
}
|
|
|
|
result := ntt(&a, invOmega, 1, invSqrtOmega, invN)
|
|
if result[0] != 6656 || result[1] != 1792 || result[2] != 1234 {
|
|
t.Errorf("NTT^-1 gave bad result: %v", result[:8])
|
|
}
|
|
}
|
|
|
|
func disabledTestNoise(t *testing.T) {
|
|
var buckets [1 + 2*k]int
|
|
numSamples := 100
|
|
|
|
for i := 0; i < numSamples; i++ {
|
|
noise := sampleNoise(rand.Reader)
|
|
for _, v := range noise {
|
|
value := (int(v) + k) % q
|
|
buckets[value]++
|
|
}
|
|
}
|
|
|
|
sum := 0
|
|
squareSum := 0
|
|
|
|
for i, count := range buckets {
|
|
sum += (i - k) * count
|
|
squareSum += (i - k) * (i - k) * count
|
|
}
|
|
|
|
mean := float64(sum) / float64(n*numSamples)
|
|
if mean < -0.5 || 0.5 < mean {
|
|
t.Errorf("mean out of range: %f", mean)
|
|
}
|
|
|
|
expectedVariance := 0.5 * 0.5 * float64(k*2) // I think?
|
|
variance := float64(squareSum)/float64(n*numSamples) - mean*mean
|
|
|
|
if variance < expectedVariance-1.0 || expectedVariance+1.0 < variance {
|
|
t.Errorf("variance out of range: got %f, want %f", variance, expectedVariance)
|
|
}
|
|
|
|
file, err := ioutil.TempFile("", "noise")
|
|
fmt.Printf("writing noise to %s\n", file.Name())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for i, count := range buckets {
|
|
dots := ""
|
|
for i := 0; i < count/(3*numSamples); i++ {
|
|
dots += "++"
|
|
}
|
|
fmt.Fprintf(file, "%+d\t%d\t%s\n", i-k, count, dots)
|
|
}
|
|
file.Close()
|
|
}
|
|
|
|
func TestSeedToPolynomial(t *testing.T) {
|
|
seed := make([]byte, 32)
|
|
seed[0] = 1
|
|
seed[31] = 2
|
|
|
|
poly := seedToPolynomial(seed)
|
|
if poly[0] != 3313 || poly[1] != 9277 || poly[2] != 11020 {
|
|
t.Errorf("bad result: %v", poly[:3])
|
|
}
|
|
}
|
|
|
|
func TestEncodeDecodePoly(t *testing.T) {
|
|
poly := randomPolynomial(rand.Reader)
|
|
poly2 := decodePoly(encodePoly(poly))
|
|
if *poly != *poly2 {
|
|
t.Errorf("decodePoly(encodePoly) isn't the identity function")
|
|
}
|
|
}
|
|
|
|
func TestEncodeDecodeRec(t *testing.T) {
|
|
var r reconciliationData
|
|
if _, err := io.ReadFull(rand.Reader, r[:]); err != nil {
|
|
panic(err)
|
|
}
|
|
for i := range r {
|
|
r[i] &= 3
|
|
}
|
|
|
|
encoded := encodeRec(&r)
|
|
decoded := decodeRec(encoded)
|
|
|
|
if *decoded != r {
|
|
t.Errorf("bad decode of rec")
|
|
}
|
|
}
|
|
|
|
func TestExchange(t *testing.T) {
|
|
for count := 0; count < 64; count++ {
|
|
offerMsg, state := Offer(rand.Reader)
|
|
sharedKey1, acceptMsg, err := Accept(rand.Reader, offerMsg)
|
|
if err != nil {
|
|
t.Errorf("Accept: %v", err)
|
|
}
|
|
sharedKey2, err := state.Finish(acceptMsg)
|
|
if err != nil {
|
|
t.Fatal("Finish: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(sharedKey1[:], sharedKey2[:]) {
|
|
t.Fatalf("keys mismatched on iteration %d: %x vs %x", count, sharedKey1, sharedKey2)
|
|
}
|
|
}
|
|
}
|