diff --git a/dh/sidh/sidh.go b/dh/sidh/sidh.go index 2d11e64..50d9b9d 100644 --- a/dh/sidh/sidh.go +++ b/dh/sidh/sidh.go @@ -186,7 +186,7 @@ func (prv *PrivateKey) generatePrivateKeyA(rand io.Reader) error { // shared secret computation. func (prv *PrivateKey) generatePrivateKeyB(rand io.Reader) error { // Perform rejection sampling to obtain a random value in [0,3^238]: - var ok uint64 + var ok uint8 for i := uint(0); i < prv.params.SampleRate; i++ { _, err := io.ReadFull(rand, prv.Scalar) if err != nil { diff --git a/dh/sidh/sidh_amd64.s b/dh/sidh/sidh_amd64.s index 4b1e4da..9d24181 100644 --- a/dh/sidh/sidh_amd64.s +++ b/dh/sidh/sidh_amd64.s @@ -10,9 +10,9 @@ #define THREE238M1_4 $0xb858a87e8f4222c7 #define THREE238M1_5 $0x254c9c6b525eaf5 -// Set result to zero if the input scalar is <= 3^238. scalar must be 48-byte array -// of bytes. -// func checkLessThanThree238(s_base uintptr, s_len uint, s_cap uint) uint64 +// Set result to zero if the input scalar is <= 3^238, otherwise result is 1. +// Scalar must be array of 48 bytes +// func checkLessThanThree238(s_base uintptr, s_len uint, s_cap uint) uint8 TEXT ·checkLessThanThree238(SB), NOSPLIT, $0-16 MOVQ scalar+0(FP), SI @@ -34,9 +34,9 @@ TEXT ·checkLessThanThree238(SB), NOSPLIT, $0-16 SBBQ 32(SI), R14 SBBQ 40(SI), R15 - // Save borrow flag indicating 3^238 - scalar < 0 as a mask in AX (eax) - SBBL $0, AX - MOVL AX, ret+24(FP) + // Save borrow flag indicating 3^238 - scalar < 0 as a mask in AX (rax) + ADCB $0, AX + MOVB AX, ret+24(FP) RET diff --git a/dh/sidh/sidh_decl.go b/dh/sidh/sidh_decl.go index f5ea729..d3ec68e 100644 --- a/dh/sidh/sidh_decl.go +++ b/dh/sidh/sidh_decl.go @@ -2,10 +2,10 @@ package sidh -// Returns zero if the input scalar is <= 3^238. scalar must be 48-byte array -// of bytes. This function is specific to P751. +// Set result to zero if the input scalar is <= 3^238, otherwise result is 1. +// Scalar must be array of 48 bytes. This function is specific to P751. //go:noescape -func checkLessThanThree238(scalar []byte) uint64 +func checkLessThanThree238(scalar []byte) uint8 // Multiply 48-byte scalar by 3 to get a scalar in 3*[0,3^238). This // function is specific to P751. diff --git a/dh/sidh/sidh_generic.go b/dh/sidh/sidh_generic.go index 03f8e03..ef45b93 100644 --- a/dh/sidh/sidh_generic.go +++ b/dh/sidh/sidh_generic.go @@ -27,14 +27,14 @@ func subc8(bIn, a, b uint8) (ret, bOut uint8) { return } -// Set result to zero if the input scalar is <= 3^238. scalar must be 48-byte array -// of bytes. This function is specific to P751. -func checkLessThanThree238(scalar []byte) uint64 { +// Set result to zero if the input scalar is <= 3^238, otherwise result is 1. +// Scalar must be array of 48 bytes. This function is specific to P751. +func checkLessThanThree238(scalar []byte) uint8 { var borrow uint8 for i := 0; i < len(three238m1); i++ { _, borrow = subc8(borrow, three238m1[i], scalar[i]) } - return uint64(borrow) + return borrow } // Multiply 48-byte scalar by 3 to get a scalar in 3*[0,3^238). This diff --git a/dh/sidh/sidh_test.go b/dh/sidh/sidh_test.go index e0848df..d6f10d3 100644 --- a/dh/sidh/sidh_test.go +++ b/dh/sidh/sidh_test.go @@ -259,19 +259,28 @@ func TestCheckLessThanThree238(t *testing.T) { 212, 191, 53, 59, 115, 56, 207, 215, 148, 207, 41, 130, 248, 214, 42, 124, 12, 153, 108, 197, 99, 199, 34, 66, 143, 126, 168, 88, 184, 245, 234, 37, 181, 198, 201, 84, 2} + // makes second 64-bit digits bigger than in three238. checks if carries are correctly propagated + var three238plus2power65 = [48]byte{249, 132, 131, 130, 138, 113, 205, 237, 22, 122, + 66, 212, 191, 53, 59, 115, 56, 207, 215, 148, 207, 41, 130, 248, 214, 42, 124, 12, + 153, 108, 197, 99, 199, 34, 66, 143, 126, 168, 88, 184, 245, 234, 37, 181, 198, + 201, 84, 2} - var result uint64 + var result uint8 result = checkLessThanThree238(three238minus1[:]) if result != 0 { t.Error("expected 0, got", result) } result = checkLessThanThree238(three238[:]) - if result == 0 { + if result != 1 { t.Error("expected nonzero, got", result) } result = checkLessThanThree238(three238plus1[:]) - if result == 0 { + if result != 1 { + t.Error("expected nonzero, got", result) + } + result = checkLessThanThree238(three238plus2power65[:]) + if result != 1 { t.Error("expected nonzero, got", result) } }