Bläddra i källkod

tris: unify ServerHello processing in preparation for D22

Merge serverHelloMsg13 into serverHelloMsg in preparation for draft 22.
This will also simplify the client implementation since only one
structure has to be checked.

Also fixed potential out-of-bounds access with keyShare unmarshal.
tls13
Peter Wu 7 år sedan
förälder
incheckning
4e6ebb63dd
6 ändrade filer med 111 tillägg och 169 borttagningar
  1. +6
    -6
      13.go
  2. +1
    -0
      common.go
  3. +1
    -5
      conn.go
  4. +93
    -140
      handshake_messages.go
  5. +4
    -11
      handshake_messages_test.go
  6. +6
    -7
      handshake_server.go

+ 6
- 6
13.go Visa fil

@@ -127,7 +127,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
config := hs.c.config config := hs.c.config
c := hs.c c := hs.c


hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id
hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id
hs.c.clientHello = hs.clientHello.marshal() hs.c.clientHello = hs.clientHello.marshal()


// When picking the group for the handshake, priority is given to groups // When picking the group for the handshake, priority is given to groups
@@ -159,7 +159,7 @@ CurvePreferenceLoop:
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err
} }
hs.hello13.keyShare = serverKS
hs.hello.keyShare = serverKS


hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
hashSize := hash.Size() hashSize := hash.Size()
@@ -188,8 +188,8 @@ CurvePreferenceLoop:
return errors.New("tls: bad ECDHE client share") return errors.New("tls: bad ECDHE client share")
} }


hs.keySchedule.write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
hs.keySchedule.write(hs.hello.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err return err
} }


@@ -603,8 +603,8 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
hs.hello13Enc.earlyData = true hs.hello13Enc.earlyData = true
} }
} }
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
hs.hello.psk = true
hs.hello.pskIdentity = uint16(i)
return true, alertSuccess return true, alertSuccess
} }




+ 1
- 0
common.go Visa fil

@@ -28,6 +28,7 @@ const (
VersionTLS12 = 0x0303 VersionTLS12 = 0x0303
VersionTLS13 = 0x0304 VersionTLS13 = 0x0304
VersionTLS13Draft18 = 0x7f00 | 18 VersionTLS13Draft18 = 0x7f00 | 18
VersionTLS13Draft21 = 0x7f00 | 21
) )


const ( const (


+ 1
- 5
conn.go Visa fil

@@ -1131,11 +1131,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeClientHello: case typeClientHello:
m = new(clientHelloMsg) m = new(clientHelloMsg)
case typeServerHello: case typeServerHello:
if c.vers >= VersionTLS13 {
m = new(serverHelloMsg13)
} else {
m = new(serverHelloMsg)
}
m = new(serverHelloMsg)
case typeEncryptedExtensions: case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg) m = new(encryptedExtensionsMsg)
case typeNewSessionTicket: case typeNewSessionTicket:


+ 93
- 140
handshake_messages.go Visa fil

@@ -711,6 +711,11 @@ type serverHelloMsg struct {
secureRenegotiation []byte secureRenegotiation []byte
secureRenegotiationSupported bool secureRenegotiationSupported bool
alpnProtocol string alpnProtocol string

// TLS 1.3
keyShare keyShare
psk bool
pskIdentity uint16
} }


func (m *serverHelloMsg) equal(i interface{}) bool { func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
m.ticketSupported == m1.ticketSupported && m.ticketSupported == m1.ticketSupported &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported && m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
m.alpnProtocol == m1.alpnProtocol
m.alpnProtocol == m1.alpnProtocol &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
} }


func (m *serverHelloMsg) marshal() []byte { func (m *serverHelloMsg) marshal() []byte {
@@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte {
return m.raw return m.raw
} }


oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
length := 38 + len(m.sessionId) length := 38 + len(m.sessionId)
if oldTLS13Draft {
// no compression method, no session ID.
length = 36
}
numExtensions := 0 numExtensions := 0
extensionsLength := 0 extensionsLength := 0


@@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte {
extensionsLength += 2 + sctLen extensionsLength += 2 + sctLen
numExtensions++ numExtensions++
} }
if m.keyShare.group != 0 {
extensionsLength += 4 + len(m.keyShare.data)
numExtensions++
}
if m.psk {
extensionsLength += 2
numExtensions++
}


if numExtensions > 0 { if numExtensions > 0 {
extensionsLength += 4 * numExtensions extensionsLength += 4 * numExtensions
@@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte {
x[4] = uint8(m.vers >> 8) x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers) x[5] = uint8(m.vers)
copy(x[6:38], m.random) copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z := x[39+len(m.sessionId):]
z := x[38:]
if !oldTLS13Draft {
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z = x[39+len(m.sessionId):]
}
z[0] = uint8(m.cipherSuite >> 8) z[0] = uint8(m.cipherSuite >> 8)
z[1] = uint8(m.cipherSuite) z[1] = uint8(m.cipherSuite)
z[2] = m.compressionMethod
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
z = z[2:]
} else {
z[2] = m.compressionMethod
z = z[3:]
}


z = z[3:]
if numExtensions > 0 { if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8) z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength) z[1] = byte(extensionsLength)
@@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte {
} }
} }


if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
z = z[8+l:]
}

if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}

m.raw = x m.raw = x


return x return x
@@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
} }
m.raw = data m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5]) m.vers = uint16(data[4])<<8 | uint16(data[5])
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
m.random = data[6:38] m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
if !oldTLS13Draft {
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
} else {
data = data[38:]
} }
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
if len(data) < 3 { if len(data) < 3 {
return alertDecodeError return alertDecodeError
} }
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
m.compressionMethod = data[2]
data = data[3:]
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
data = data[2:]
} else {
m.compressionMethod = data[2]
data = data[3:]
}


m.nextProtoNeg = false m.nextProtoNeg = false
m.nextProtos = nil m.nextProtos = nil
@@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = nil m.scts = nil
m.ticketSupported = false m.ticketSupported = false
m.alpnProtocol = "" m.alpnProtocol = ""
m.keyShare.group = 0
m.keyShare.data = nil
m.psk = false
m.pskIdentity = 0


if len(data) == 0 { if len(data) == 0 {
// ServerHello is optionally followed by extension data // ServerHello is optionally followed by extension data
@@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = append(m.scts, d[:sctLen]) m.scts = append(m.scts, d[:sctLen])
d = d[sctLen:] d = d[sctLen:]
} }
}
data = data[length:]
}

return alertSuccess
}

type serverHelloMsg13 struct {
raw []byte
vers uint16
random []byte
cipherSuite uint16
keyShare keyShare
psk bool
pskIdentity uint16
}

func (m *serverHelloMsg13) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg13)
if !ok {
return false
}

return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
m.cipherSuite == m1.cipherSuite &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
}

func (m *serverHelloMsg13) marshal() []byte {
if m.raw != nil {
return m.raw
}

length := 38
if m.keyShare.group != 0 {
length += 8 + len(m.keyShare.data)
}
if m.psk {
length += 6
}

x := make([]byte, 4+length)
x[0] = typeServerHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(m.cipherSuite >> 8)
x[39] = uint8(m.cipherSuite)

z := x[42:]
x[40] = uint8(len(z) >> 8)
x[41] = uint8(len(z))

if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}

if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
}

m.raw = x
return x
}

func (m *serverHelloMsg13) unmarshal(data []byte) alert {
if len(data) < 50 {
return alertDecodeError
}
m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38]
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
m.psk = false
m.pskIdentity = 0

extensionsLength := int(data[40])<<8 | int(data[41])
data = data[42:]
if len(data) != extensionsLength {
return alertDecodeError
}

for len(data) != 0 {
if len(data) < 4 {
return alertDecodeError
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return alertDecodeError
}
case extensionKeyShare:
d := data[:length]


switch extension {
default:
return alertDecodeError
case extensionPreSharedKey:
if length != 2 {
if len(d) < 4 {
return alertDecodeError return alertDecodeError
} }
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
case extensionKeyShare:
if length < 2 {
m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
l := int(d[2])<<8 | int(d[3])
d = d[4:]
if len(d) != l {
return alertDecodeError return alertDecodeError
} }
m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1])
if length-4 != int(data[2])<<8|int(data[3]) {
m.keyShare.data = d[:l]
case extensionPreSharedKey:
if length != 2 {
return alertDecodeError return alertDecodeError
} }
m.keyShare.data = data[4:length]
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
} }
data = data[length:] data = data[length:]
} }


+ 4
- 11
handshake_messages_test.go Visa fil

@@ -26,7 +26,6 @@ var tests = []interface{}{
&nextProtoMsg{}, &nextProtoMsg{},
&newSessionTicketMsg{}, &newSessionTicketMsg{},
&sessionState{}, &sessionState{},
&serverHelloMsg13{},
&encryptedExtensionsMsg{}, &encryptedExtensionsMsg{},
&certificateMsg13{}, &certificateMsg13{},
&newSessionTicketMsg13{}, &newSessionTicketMsg13{},
@@ -209,16 +208,10 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
} }
} }


return reflect.ValueOf(m)
}

func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg13{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.cipherSuite = uint16(rand.Int31())
m.keyShare.group = CurveID(rand.Intn(30000))
m.keyShare.data = randomBytes(rand.Intn(300), rand)
if rand.Intn(10) > 5 {
m.keyShare.group = CurveID(rand.Intn(30000))
m.keyShare.data = randomBytes(rand.Intn(300), rand)
}
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.psk = true m.psk = true
m.pskIdentity = uint16(rand.Int31()) m.pskIdentity = uint16(rand.Int31())


+ 6
- 7
handshake_server.go Visa fil

@@ -29,10 +29,10 @@ type serverHandshakeState struct {
masterSecret []byte masterSecret []byte
cachedClientHelloInfo *ClientHelloInfo cachedClientHelloInfo *ClientHelloInfo
clientHello *clientHelloMsg clientHello *clientHelloMsg
hello *serverHelloMsg
cert *Certificate cert *Certificate


// TLS 1.0-1.2 fields // TLS 1.0-1.2 fields
hello *serverHelloMsg
ellipticOk bool ellipticOk bool
ecdsaOk bool ecdsaOk bool
rsaDecryptOk bool rsaDecryptOk bool
@@ -42,7 +42,6 @@ type serverHandshakeState struct {
certsFromClient [][]byte certsFromClient [][]byte


// TLS 1.3 fields // TLS 1.3 fields
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg hello13Enc *encryptedExtensionsMsg
keySchedule *keySchedule13 keySchedule *keySchedule13
clientFinishedKey []byte clientFinishedKey []byte
@@ -70,7 +69,7 @@ func (c *Conn) serverHandshake() error {
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
// and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2 // and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2
c.buffering = true c.buffering = true
if hs.hello13 != nil {
if c.vers >= VersionTLS13 {
if err := hs.doTLS13Handshake(); err != nil { if err := hs.doTLS13Handshake(); err != nil {
return err return err
} }
@@ -250,11 +249,11 @@ Curves:
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone hs.hello.compressionMethod = compressionNone
} else { } else {
hs.hello13 = new(serverHelloMsg13)
hs.hello = new(serverHelloMsg)
hs.hello13Enc = new(encryptedExtensionsMsg) hs.hello13Enc = new(encryptedExtensionsMsg)
hs.hello13.vers = c.vers
hs.hello13.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hs.hello13.random)
hs.hello.vers = c.vers
hs.hello.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hs.hello.random)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return false, err return false, err


Laddar…
Avbryt
Spara