@@ -711,6 +711,11 @@ type serverHelloMsg struct {
secureRenegotiation []byte
secureRenegotiation []byte
secureRenegotiationSupported bool
secureRenegotiationSupported bool
alpnProtocol string
alpnProtocol string
// TLS 1.3
keyShare keyShare
psk bool
pskIdentity uint16
}
}
func (m *serverHelloMsg) equal(i interface{}) bool {
func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
m.ticketSupported == m1.ticketSupported &&
m.ticketSupported == m1.ticketSupported &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
m.alpnProtocol == m1.alpnProtocol
m.alpnProtocol == m1.alpnProtocol &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
}
}
func (m *serverHelloMsg) marshal() []byte {
func (m *serverHelloMsg) marshal() []byte {
@@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte {
return m.raw
return m.raw
}
}
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
length := 38 + len(m.sessionId)
length := 38 + len(m.sessionId)
if oldTLS13Draft {
// no compression method, no session ID.
length = 36
}
numExtensions := 0
numExtensions := 0
extensionsLength := 0
extensionsLength := 0
@@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte {
extensionsLength += 2 + sctLen
extensionsLength += 2 + sctLen
numExtensions++
numExtensions++
}
}
if m.keyShare.group != 0 {
extensionsLength += 4 + len(m.keyShare.data)
numExtensions++
}
if m.psk {
extensionsLength += 2
numExtensions++
}
if numExtensions > 0 {
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
extensionsLength += 4 * numExtensions
@@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte {
x[4] = uint8(m.vers >> 8)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z := x[39+len(m.sessionId):]
z := x[38:]
if !oldTLS13Draft {
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z = x[39+len(m.sessionId):]
}
z[0] = uint8(m.cipherSuite >> 8)
z[0] = uint8(m.cipherSuite >> 8)
z[1] = uint8(m.cipherSuite)
z[1] = uint8(m.cipherSuite)
z[2] = m.compressionMethod
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
z = z[2:]
} else {
z[2] = m.compressionMethod
z = z[3:]
}
z = z[3:]
if numExtensions > 0 {
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
z[1] = byte(extensionsLength)
@@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte {
}
}
}
}
if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
z = z[8+l:]
}
if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}
m.raw = x
m.raw = x
return x
return x
@@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
}
}
m.raw = data
m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.vers = uint16(data[4])<<8 | uint16(data[5])
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
m.random = data[6:38]
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
if !oldTLS13Draft {
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
} else {
data = data[38:]
}
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
if len(data) < 3 {
if len(data) < 3 {
return alertDecodeError
return alertDecodeError
}
}
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
m.compressionMethod = data[2]
data = data[3:]
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
data = data[2:]
} else {
m.compressionMethod = data[2]
data = data[3:]
}
m.nextProtoNeg = false
m.nextProtoNeg = false
m.nextProtos = nil
m.nextProtos = nil
@@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = nil
m.scts = nil
m.ticketSupported = false
m.ticketSupported = false
m.alpnProtocol = ""
m.alpnProtocol = ""
m.keyShare.group = 0
m.keyShare.data = nil
m.psk = false
m.pskIdentity = 0
if len(data) == 0 {
if len(data) == 0 {
// ServerHello is optionally followed by extension data
// ServerHello is optionally followed by extension data
@@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = append(m.scts, d[:sctLen])
m.scts = append(m.scts, d[:sctLen])
d = d[sctLen:]
d = d[sctLen:]
}
}
}
data = data[length:]
}
return alertSuccess
}
type serverHelloMsg13 struct {
raw []byte
vers uint16
random []byte
cipherSuite uint16
keyShare keyShare
psk bool
pskIdentity uint16
}
func (m *serverHelloMsg13) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg13)
if !ok {
return false
}
return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
m.cipherSuite == m1.cipherSuite &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
}
func (m *serverHelloMsg13) marshal() []byte {
if m.raw != nil {
return m.raw
}
length := 38
if m.keyShare.group != 0 {
length += 8 + len(m.keyShare.data)
}
if m.psk {
length += 6
}
x := make([]byte, 4+length)
x[0] = typeServerHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(m.cipherSuite >> 8)
x[39] = uint8(m.cipherSuite)
z := x[42:]
x[40] = uint8(len(z) >> 8)
x[41] = uint8(len(z))
if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}
if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
}
m.raw = x
return x
}
func (m *serverHelloMsg13) unmarshal(data []byte) alert {
if len(data) < 50 {
return alertDecodeError
}
m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38]
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
m.psk = false
m.pskIdentity = 0
extensionsLength := int(data[40])<<8 | int(data[41])
data = data[42:]
if len(data) != extensionsLength {
return alertDecodeError
}
for len(data) != 0 {
if len(data) < 4 {
return alertDecodeError
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return alertDecodeError
}
case extensionKeyShare:
d := data[:length]
switch extension {
default:
return alertDecodeError
case extensionPreSharedKey:
if length != 2 {
if len(d) < 4 {
return alertDecodeError
return alertDecodeError
}
}
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1 ])
case extensionKeyShare:
if length < 2 {
m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
l := int(d[2])<<8 | int(d[3])
d = d[4:]
if len(d) != l {
return alertDecodeError
return alertDecodeError
}
}
m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1])
if length-4 != int(data[2])<<8|int(data[3]) {
m.keyShare.data = d[:l]
case extensionPreSharedKey:
if length != 2 {
return alertDecodeError
return alertDecodeError
}
}
m.keyShare.data = data[4:length]
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
}
}
data = data[length:]
data = data[length:]
}
}