Kaynağa Gözat

tris: unify ServerHello processing in preparation for D22

Merge serverHelloMsg13 into serverHelloMsg in preparation for draft 22.
This will also simplify the client implementation since only one
structure has to be checked.

Also fixed potential out-of-bounds access with keyShare unmarshal.
tls13
Peter Wu 7 yıl önce
ebeveyn
işleme
4e6ebb63dd
6 değiştirilmiş dosya ile 111 ekleme ve 169 silme
  1. +6
    -6
      13.go
  2. +1
    -0
      common.go
  3. +1
    -5
      conn.go
  4. +93
    -140
      handshake_messages.go
  5. +4
    -11
      handshake_messages_test.go
  6. +6
    -7
      handshake_server.go

+ 6
- 6
13.go Dosyayı Görüntüle

@@ -127,7 +127,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
config := hs.c.config
c := hs.c

hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id
hs.c.cipherSuite, hs.hello.cipherSuite = hs.suite.id, hs.suite.id
hs.c.clientHello = hs.clientHello.marshal()

// When picking the group for the handshake, priority is given to groups
@@ -159,7 +159,7 @@ CurvePreferenceLoop:
c.sendAlert(alertInternalError)
return err
}
hs.hello13.keyShare = serverKS
hs.hello.keyShare = serverKS

hash := hashForSuite(hs.suite)
hashSize := hash.Size()
@@ -188,8 +188,8 @@ CurvePreferenceLoop:
return errors.New("tls: bad ECDHE client share")
}

hs.keySchedule.write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
hs.keySchedule.write(hs.hello.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
return err
}

@@ -603,8 +603,8 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
hs.hello13Enc.earlyData = true
}
}
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
hs.hello.psk = true
hs.hello.pskIdentity = uint16(i)
return true, alertSuccess
}



+ 1
- 0
common.go Dosyayı Görüntüle

@@ -28,6 +28,7 @@ const (
VersionTLS12 = 0x0303
VersionTLS13 = 0x0304
VersionTLS13Draft18 = 0x7f00 | 18
VersionTLS13Draft21 = 0x7f00 | 21
)

const (


+ 1
- 5
conn.go Dosyayı Görüntüle

@@ -1131,11 +1131,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeClientHello:
m = new(clientHelloMsg)
case typeServerHello:
if c.vers >= VersionTLS13 {
m = new(serverHelloMsg13)
} else {
m = new(serverHelloMsg)
}
m = new(serverHelloMsg)
case typeEncryptedExtensions:
m = new(encryptedExtensionsMsg)
case typeNewSessionTicket:


+ 93
- 140
handshake_messages.go Dosyayı Görüntüle

@@ -711,6 +711,11 @@ type serverHelloMsg struct {
secureRenegotiation []byte
secureRenegotiationSupported bool
alpnProtocol string

// TLS 1.3
keyShare keyShare
psk bool
pskIdentity uint16
}

func (m *serverHelloMsg) equal(i interface{}) bool {
@@ -740,7 +745,11 @@ func (m *serverHelloMsg) equal(i interface{}) bool {
m.ticketSupported == m1.ticketSupported &&
m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
m.alpnProtocol == m1.alpnProtocol
m.alpnProtocol == m1.alpnProtocol &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
}

func (m *serverHelloMsg) marshal() []byte {
@@ -748,7 +757,12 @@ func (m *serverHelloMsg) marshal() []byte {
return m.raw
}

oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
length := 38 + len(m.sessionId)
if oldTLS13Draft {
// no compression method, no session ID.
length = 36
}
numExtensions := 0
extensionsLength := 0

@@ -786,6 +800,14 @@ func (m *serverHelloMsg) marshal() []byte {
extensionsLength += 2 + sctLen
numExtensions++
}
if m.keyShare.group != 0 {
extensionsLength += 4 + len(m.keyShare.data)
numExtensions++
}
if m.psk {
extensionsLength += 2
numExtensions++
}

if numExtensions > 0 {
extensionsLength += 4 * numExtensions
@@ -800,14 +822,22 @@ func (m *serverHelloMsg) marshal() []byte {
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z := x[39+len(m.sessionId):]
z := x[38:]
if !oldTLS13Draft {
x[38] = uint8(len(m.sessionId))
copy(x[39:39+len(m.sessionId)], m.sessionId)
z = x[39+len(m.sessionId):]
}
z[0] = uint8(m.cipherSuite >> 8)
z[1] = uint8(m.cipherSuite)
z[2] = m.compressionMethod
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
z = z[2:]
} else {
z[2] = m.compressionMethod
z = z[3:]
}

z = z[3:]
if numExtensions > 0 {
z[0] = byte(extensionsLength >> 8)
z[1] = byte(extensionsLength)
@@ -882,6 +912,30 @@ func (m *serverHelloMsg) marshal() []byte {
}
}

if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
z = z[8+l:]
}

if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}

m.raw = x

return x
@@ -893,19 +947,29 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
}
m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5])
oldTLS13Draft := m.vers >= VersionTLS13Draft18 && m.vers <= VersionTLS13Draft21
m.random = data[6:38]
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
if !oldTLS13Draft {
sessionIdLen := int(data[38])
if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
return alertDecodeError
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
} else {
data = data[38:]
}
m.sessionId = data[39 : 39+sessionIdLen]
data = data[39+sessionIdLen:]
if len(data) < 3 {
return alertDecodeError
}
m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
m.compressionMethod = data[2]
data = data[3:]
if oldTLS13Draft {
// no compression method in older TLS 1.3 drafts.
data = data[2:]
} else {
m.compressionMethod = data[2]
data = data[3:]
}

m.nextProtoNeg = false
m.nextProtos = nil
@@ -913,6 +977,10 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = nil
m.ticketSupported = false
m.alpnProtocol = ""
m.keyShare.group = 0
m.keyShare.data = nil
m.psk = false
m.pskIdentity = 0

if len(data) == 0 {
// ServerHello is optionally followed by extension data
@@ -1020,140 +1088,25 @@ func (m *serverHelloMsg) unmarshal(data []byte) alert {
m.scts = append(m.scts, d[:sctLen])
d = d[sctLen:]
}
}
data = data[length:]
}

return alertSuccess
}

type serverHelloMsg13 struct {
raw []byte
vers uint16
random []byte
cipherSuite uint16
keyShare keyShare
psk bool
pskIdentity uint16
}

func (m *serverHelloMsg13) equal(i interface{}) bool {
m1, ok := i.(*serverHelloMsg13)
if !ok {
return false
}

return bytes.Equal(m.raw, m1.raw) &&
m.vers == m1.vers &&
bytes.Equal(m.random, m1.random) &&
m.cipherSuite == m1.cipherSuite &&
m.keyShare.group == m1.keyShare.group &&
bytes.Equal(m.keyShare.data, m1.keyShare.data) &&
m.psk == m1.psk &&
m.pskIdentity == m1.pskIdentity
}

func (m *serverHelloMsg13) marshal() []byte {
if m.raw != nil {
return m.raw
}

length := 38
if m.keyShare.group != 0 {
length += 8 + len(m.keyShare.data)
}
if m.psk {
length += 6
}

x := make([]byte, 4+length)
x[0] = typeServerHello
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
x[3] = uint8(length)
x[4] = uint8(m.vers >> 8)
x[5] = uint8(m.vers)
copy(x[6:38], m.random)
x[38] = uint8(m.cipherSuite >> 8)
x[39] = uint8(m.cipherSuite)

z := x[42:]
x[40] = uint8(len(z) >> 8)
x[41] = uint8(len(z))

if m.psk {
z[0] = byte(extensionPreSharedKey >> 8)
z[1] = byte(extensionPreSharedKey)
z[3] = 2
z[4] = byte(m.pskIdentity >> 8)
z[5] = byte(m.pskIdentity)
z = z[6:]
}

if m.keyShare.group != 0 {
z[0] = uint8(extensionKeyShare >> 8)
z[1] = uint8(extensionKeyShare)
l := 4 + len(m.keyShare.data)
z[2] = uint8(l >> 8)
z[3] = uint8(l)
z[4] = uint8(m.keyShare.group >> 8)
z[5] = uint8(m.keyShare.group)
l -= 4
z[6] = uint8(l >> 8)
z[7] = uint8(l)
copy(z[8:], m.keyShare.data)
}

m.raw = x
return x
}

func (m *serverHelloMsg13) unmarshal(data []byte) alert {
if len(data) < 50 {
return alertDecodeError
}
m.raw = data
m.vers = uint16(data[4])<<8 | uint16(data[5])
m.random = data[6:38]
m.cipherSuite = uint16(data[38])<<8 | uint16(data[39])
m.psk = false
m.pskIdentity = 0

extensionsLength := int(data[40])<<8 | int(data[41])
data = data[42:]
if len(data) != extensionsLength {
return alertDecodeError
}

for len(data) != 0 {
if len(data) < 4 {
return alertDecodeError
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return alertDecodeError
}
case extensionKeyShare:
d := data[:length]

switch extension {
default:
return alertDecodeError
case extensionPreSharedKey:
if length != 2 {
if len(d) < 4 {
return alertDecodeError
}
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
case extensionKeyShare:
if length < 2 {
m.keyShare.group = CurveID(d[0])<<8 | CurveID(d[1])
l := int(d[2])<<8 | int(d[3])
d = d[4:]
if len(d) != l {
return alertDecodeError
}
m.keyShare.group = CurveID(data[0])<<8 | CurveID(data[1])
if length-4 != int(data[2])<<8|int(data[3]) {
m.keyShare.data = d[:l]
case extensionPreSharedKey:
if length != 2 {
return alertDecodeError
}
m.keyShare.data = data[4:length]
m.psk = true
m.pskIdentity = uint16(data[0])<<8 | uint16(data[1])
}
data = data[length:]
}


+ 4
- 11
handshake_messages_test.go Dosyayı Görüntüle

@@ -26,7 +26,6 @@ var tests = []interface{}{
&nextProtoMsg{},
&newSessionTicketMsg{},
&sessionState{},
&serverHelloMsg13{},
&encryptedExtensionsMsg{},
&certificateMsg13{},
&newSessionTicketMsg13{},
@@ -209,16 +208,10 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
}
}

return reflect.ValueOf(m)
}

func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
m := &serverHelloMsg13{}
m.vers = uint16(rand.Intn(65536))
m.random = randomBytes(32, rand)
m.cipherSuite = uint16(rand.Int31())
m.keyShare.group = CurveID(rand.Intn(30000))
m.keyShare.data = randomBytes(rand.Intn(300), rand)
if rand.Intn(10) > 5 {
m.keyShare.group = CurveID(rand.Intn(30000))
m.keyShare.data = randomBytes(rand.Intn(300), rand)
}
if rand.Intn(10) > 5 {
m.psk = true
m.pskIdentity = uint16(rand.Int31())


+ 6
- 7
handshake_server.go Dosyayı Görüntüle

@@ -29,10 +29,10 @@ type serverHandshakeState struct {
masterSecret []byte
cachedClientHelloInfo *ClientHelloInfo
clientHello *clientHelloMsg
hello *serverHelloMsg
cert *Certificate

// TLS 1.0-1.2 fields
hello *serverHelloMsg
ellipticOk bool
ecdsaOk bool
rsaDecryptOk bool
@@ -42,7 +42,6 @@ type serverHandshakeState struct {
certsFromClient [][]byte

// TLS 1.3 fields
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
keySchedule *keySchedule13
clientFinishedKey []byte
@@ -70,7 +69,7 @@ func (c *Conn) serverHandshake() error {
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
// and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2
c.buffering = true
if hs.hello13 != nil {
if c.vers >= VersionTLS13 {
if err := hs.doTLS13Handshake(); err != nil {
return err
}
@@ -250,11 +249,11 @@ Curves:
hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
hs.hello.compressionMethod = compressionNone
} else {
hs.hello13 = new(serverHelloMsg13)
hs.hello = new(serverHelloMsg)
hs.hello13Enc = new(encryptedExtensionsMsg)
hs.hello13.vers = c.vers
hs.hello13.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hs.hello13.random)
hs.hello.vers = c.vers
hs.hello.random = make([]byte, 32)
_, err = io.ReadFull(c.config.rand(), hs.hello.random)
if err != nil {
c.sendAlert(alertInternalError)
return false, err


Yükleniyor…
İptal
Kaydet