diff --git a/handshake_messages.go b/handshake_messages.go index ab8e60a..2ea4ddb 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -802,12 +802,9 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } l := int(d[0])<<8 | int(d[1]) d = d[2:] - if len(d) != l { + if len(d) != l || l == 0 { return false } - if l == 0 { - continue - } m.scts = make([][]byte, 0, 3) for len(d) != 0 { diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 95d825b..cb3634c 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -5,6 +5,7 @@ package tls import ( + "bytes" "math/rand" "reflect" "testing" @@ -260,3 +261,47 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { } return reflect.ValueOf(s) } + +func TestRejectEmptySCTList(t *testing.T) { + // https://tools.ietf.org/html/rfc6962#section-3.3.1 specifies that + // empty SCT lists are invalid. + + var random [32]byte + sct := []byte{0x42, 0x42, 0x42, 0x42} + serverHello := serverHelloMsg{ + vers: VersionTLS12, + random: random[:], + scts: [][]byte{sct}, + } + serverHelloBytes := serverHello.marshal() + + var serverHelloCopy serverHelloMsg + if !serverHelloCopy.unmarshal(serverHelloBytes) { + t.Fatal("Failed to unmarshal initial message") + } + + // Change serverHelloBytes so that the SCT list is empty + i := bytes.Index(serverHelloBytes, sct) + if i < 0 { + t.Fatal("Cannot find SCT in ServerHello") + } + + var serverHelloEmptySCT []byte + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[:i-6]...) + // Append the extension length and SCT list length for an empty list. + serverHelloEmptySCT = append(serverHelloEmptySCT, []byte{0, 2, 0, 0}...) + serverHelloEmptySCT = append(serverHelloEmptySCT, serverHelloBytes[i+4:]...) + + // Update the handshake message length. + serverHelloEmptySCT[1] = byte((len(serverHelloEmptySCT) - 4) >> 16) + serverHelloEmptySCT[2] = byte((len(serverHelloEmptySCT) - 4) >> 8) + serverHelloEmptySCT[3] = byte(len(serverHelloEmptySCT) - 4) + + // Update the extensions length + serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) + serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) + + if serverHelloCopy.unmarshal(serverHelloEmptySCT) { + t.Fatal("Unmarshaled ServerHello with empty SCT list") + } +}