tris: unify ServerHello processing in preparation for D22
Merge serverHelloMsg13 into serverHelloMsg in preparation for draft 22. This will also simplify the client implementation since only one structure has to be checked. Also fixed potential out-of-bounds access with keyShare unmarshal.
This commit is contained in:
parent
0bbbecd894
commit
4e6ebb63dd
12
13.go
12
13.go
@ -127,7 +127,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
|
||||
config := hs.c.config
|
||||
c := hs.c
|
||||
|
||||
hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id
|
||||
hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id
|
||||
hs.c.clientHello = hs.clientHello.marshal()
|
||||
|
||||
// When picking the group for the handshake, priority is given to groups
|
||||
@ -159,7 +159,7 @@ CurvePreferenceLoop:
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
hs.hello13.keyShare = serverKS
|
||||
hs.hello.keyShare = serverKS
|
||||
|
||||
hash := hashForSuite(hs.suite)
|
||||
hashSize := hash.Size()
|
||||
@ -188,8 +188,8 @@ CurvePreferenceLoop:
|
||||
return errors.New("tls: bad ECDHE client share")
|
||||
}
|
||||
|
||||
hs.keySchedule.write(hs.hello13.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
|
||||
hs.keySchedule.write(hs.hello.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -603,8 +603,8 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
|
||||
hs.hello13Enc.earlyData = true
|
||||
}
|
||||
}
|
||||
hs.hello13.psk = true
|
||||
hs.hello13.pskIdentity = uint16(i)
|
||||
hs.hello.psk = true
|
||||
hs.hello.pskIdentity = uint16(i)
|
||||
return true, alertSuccess
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,7 @@ const (
|
||||
VersionTLS12 = 0x0303
|
||||
VersionTLS13 = 0x0304
|
||||
VersionTLS13Draft18 = 0x7f00 | 18
|
||||
VersionTLS13Draft21 = 0x7f00 | 21
|
||||
)
|
||||
|
||||
const (
|
||||
|
4
conn.go
4
conn.go
@ -1131,11 +1131,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
|
||||
case typeClientHello:
|
||||
m = new(clientHelloMsg)
|
||||
case typeServerHello:
|
||||
if c.vers >= VersionTLS13 {
|
||||
m = new(serverHelloMsg13)
|
||||
} else {
|
||||
m = new(serverHelloMsg)
|
||||
}
|
||||
case typeEncryptedExtensions:
|
||||
m = new(encryptedExtensionsMsg)
|
||||
case typeNewSessionTicket:
|
||||
|
@ -711,6 +711,11 @@ type serverHelloMsg struct {
|
||||
secureRenegotiation []byte
|
||||
secureRenegotiationSupported bool
|
||||
alpnProtocol string
|
||||
|
||||
// TLS 1.3
|
||||
keyShare keyShare
|
||||
psk bool
|
||||
pskIdentity uint16
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg) equal(i interface{}) bool {
|
||||
@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
|
||||
m.ticketSupported == m1.ticketSupported &&
|
||||
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
|
||||
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
|
||||
m.alpnProtocol == m1.alpnProtocol
|
||||
m.alpnProtocol == m1.alpnProtocol &&
|
||||
m.keyShare.group == m1.keyShare.group &&
|
||||
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
|
||||
m.psk == m1.psk &&
|
||||
m.pskIdentity == m1.pskIdentity
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg) marshal() []byte {
|
||||
@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte {
|
||||
return m.raw
|
||||
}
|
||||
|
||||
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
|
||||
length := 38 + len(m.sessionId)
|
||||
if oldTLS13Draft {
|
||||
// no compression method, no session ID.
|
||||
length = 36
|
||||
}
|
||||
numExtensions := 0
|
||||
extensionsLength := 0
|
||||
|
||||
@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte {
|
||||
extensionsLength += 2 + sctLen
|
||||
numExtensions++
|
||||
}
|
||||
if m.keyShare.group != 0 {
|
||||
extensionsLength += 4 + len(m.keyShare.data)
|
||||
numExtensions++
|
||||
}
|
||||
if m.psk {
|
||||
extensionsLength += 2
|
||||
numExtensions++
|
||||
}
|
||||
|
||||
if numExtensions > 0 {
|
||||
extensionsLength += 4 * numExtensions
|
||||
@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte {
|
||||
x[4] = uint8(m.vers >> 8)
|
||||
x[5] = uint8(m.vers)
|
||||
copy(x[6:38], m.random)
|
||||
z := x[38:]
|
||||
if !oldTLS13Draft {
|
||||
x[38] = uint8(len(m.sessionId))
|
||||
copy(x[39:39+len(m.sessionId)], m.sessionId)
|
||||
z := x[39+len(m.sessionId):]
|
||||
z = x[39+len(m.sessionId):]
|
||||
}
|
||||
z[0] = uint8(m.cipherSuite >> 8)
|
||||
z[1] = uint8(m.cipherSuite)
|
||||
if oldTLS13Draft {
|
||||
// no compression method in older TLS 1.3 drafts.
|
||||
z = z[2:]
|
||||
} else {
|
||||
z[2] = m.compressionMethod
|
||||
|
||||
z = z[3:]
|
||||
}
|
||||
|
||||
if numExtensions > 0 {
|
||||
z[0] = byte(extensionsLength >> 8)
|
||||
z[1] = byte(extensionsLength)
|
||||
@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte {
|
||||
}
|
||||
}
|
||||
|
||||
if m.keyShare.group != 0 {
|
||||
z[0] = uint8(extensionKeyShare >> 8)
|
||||
z[1] = uint8(extensionKeyShare)
|
||||
l := 4 + len(m.keyShare.data)
|
||||
z[2] = uint8(l >> 8)
|
||||
z[3] = uint8(l)
|
||||
z[4] = uint8(m.keyShare.group >> 8)
|
||||
z[5] = uint8(m.keyShare.group)
|
||||
l -= 4
|
||||
z[6] = uint8(l >> 8)
|
||||
z[7] = uint8(l)
|
||||
copy(z[8:], m.keyShare.data)
|
||||
z = z[8+l:]
|
||||
}
|
||||
|
||||
if m.psk {
|
||||
z[0] = byte(extensionPreSharedKey >> 8)
|
||||
z[1] = byte(extensionPreSharedKey)
|
||||
z[3] = 2
|
||||
z[4] = byte(m.pskIdentity >> 8)
|
||||
z[5] = byte(m.pskIdentity)
|
||||
z = z[6:]
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
|
||||
return x
|
||||
@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
|
||||
}
|
||||
m.raw = data
|
||||
m.vers = uint16(data[4])<<8 | uint16(data[5])
|
||||
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
|
||||
m.random = data[6:38]
|
||||
if !oldTLS13Draft {
|
||||
sessionIdLen := int(data[38])
|
||||
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.sessionId = data[39 : 39+sessionIdLen]
|
||||
data = data[39+sessionIdLen:]
|
||||
} else {
|
||||
data = data[38:]
|
||||
}
|
||||
if len(data) < 3 {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
|
||||
if oldTLS13Draft {
|
||||
// no compression method in older TLS 1.3 drafts.
|
||||
data = data[2:]
|
||||
} else {
|
||||
m.compressionMethod = data[2]
|
||||
data = data[3:]
|
||||
}
|
||||
|
||||
m.nextProtoNeg = false
|
||||
m.nextProtos = nil
|
||||
@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
|
||||
m.scts = nil
|
||||
m.ticketSupported = false
|
||||
m.alpnProtocol = ""
|
||||
m.keyShare.group = 0
|
||||
m.keyShare.data = nil
|
||||
m.psk = false
|
||||
m.pskIdentity = 0
|
||||
|
||||
if len(data) == 0 {
|
||||
// ServerHello is optionally followed by extension data
|
||||
@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
|
||||
m.scts = append(m.scts, d[:sctLen])
|
||||
d = d[sctLen:]
|
||||
}
|
||||
}
|
||||
data = data[length:]
|
||||
}
|
||||
case extensionKeyShare:
|
||||
d := data[:length]
|
||||
|
||||
return alertSuccess
|
||||
}
|
||||
|
||||
type serverHelloMsg13 struct {
|
||||
raw []byte
|
||||
vers uint16
|
||||
random []byte
|
||||
cipherSuite uint16
|
||||
keyShare keyShare
|
||||
psk bool
|
||||
pskIdentity uint16
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg13) equal(i interface{}) bool {
|
||||
m1, ok := i.(*serverHelloMsg13)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.vers == m1.vers &&
|
||||
bytes.Equal(m.random, m1.random) &&
|
||||
m.cipherSuite == m1.cipherSuite &&
|
||||
m.keyShare.group == m1.keyShare.group &&
|
||||
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
|
||||
m.psk == m1.psk &&
|
||||
m.pskIdentity == m1.pskIdentity
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg13) marshal() []byte {
|
||||
if m.raw != nil {
|
||||
return m.raw
|
||||
}
|
||||
|
||||
length := 38
|
||||
if m.keyShare.group != 0 {
|
||||
length += 8 + len(m.keyShare.data)
|
||||
}
|
||||
if m.psk {
|
||||
length += 6
|
||||
}
|
||||
|
||||
x := make([]byte, 4+length)
|
||||
x[0] = typeServerHello
|
||||
x[1] = uint8(length >> 16)
|
||||
x[2] = uint8(length >> 8)
|
||||
x[3] = uint8(length)
|
||||
x[4] = uint8(m.vers >> 8)
|
||||
x[5] = uint8(m.vers)
|
||||
copy(x[6:38], m.random)
|
||||
x[38] = uint8(m.cipherSuite >> 8)
|
||||
x[39] = uint8(m.cipherSuite)
|
||||
|
||||
z := x[42:]
|
||||
x[40] = uint8(len(z) >> 8)
|
||||
x[41] = uint8(len(z))
|
||||
|
||||
if m.psk {
|
||||
z[0] = byte(extensionPreSharedKey >> 8)
|
||||
z[1] = byte(extensionPreSharedKey)
|
||||
z[3] = 2
|
||||
z[4] = byte(m.pskIdentity >> 8)
|
||||
z[5] = byte(m.pskIdentity)
|
||||
z = z[6:]
|
||||
}
|
||||
|
||||
if m.keyShare.group != 0 {
|
||||
z[0] = uint8(extensionKeyShare >> 8)
|
||||
z[1] = uint8(extensionKeyShare)
|
||||
l := 4 + len(m.keyShare.data)
|
||||
z[2] = uint8(l >> 8)
|
||||
z[3] = uint8(l)
|
||||
z[4] = uint8(m.keyShare.group >> 8)
|
||||
z[5] = uint8(m.keyShare.group)
|
||||
l -= 4
|
||||
z[6] = uint8(l >> 8)
|
||||
z[7] = uint8(l)
|
||||
copy(z[8:], m.keyShare.data)
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
return x
|
||||
}
|
||||
|
||||
func (m *serverHelloMsg13) unmarshal(data []byte) alert {
|
||||
if len(data) < 50 {
|
||||
if len(d) < 4 {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.raw = data
|
||||
m.vers = uint16(data[4])<<8 | uint16(data[5])
|
||||
m.random = data[6:38]
|
||||
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
|
||||
m.psk = false
|
||||
m.pskIdentity = 0
|
||||
|
||||
extensionsLength := int(data[40])<<8 | int(data[41])
|
||||
data = data[42:]
|
||||
if len(data) != extensionsLength {
|
||||
m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
|
||||
l := int(d[2])<<8 | int(d[3])
|
||||
d = d[4:]
|
||||
if len(d) != l {
|
||||
return alertDecodeError
|
||||
}
|
||||
|
||||
for len(data) != 0 {
|
||||
if len(data) < 4 {
|
||||
return alertDecodeError
|
||||
}
|
||||
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||
length := int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if len(data) < length {
|
||||
return alertDecodeError
|
||||
}
|
||||
|
||||
switch extension {
|
||||
default:
|
||||
return alertDecodeError
|
||||
m.keyShare.data = d[:l]
|
||||
case extensionPreSharedKey:
|
||||
if length != 2 {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.psk = true
|
||||
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
|
||||
case extensionKeyShare:
|
||||
if length < 2 {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1])
|
||||
if length-4 != int(data[2])<<8|int(data[3]) {
|
||||
return alertDecodeError
|
||||
}
|
||||
m.keyShare.data = data[4:length]
|
||||
}
|
||||
data = data[length:]
|
||||
}
|
||||
|
@ -26,7 +26,6 @@ var tests = []interface{}{
|
||||
&nextProtoMsg{},
|
||||
&newSessionTicketMsg{},
|
||||
&sessionState{},
|
||||
&serverHelloMsg13{},
|
||||
&encryptedExtensionsMsg{},
|
||||
&certificateMsg13{},
|
||||
&newSessionTicketMsg13{},
|
||||
@ -209,16 +208,10 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
|
||||
func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m := &serverHelloMsg13{}
|
||||
m.vers = uint16(rand.Intn(65536))
|
||||
m.random = randomBytes(32, rand)
|
||||
m.cipherSuite = uint16(rand.Int31())
|
||||
if rand.Intn(10) > 5 {
|
||||
m.keyShare.group = CurveID(rand.Intn(30000))
|
||||
m.keyShare.data = randomBytes(rand.Intn(300), rand)
|
||||
}
|
||||
if rand.Intn(10) > 5 {
|
||||
m.psk = true
|
||||
m.pskIdentity = uint16(rand.Int31())
|
||||
|
@ -29,10 +29,10 @@ type serverHandshakeState struct {
|
||||
masterSecret []byte
|
||||
cachedClientHelloInfo *ClientHelloInfo
|
||||
clientHello *clientHelloMsg
|
||||
hello *serverHelloMsg
|
||||
cert *Certificate
|
||||
|
||||
// TLS 1.0-1.2 fields
|
||||
hello *serverHelloMsg
|
||||
ellipticOk bool
|
||||
ecdsaOk bool
|
||||
rsaDecryptOk bool
|
||||
@ -42,7 +42,6 @@ type serverHandshakeState struct {
|
||||
certsFromClient [][]byte
|
||||
|
||||
// TLS 1.3 fields
|
||||
hello13 *serverHelloMsg13
|
||||
hello13Enc *encryptedExtensionsMsg
|
||||
keySchedule *keySchedule13
|
||||
clientFinishedKey []byte
|
||||
@ -70,7 +69,7 @@ func (c *Conn) serverHandshake() error {
|
||||
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
|
||||
// and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2
|
||||
c.buffering = true
|
||||
if hs.hello13 != nil {
|
||||
if c.vers >= VersionTLS13 {
|
||||
if err := hs.doTLS13Handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -250,11 +249,11 @@ Curves:
|
||||
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
|
||||
hs.hello.compressionMethod = compressionNone
|
||||
} else {
|
||||
hs.hello13 = new(serverHelloMsg13)
|
||||
hs.hello = new(serverHelloMsg)
|
||||
hs.hello13Enc = new(encryptedExtensionsMsg)
|
||||
hs.hello13.vers = c.vers
|
||||
hs.hello13.random = make([]byte, 32)
|
||||
_, err = io.ReadFull(c.config.rand(), hs.hello13.random)
|
||||
hs.hello.vers = c.vers
|
||||
hs.hello.random = make([]byte, 32)
|
||||
_, err = io.ReadFull(c.config.rand(), hs.hello.random)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return false, err
|
||||
|
Loading…
Reference in New Issue
Block a user