diff --git a/_dev/tris-localserver/server.go b/_dev/tris-localserver/server.go index 58a4876..d040449 100644 --- a/_dev/tris-localserver/server.go +++ b/_dev/tris-localserver/server.go @@ -18,6 +18,7 @@ var tlsVersionToName = map[uint16]string{ tls.VersionTLS13: "1.3", tls.VersionTLS13Draft18: "1.3 (draft 18)", tls.VersionTLS13Draft21: "1.3 (draft 21)", + tls.VersionTLS13Draft22: "1.3 (draft 22)", } func startServer(addr string, rsa, offer0RTT, accept0RTT bool) { diff --git a/common.go b/common.go index 87ad18a..8d6f11e 100644 --- a/common.go +++ b/common.go @@ -29,6 +29,7 @@ const ( VersionTLS13 = 0x0304 VersionTLS13Draft18 = 0x7f00 | 18 VersionTLS13Draft21 = 0x7f00 | 21 + VersionTLS13Draft22 = 0x7f00 | 22 ) const ( @@ -855,7 +856,7 @@ var configSuppVersArray = [...]uint16{VersionTLS13, VersionTLS12, VersionTLS11, // with TLS 1.3 draft versions included. // // TODO: remove once TLS 1.3 is finalised. -var tls13DraftSuppVersArray = [...]uint16{VersionTLS13Draft21, VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30} +var tls13DraftSuppVersArray = [...]uint16{VersionTLS13Draft22, VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30} // getSupportedVersions returns the protocol versions that are supported by the // current configuration. diff --git a/handshake_messages.go b/handshake_messages.go index a896cf8..77b1b67 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -808,6 +808,11 @@ func (m *serverHelloMsg) marshal() []byte { extensionsLength += 2 numExtensions++ } + // supported_versions extension + if m.vers >= VersionTLS13 { + extensionsLength += 2 + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -819,8 +824,13 @@ func (m *serverHelloMsg) marshal() []byte { x[1] = uint8(length >> 16) x[2] = uint8(length >> 8) x[3] = uint8(length) - x[4] = uint8(m.vers >> 8) - x[5] = uint8(m.vers) + if m.vers >= VersionTLS13 { + x[4] = 3 + x[5] = 3 + } else { + x[4] = uint8(m.vers >> 8) + x[5] = uint8(m.vers) + } copy(x[6:38], m.random) z := x[38:] if !oldTLS13Draft { @@ -843,6 +853,14 @@ func (m *serverHelloMsg) marshal() []byte { z[1] = byte(extensionsLength) z = z[2:] } + if m.vers >= VersionTLS13 { + z[0] = byte(extensionSupportedVersions >> 8) + z[1] = byte(extensionSupportedVersions) + z[3] = 2 + z[4] = uint8(m.vers >> 8) + z[5] = uint8(m.vers) + z = z[6:] + } if m.nextProtoNeg { z[0] = byte(extensionNextProtoNeg >> 8) z[1] = byte(extensionNextProtoNeg & 0xff) @@ -996,6 +1014,17 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert { return alertDecodeError } + svData := findExtension(data, extensionSupportedVersions) + if svData != nil { + if len(svData) != 2 { + return alertDecodeError + } + if m.vers != VersionTLS12 { + return alertDecodeError + } + m.vers = uint16(svData[0])<<8 | uint16(svData[1]) + } + for len(data) != 0 { if len(data) < 4 { return alertDecodeError @@ -2384,3 +2413,22 @@ func eqKeyShares(x, y []keyShare) bool { } return true } + +func findExtension(data []byte, extensionType uint16) []byte { + for len(data) != 0 { + if len(data) < 4 { + return nil + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return nil + } + if extension == extensionType { + return data[:length] + } + data = data[length:] + } + return nil +}