From 4e6ebb63dd8304f3f08eb6b97704795fc8c8f92b Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Fri, 24 Nov 2017 18:10:50 +0000 Subject: [PATCH] tris: unify ServerHello processing in preparation for D22 Merge serverHelloMsg13 into serverHelloMsg in preparation for draft 22. This will also simplify the client implementation since only one structure has to be checked. Also fixed potential out-of-bounds access with keyShare unmarshal. --- 13.go | 12 +- common.go | 1 + conn.go | 6 +- handshake_messages.go | 233 +++++++++++++++---------------------- handshake_messages_test.go | 15 +-- handshake_server.go | 13 +-- 6 files changed, 111 insertions(+), 169 deletions(-) diff --git a/13.go b/13.go index 73b7d6d..55f7972 100644 --- a/13.go +++ b/13.go @@ -127,7 +127,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { config := hs.c.config c := hs.c - hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id + hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id hs.c.clientHello = hs.clientHello.marshal() // When picking the group for the handshake, priority is given to groups @@ -159,7 +159,7 @@ CurvePreferenceLoop: c.sendAlert(alertInternalError) return err } - hs.hello13.keyShare = serverKS + hs.hello.keyShare = serverKS hash := hashForSuite(hs.suite) hashSize := hash.Size() @@ -188,8 +188,8 @@ CurvePreferenceLoop: return errors.New("tls: bad ECDHE client share") } - hs.keySchedule.write(hs.hello13.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { + hs.keySchedule.write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { return err } @@ -603,8 +603,8 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { hs.hello13Enc.earlyData = true } } - hs.hello13.psk = true - hs.hello13.pskIdentity = uint16(i) + hs.hello.psk = true + hs.hello.pskIdentity = uint16(i) return true, alertSuccess } diff --git a/common.go b/common.go index 7c4f7bd..871e338 100644 --- a/common.go +++ b/common.go @@ -28,6 +28,7 @@ const ( VersionTLS12 = 0x0303 VersionTLS13 = 0x0304 VersionTLS13Draft18 = 0x7f00 | 18 + VersionTLS13Draft21 = 0x7f00 | 21 ) const ( diff --git a/conn.go b/conn.go index 27dc1c4..002222d 100644 --- a/conn.go +++ b/conn.go @@ -1131,11 +1131,7 @@ func (c *Conn) readHandshake() (interface{}, error) { case typeClientHello: m = new(clientHelloMsg) case typeServerHello: - if c.vers >= VersionTLS13 { - m = new(serverHelloMsg13) - } else { - m = new(serverHelloMsg) - } + m = new(serverHelloMsg) case typeEncryptedExtensions: m = new(encryptedExtensionsMsg) case typeNewSessionTicket: diff --git a/handshake_messages.go b/handshake_messages.go index 693cf4f..77b4c6c 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -711,6 +711,11 @@ type serverHelloMsg struct { secureRenegotiation []byte secureRenegotiationSupported bool alpnProtocol string + + // TLS 1.3 + keyShare keyShare + psk bool + pskIdentity uint16 } func (m *serverHelloMsg) equal(i interface{}) bool { @@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool { m.ticketSupported == m1.ticketSupported && m.secureRenegotiationSupported == m1.secureRenegotiationSupported && bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - m.alpnProtocol == m1.alpnProtocol + m.alpnProtocol == m1.alpnProtocol && + m.keyShare.group == m1.keyShare.group && + bytes.Equal(m.keyShare.data, m1.keyShare.data) && + m.psk == m1.psk && + m.pskIdentity == m1.pskIdentity } func (m *serverHelloMsg) marshal() []byte { @@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } + oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21 length := 38 + len(m.sessionId) + if oldTLS13Draft { + // no compression method, no session ID. + length = 36 + } numExtensions := 0 extensionsLength := 0 @@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += 2 + sctLen numExtensions++ } + if m.keyShare.group != 0 { + extensionsLength += 4 + len(m.keyShare.data) + numExtensions++ + } + if m.psk { + extensionsLength += 2 + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte { x[4] = uint8(m.vers >> 8) x[5] = uint8(m.vers) copy(x[6:38], m.random) - x[38] = uint8(len(m.sessionId)) - copy(x[39:39+len(m.sessionId)], m.sessionId) - z := x[39+len(m.sessionId):] + z := x[38:] + if !oldTLS13Draft { + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z = x[39+len(m.sessionId):] + } z[0] = uint8(m.cipherSuite >> 8) z[1] = uint8(m.cipherSuite) - z[2] = m.compressionMethod + if oldTLS13Draft { + // no compression method in older TLS 1.3 drafts. + z = z[2:] + } else { + z[2] = m.compressionMethod + z = z[3:] + } - z = z[3:] if numExtensions > 0 { z[0] = byte(extensionsLength >> 8) z[1] = byte(extensionsLength) @@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte { } } + if m.keyShare.group != 0 { + z[0] = uint8(extensionKeyShare >> 8) + z[1] = uint8(extensionKeyShare) + l := 4 + len(m.keyShare.data) + z[2] = uint8(l >> 8) + z[3] = uint8(l) + z[4] = uint8(m.keyShare.group >> 8) + z[5] = uint8(m.keyShare.group) + l -= 4 + z[6] = uint8(l >> 8) + z[7] = uint8(l) + copy(z[8:], m.keyShare.data) + z = z[8+l:] + } + + if m.psk { + z[0] = byte(extensionPreSharedKey >> 8) + z[1] = byte(extensionPreSharedKey) + z[3] = 2 + z[4] = byte(m.pskIdentity >> 8) + z[5] = byte(m.pskIdentity) + z = z[6:] + } + m.raw = x return x @@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { } m.raw = data m.vers = uint16(data[4])<<8 | uint16(data[5]) + oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21 m.random = data[6:38] - sessionIdLen := int(data[38]) - if sessionIdLen > 32 || len(data) < 39+sessionIdLen { - return alertDecodeError + if !oldTLS13Draft { + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return alertDecodeError + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + } else { + data = data[38:] } - m.sessionId = data[39 : 39+sessionIdLen] - data = data[39+sessionIdLen:] if len(data) < 3 { return alertDecodeError } m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - m.compressionMethod = data[2] - data = data[3:] + if oldTLS13Draft { + // no compression method in older TLS 1.3 drafts. + data = data[2:] + } else { + m.compressionMethod = data[2] + data = data[3:] + } m.nextProtoNeg = false m.nextProtos = nil @@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { m.scts = nil m.ticketSupported = false m.alpnProtocol = "" + m.keyShare.group = 0 + m.keyShare.data = nil + m.psk = false + m.pskIdentity = 0 if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { m.scts = append(m.scts, d[:sctLen]) d = d[sctLen:] } - } - data = data[length:] - } - - return alertSuccess -} - -type serverHelloMsg13 struct { - raw []byte - vers uint16 - random []byte - cipherSuite uint16 - keyShare keyShare - psk bool - pskIdentity uint16 -} - -func (m *serverHelloMsg13) equal(i interface{}) bool { - m1, ok := i.(*serverHelloMsg13) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - m.cipherSuite == m1.cipherSuite && - m.keyShare.group == m1.keyShare.group && - bytes.Equal(m.keyShare.data, m1.keyShare.data) && - m.psk == m1.psk && - m.pskIdentity == m1.pskIdentity -} - -func (m *serverHelloMsg13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - length := 38 - if m.keyShare.group != 0 { - length += 8 + len(m.keyShare.data) - } - if m.psk { - length += 6 - } - - x := make([]byte, 4+length) - x[0] = typeServerHello - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) - copy(x[6:38], m.random) - x[38] = uint8(m.cipherSuite >> 8) - x[39] = uint8(m.cipherSuite) - - z := x[42:] - x[40] = uint8(len(z) >> 8) - x[41] = uint8(len(z)) - - if m.psk { - z[0] = byte(extensionPreSharedKey >> 8) - z[1] = byte(extensionPreSharedKey) - z[3] = 2 - z[4] = byte(m.pskIdentity >> 8) - z[5] = byte(m.pskIdentity) - z = z[6:] - } - - if m.keyShare.group != 0 { - z[0] = uint8(extensionKeyShare >> 8) - z[1] = uint8(extensionKeyShare) - l := 4 + len(m.keyShare.data) - z[2] = uint8(l >> 8) - z[3] = uint8(l) - z[4] = uint8(m.keyShare.group >> 8) - z[5] = uint8(m.keyShare.group) - l -= 4 - z[6] = uint8(l >> 8) - z[7] = uint8(l) - copy(z[8:], m.keyShare.data) - } - - m.raw = x - return x -} - -func (m *serverHelloMsg13) unmarshal(data []byte) alert { - if len(data) < 50 { - return alertDecodeError - } - m.raw = data - m.vers = uint16(data[4])<<8 | uint16(data[5]) - m.random = data[6:38] - m.cipherSuite = uint16(data[38])<<8 | uint16(data[39]) - m.psk = false - m.pskIdentity = 0 - - extensionsLength := int(data[40])<<8 | int(data[41]) - data = data[42:] - if len(data) != extensionsLength { - return alertDecodeError - } - - for len(data) != 0 { - if len(data) < 4 { - return alertDecodeError - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - return alertDecodeError - } + case extensionKeyShare: + d := data[:length] - switch extension { - default: - return alertDecodeError - case extensionPreSharedKey: - if length != 2 { + if len(d) < 4 { return alertDecodeError } - m.psk = true - m.pskIdentity = uint16(data[0])<<8 | uint16(data[1]) - case extensionKeyShare: - if length < 2 { + m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1]) + l := int(d[2])<<8 | int(d[3]) + d = d[4:] + if len(d) != l { return alertDecodeError } - m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1]) - if length-4 != int(data[2])<<8|int(data[3]) { + m.keyShare.data = d[:l] + case extensionPreSharedKey: + if length != 2 { return alertDecodeError } - m.keyShare.data = data[4:length] + m.psk = true + m.pskIdentity = uint16(data[0])<<8 | uint16(data[1]) } data = data[length:] } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 3f26473..5b23a27 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -26,7 +26,6 @@ var tests = []interface{}{ &nextProtoMsg{}, &newSessionTicketMsg{}, &sessionState{}, - &serverHelloMsg13{}, &encryptedExtensionsMsg{}, &certificateMsg13{}, &newSessionTicketMsg13{}, @@ -209,16 +208,10 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } } - return reflect.ValueOf(m) -} - -func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value { - m := &serverHelloMsg13{} - m.vers = uint16(rand.Intn(65536)) - m.random = randomBytes(32, rand) - m.cipherSuite = uint16(rand.Int31()) - m.keyShare.group = CurveID(rand.Intn(30000)) - m.keyShare.data = randomBytes(rand.Intn(300), rand) + if rand.Intn(10) > 5 { + m.keyShare.group = CurveID(rand.Intn(30000)) + m.keyShare.data = randomBytes(rand.Intn(300), rand) + } if rand.Intn(10) > 5 { m.psk = true m.pskIdentity = uint16(rand.Int31()) diff --git a/handshake_server.go b/handshake_server.go index 68be6a7..aa4cc4e 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -29,10 +29,10 @@ type serverHandshakeState struct { masterSecret []byte cachedClientHelloInfo *ClientHelloInfo clientHello *clientHelloMsg + hello *serverHelloMsg cert *Certificate // TLS 1.0-1.2 fields - hello *serverHelloMsg ellipticOk bool ecdsaOk bool rsaDecryptOk bool @@ -42,7 +42,6 @@ type serverHandshakeState struct { certsFromClient [][]byte // TLS 1.3 fields - hello13 *serverHelloMsg13 hello13Enc *encryptedExtensionsMsg keySchedule *keySchedule13 clientFinishedKey []byte @@ -70,7 +69,7 @@ func (c *Conn) serverHandshake() error { // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2 c.buffering = true - if hs.hello13 != nil { + if c.vers >= VersionTLS13 { if err := hs.doTLS13Handshake(); err != nil { return err } @@ -250,11 +249,11 @@ Curves: hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported hs.hello.compressionMethod = compressionNone } else { - hs.hello13 = new(serverHelloMsg13) + hs.hello = new(serverHelloMsg) hs.hello13Enc = new(encryptedExtensionsMsg) - hs.hello13.vers = c.vers - hs.hello13.random = make([]byte, 32) - _, err = io.ReadFull(c.config.rand(), hs.hello13.random) + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + _, err = io.ReadFull(c.config.rand(), hs.hello.random) if err != nil { c.sendAlert(alertInternalError) return false, err