diff --git a/13.go b/13.go index 73b7d6d..55f7972 100644 --- a/13.go +++ b/13.go @@ -127,7 +127,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { config := hs.c.config c := hs.c - hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id + hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id hs.c.clientHello = hs.clientHello.marshal() // When picking the group for the handshake, priority is given to groups @@ -159,7 +159,7 @@ CurvePreferenceLoop: c.sendAlert(alertInternalError) return err } - hs.hello13.keyShare = serverKS + hs.hello.keyShare = serverKS hash := hashForSuite(hs.suite) hashSize := hash.Size() @@ -188,8 +188,8 @@ CurvePreferenceLoop: return errors.New("tls: bad ECDHE client share") } - hs.keySchedule.write(hs.hello13.marshal()) - if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { + hs.keySchedule.write(hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { return err } @@ -603,8 +603,8 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { hs.hello13Enc.earlyData = true } } - hs.hello13.psk = true - hs.hello13.pskIdentity = uint16(i) + hs.hello.psk = true + hs.hello.pskIdentity = uint16(i) return true, alertSuccess } diff --git a/common.go b/common.go index 7c4f7bd..871e338 100644 --- a/common.go +++ b/common.go @@ -28,6 +28,7 @@ const ( VersionTLS12 = 0x0303 VersionTLS13 = 0x0304 VersionTLS13Draft18 = 0x7f00 | 18 + VersionTLS13Draft21 = 0x7f00 | 21 ) const ( diff --git a/conn.go b/conn.go index 27dc1c4..002222d 100644 --- a/conn.go +++ b/conn.go @@ -1131,11 +1131,7 @@ func (c *Conn) readHandshake() (interface{}, error) { case typeClientHello: m = new(clientHelloMsg) case typeServerHello: - if c.vers >= VersionTLS13 { - m = new(serverHelloMsg13) - } else { - m = new(serverHelloMsg) - } + m = new(serverHelloMsg) case typeEncryptedExtensions: m = new(encryptedExtensionsMsg) case typeNewSessionTicket: diff --git a/handshake_messages.go b/handshake_messages.go index 693cf4f..77b4c6c 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -711,6 +711,11 @@ type serverHelloMsg struct { secureRenegotiation []byte secureRenegotiationSupported bool alpnProtocol string + + // TLS 1.3 + keyShare keyShare + psk bool + pskIdentity uint16 } func (m *serverHelloMsg) equal(i interface{}) bool { @@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool { m.ticketSupported == m1.ticketSupported && m.secureRenegotiationSupported == m1.secureRenegotiationSupported && bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && - m.alpnProtocol == m1.alpnProtocol + m.alpnProtocol == m1.alpnProtocol && + m.keyShare.group == m1.keyShare.group && + bytes.Equal(m.keyShare.data, m1.keyShare.data) && + m.psk == m1.psk && + m.pskIdentity == m1.pskIdentity } func (m *serverHelloMsg) marshal() []byte { @@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte { return m.raw } + oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21 length := 38 + len(m.sessionId) + if oldTLS13Draft { + // no compression method, no session ID. + length = 36 + } numExtensions := 0 extensionsLength := 0 @@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += 2 + sctLen numExtensions++ } + if m.keyShare.group != 0 { + extensionsLength += 4 + len(m.keyShare.data) + numExtensions++ + } + if m.psk { + extensionsLength += 2 + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte { x[4] = uint8(m.vers >> 8) x[5] = uint8(m.vers) copy(x[6:38], m.random) - x[38] = uint8(len(m.sessionId)) - copy(x[39:39+len(m.sessionId)], m.sessionId) - z := x[39+len(m.sessionId):] + z := x[38:] + if !oldTLS13Draft { + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z = x[39+len(m.sessionId):] + } z[0] = uint8(m.cipherSuite >> 8) z[1] = uint8(m.cipherSuite) - z[2] = m.compressionMethod + if oldTLS13Draft { + // no compression method in older TLS 1.3 drafts. + z = z[2:] + } else { + z[2] = m.compressionMethod + z = z[3:] + } - z = z[3:] if numExtensions > 0 { z[0] = byte(extensionsLength >> 8) z[1] = byte(extensionsLength) @@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte { } } + if m.keyShare.group != 0 { + z[0] = uint8(extensionKeyShare >> 8) + z[1] = uint8(extensionKeyShare) + l := 4 + len(m.keyShare.data) + z[2] = uint8(l >> 8) + z[3] = uint8(l) + z[4] = uint8(m.keyShare.group >> 8) + z[5] = uint8(m.keyShare.group) + l -= 4 + z[6] = uint8(l >> 8) + z[7] = uint8(l) + copy(z[8:], m.keyShare.data) + z = z[8+l:] + } + + if m.psk { + z[0] = byte(extensionPreSharedKey >> 8) + z[1] = byte(extensionPreSharedKey) + z[3] = 2 + z[4] = byte(m.pskIdentity >> 8) + z[5] = byte(m.pskIdentity) + z = z[6:] + } + m.raw = x return x @@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { } m.raw = data m.vers = uint16(data[4])<<8 | uint16(data[5]) + oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21 m.random = data[6:38] - sessionIdLen := int(data[38]) - if sessionIdLen > 32 || len(data) < 39+sessionIdLen { - return alertDecodeError + if !oldTLS13Draft { + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return alertDecodeError + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + } else { + data = data[38:] } - m.sessionId = data[39 : 39+sessionIdLen] - data = data[39+sessionIdLen:] if len(data) < 3 { return alertDecodeError } m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - m.compressionMethod = data[2] - data = data[3:] + if oldTLS13Draft { + // no compression method in older TLS 1.3 drafts. + data = data[2:] + } else { + m.compressionMethod = data[2] + data = data[3:] + } m.nextProtoNeg = false m.nextProtos = nil @@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { m.scts = nil m.ticketSupported = false m.alpnProtocol = "" + m.keyShare.group = 0 + m.keyShare.data = nil + m.psk = false + m.pskIdentity = 0 if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { m.scts = append(m.scts, d[:sctLen]) d = d[sctLen:] } - } - data = data[length:] - } + case extensionKeyShare: + d := data[:length] - return alertSuccess -} - -type serverHelloMsg13 struct { - raw []byte - vers uint16 - random []byte - cipherSuite uint16 - keyShare keyShare - psk bool - pskIdentity uint16 -} - -func (m *serverHelloMsg13) equal(i interface{}) bool { - m1, ok := i.(*serverHelloMsg13) - if !ok { - return false - } - - return bytes.Equal(m.raw, m1.raw) && - m.vers == m1.vers && - bytes.Equal(m.random, m1.random) && - m.cipherSuite == m1.cipherSuite && - m.keyShare.group == m1.keyShare.group && - bytes.Equal(m.keyShare.data, m1.keyShare.data) && - m.psk == m1.psk && - m.pskIdentity == m1.pskIdentity -} - -func (m *serverHelloMsg13) marshal() []byte { - if m.raw != nil { - return m.raw - } - - length := 38 - if m.keyShare.group != 0 { - length += 8 + len(m.keyShare.data) - } - if m.psk { - length += 6 - } - - x := make([]byte, 4+length) - x[0] = typeServerHello - x[1] = uint8(length >> 16) - x[2] = uint8(length >> 8) - x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) - copy(x[6:38], m.random) - x[38] = uint8(m.cipherSuite >> 8) - x[39] = uint8(m.cipherSuite) - - z := x[42:] - x[40] = uint8(len(z) >> 8) - x[41] = uint8(len(z)) - - if m.psk { - z[0] = byte(extensionPreSharedKey >> 8) - z[1] = byte(extensionPreSharedKey) - z[3] = 2 - z[4] = byte(m.pskIdentity >> 8) - z[5] = byte(m.pskIdentity) - z = z[6:] - } - - if m.keyShare.group != 0 { - z[0] = uint8(extensionKeyShare >> 8) - z[1] = uint8(extensionKeyShare) - l := 4 + len(m.keyShare.data) - z[2] = uint8(l >> 8) - z[3] = uint8(l) - z[4] = uint8(m.keyShare.group >> 8) - z[5] = uint8(m.keyShare.group) - l -= 4 - z[6] = uint8(l >> 8) - z[7] = uint8(l) - copy(z[8:], m.keyShare.data) - } - - m.raw = x - return x -} - -func (m *serverHelloMsg13) unmarshal(data []byte) alert { - if len(data) < 50 { - return alertDecodeError - } - m.raw = data - m.vers = uint16(data[4])<<8 | uint16(data[5]) - m.random = data[6:38] - m.cipherSuite = uint16(data[38])<<8 | uint16(data[39]) - m.psk = false - m.pskIdentity = 0 - - extensionsLength := int(data[40])<<8 | int(data[41]) - data = data[42:] - if len(data) != extensionsLength { - return alertDecodeError - } - - for len(data) != 0 { - if len(data) < 4 { - return alertDecodeError - } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - return alertDecodeError - } - - switch extension { - default: - return alertDecodeError + if len(d) < 4 { + return alertDecodeError + } + m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1]) + l := int(d[2])<<8 | int(d[3]) + d = d[4:] + if len(d) != l { + return alertDecodeError + } + m.keyShare.data = d[:l] case extensionPreSharedKey: if length != 2 { return alertDecodeError } m.psk = true m.pskIdentity = uint16(data[0])<<8 | uint16(data[1]) - case extensionKeyShare: - if length < 2 { - return alertDecodeError - } - m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1]) - if length-4 != int(data[2])<<8|int(data[3]) { - return alertDecodeError - } - m.keyShare.data = data[4:length] } data = data[length:] } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 3f26473..5b23a27 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -26,7 +26,6 @@ var tests = []interface{}{ &nextProtoMsg{}, &newSessionTicketMsg{}, &sessionState{}, - &serverHelloMsg13{}, &encryptedExtensionsMsg{}, &certificateMsg13{}, &newSessionTicketMsg13{}, @@ -209,16 +208,10 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { } } - return reflect.ValueOf(m) -} - -func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value { - m := &serverHelloMsg13{} - m.vers = uint16(rand.Intn(65536)) - m.random = randomBytes(32, rand) - m.cipherSuite = uint16(rand.Int31()) - m.keyShare.group = CurveID(rand.Intn(30000)) - m.keyShare.data = randomBytes(rand.Intn(300), rand) + if rand.Intn(10) > 5 { + m.keyShare.group = CurveID(rand.Intn(30000)) + m.keyShare.data = randomBytes(rand.Intn(300), rand) + } if rand.Intn(10) > 5 { m.psk = true m.pskIdentity = uint16(rand.Int31()) diff --git a/handshake_server.go b/handshake_server.go index 68be6a7..aa4cc4e 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -29,10 +29,10 @@ type serverHandshakeState struct { masterSecret []byte cachedClientHelloInfo *ClientHelloInfo clientHello *clientHelloMsg + hello *serverHelloMsg cert *Certificate // TLS 1.0-1.2 fields - hello *serverHelloMsg ellipticOk bool ecdsaOk bool rsaDecryptOk bool @@ -42,7 +42,6 @@ type serverHandshakeState struct { certsFromClient [][]byte // TLS 1.3 fields - hello13 *serverHelloMsg13 hello13Enc *encryptedExtensionsMsg keySchedule *keySchedule13 clientFinishedKey []byte @@ -70,7 +69,7 @@ func (c *Conn) serverHandshake() error { // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2 c.buffering = true - if hs.hello13 != nil { + if c.vers >= VersionTLS13 { if err := hs.doTLS13Handshake(); err != nil { return err } @@ -250,11 +249,11 @@ Curves: hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported hs.hello.compressionMethod = compressionNone } else { - hs.hello13 = new(serverHelloMsg13) + hs.hello = new(serverHelloMsg) hs.hello13Enc = new(encryptedExtensionsMsg) - hs.hello13.vers = c.vers - hs.hello13.random = make([]byte, 32) - _, err = io.ReadFull(c.config.rand(), hs.hello13.random) + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + _, err = io.ReadFull(c.config.rand(), hs.hello.random) if err != nil { c.sendAlert(alertInternalError) return false, err