crypto/tls: fix handshake message test
This test breaks when I make reflect.DeepEqual distinguish empty slices from nil slices. R=agl CC=golang-dev https://golang.org/cl/5369110
This commit is contained in:
parent
a070fcf2bd
commit
30373ac5f7
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
package tls
|
package tls
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
type clientHelloMsg struct {
|
type clientHelloMsg struct {
|
||||||
raw []byte
|
raw []byte
|
||||||
vers uint16
|
vers uint16
|
||||||
@ -18,6 +20,25 @@ type clientHelloMsg struct {
|
|||||||
supportedPoints []uint8
|
supportedPoints []uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*clientHelloMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.vers == m1.vers &&
|
||||||
|
bytes.Equal(m.random, m1.random) &&
|
||||||
|
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||||
|
eqUint16s(m.cipherSuites, m1.cipherSuites) &&
|
||||||
|
bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
|
||||||
|
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||||
|
m.serverName == m1.serverName &&
|
||||||
|
m.ocspStapling == m1.ocspStapling &&
|
||||||
|
eqUint16s(m.supportedCurves, m1.supportedCurves) &&
|
||||||
|
bytes.Equal(m.supportedPoints, m1.supportedPoints)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *clientHelloMsg) marshal() []byte {
|
func (m *clientHelloMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -309,6 +330,23 @@ type serverHelloMsg struct {
|
|||||||
ocspStapling bool
|
ocspStapling bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *serverHelloMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*serverHelloMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.vers == m1.vers &&
|
||||||
|
bytes.Equal(m.random, m1.random) &&
|
||||||
|
bytes.Equal(m.sessionId, m1.sessionId) &&
|
||||||
|
m.cipherSuite == m1.cipherSuite &&
|
||||||
|
m.compressionMethod == m1.compressionMethod &&
|
||||||
|
m.nextProtoNeg == m1.nextProtoNeg &&
|
||||||
|
eqStrings(m.nextProtos, m1.nextProtos) &&
|
||||||
|
m.ocspStapling == m1.ocspStapling
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverHelloMsg) marshal() []byte {
|
func (m *serverHelloMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -463,6 +501,16 @@ type certificateMsg struct {
|
|||||||
certificates [][]byte
|
certificates [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
eqByteSlices(m.certificates, m1.certificates)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateMsg) marshal() (x []byte) {
|
func (m *certificateMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
|
|||||||
key []byte
|
key []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*serverKeyExchangeMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.key, m1.key)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverKeyExchangeMsg) marshal() []byte {
|
func (m *serverKeyExchangeMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -571,6 +629,17 @@ type certificateStatusMsg struct {
|
|||||||
response []byte
|
response []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateStatusMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateStatusMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.statusType == m1.statusType &&
|
||||||
|
bytes.Equal(m.response, m1.response)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateStatusMsg) marshal() []byte {
|
func (m *certificateStatusMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
|
|||||||
|
|
||||||
type serverHelloDoneMsg struct{}
|
type serverHelloDoneMsg struct{}
|
||||||
|
|
||||||
|
func (m *serverHelloDoneMsg) equal(i interface{}) bool {
|
||||||
|
_, ok := i.(*serverHelloDoneMsg)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (m *serverHelloDoneMsg) marshal() []byte {
|
func (m *serverHelloDoneMsg) marshal() []byte {
|
||||||
x := make([]byte, 4)
|
x := make([]byte, 4)
|
||||||
x[0] = typeServerHelloDone
|
x[0] = typeServerHelloDone
|
||||||
@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
|
|||||||
ciphertext []byte
|
ciphertext []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*clientKeyExchangeMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.ciphertext, m1.ciphertext)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *clientKeyExchangeMsg) marshal() []byte {
|
func (m *clientKeyExchangeMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -671,6 +755,16 @@ type finishedMsg struct {
|
|||||||
verifyData []byte
|
verifyData []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *finishedMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*finishedMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.verifyData, m1.verifyData)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *finishedMsg) marshal() (x []byte) {
|
func (m *finishedMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -698,6 +792,16 @@ type nextProtoMsg struct {
|
|||||||
proto string
|
proto string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *nextProtoMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*nextProtoMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
m.proto == m1.proto
|
||||||
|
}
|
||||||
|
|
||||||
func (m *nextProtoMsg) marshal() []byte {
|
func (m *nextProtoMsg) marshal() []byte {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -759,6 +863,17 @@ type certificateRequestMsg struct {
|
|||||||
certificateAuthorities [][]byte
|
certificateAuthorities [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateRequestMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateRequestMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
|
||||||
|
eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateRequestMsg) marshal() (x []byte) {
|
func (m *certificateRequestMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
|
|||||||
signature []byte
|
signature []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *certificateVerifyMsg) equal(i interface{}) bool {
|
||||||
|
m1, ok := i.(*certificateVerifyMsg)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Equal(m.raw, m1.raw) &&
|
||||||
|
bytes.Equal(m.signature, m1.signature)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *certificateVerifyMsg) marshal() (x []byte) {
|
func (m *certificateVerifyMsg) marshal() (x []byte) {
|
||||||
if m.raw != nil {
|
if m.raw != nil {
|
||||||
return m.raw
|
return m.raw
|
||||||
@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func eqUint16s(x, y []uint16) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if y[i] != v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func eqStrings(x, y []string) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if y[i] != v {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func eqByteSlices(x, y [][]byte) bool {
|
||||||
|
if len(x) != len(y) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range x {
|
||||||
|
if !bytes.Equal(v, y[i]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
@ -27,10 +27,12 @@ var tests = []interface{}{
|
|||||||
type testMessage interface {
|
type testMessage interface {
|
||||||
marshal() []byte
|
marshal() []byte
|
||||||
unmarshal([]byte) bool
|
unmarshal([]byte) bool
|
||||||
|
equal(interface{}) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMarshalUnmarshal(t *testing.T) {
|
func TestMarshalUnmarshal(t *testing.T) {
|
||||||
rand := rand.New(rand.NewSource(0))
|
rand := rand.New(rand.NewSource(0))
|
||||||
|
|
||||||
for i, iface := range tests {
|
for i, iface := range tests {
|
||||||
ty := reflect.ValueOf(iface).Type()
|
ty := reflect.ValueOf(iface).Type()
|
||||||
|
|
||||||
@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
|
|||||||
}
|
}
|
||||||
m2.marshal() // to fill any marshal cache in the message
|
m2.marshal() // to fill any marshal cache in the message
|
||||||
|
|
||||||
if !reflect.DeepEqual(m1, m2) {
|
if !m1.equal(m2) {
|
||||||
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user