From 69dddf0612e2ef3ec73f1ec79d511ed4badeb622 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 12 Nov 2017 02:11:32 +0000 Subject: [PATCH] tris: update NewSessionTicket for draft -19 and -21 D19: use early_data instead of custom ticket_early_data_info extension codepoint. D21: new ticket nonce field and change in PSK calculation. This nonce provides some minor security advantage in case one of the PSK is compromised (which would leak the resumption master secret). Rename "resumptionSecret" to "pskSecret" in sessionState13 to reflect the D21 change and use constant-time comparison for the secret. Also fix potential panic if the ticket is large enough, but the extensions are missing. --- 13.go | 25 ++++++++++++++-------- handshake_messages.go | 43 +++++++++++++++++++++++++------------- handshake_messages_test.go | 5 +++-- ticket.go | 30 +++++++++++++------------- 4 files changed, 63 insertions(+), 40 deletions(-) diff --git a/13.go b/13.go index 2cf8fee..e529517 100644 --- a/13.go +++ b/13.go @@ -612,7 +612,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { continue } - hs.keySchedule.setSecret(s.resumptionSecret) + hs.keySchedule.setSecret(s.pskSecret) binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) chHash := hash.New() @@ -660,18 +660,18 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return nil } - resumptionSecret := hs.keySchedule.deriveSecret(secretResumption) + resumptionMasterSecret := hs.keySchedule.deriveSecret(secretResumption) ageAddBuf := make([]byte, 4) sessionState := &sessionState13{ - vers: c.vers, - suite: hs.suite.id, - createdAt: uint64(time.Now().Unix()), - resumptionSecret: resumptionSecret, - alpnProtocol: c.clientProtocol, - SNI: c.serverName, - maxEarlyDataLen: c.config.Max0RTTDataSize, + vers: c.vers, + suite: hs.suite.id, + createdAt: uint64(time.Now().Unix()), + alpnProtocol: c.clientProtocol, + SNI: c.serverName, + maxEarlyDataLen: c.config.Max0RTTDataSize, } + hash := hashForSuite(hs.suite) for i := 0; i < numSessionTickets; i++ { if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil { @@ -680,6 +680,12 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { } sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]) + // ticketNonce must be a unique value for this connection. + // Assume there are no more than 255 tickets, otherwise two + // tickets might have the same PSK which could be a problem if + // one of them is compromised. + ticketNonce := []byte{byte(i)} + sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size()) ticket := sessionState.marshal() var err error if c.config.SessionTicketSealer != nil { @@ -700,6 +706,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { maxEarlyDataLength: c.config.Max0RTTDataSize, withEarlyDataInfo: c.config.Max0RTTDataSize > 0, ageAdd: sessionState.ageAdd, + nonce: ticketNonce, ticket: ticket, } if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil { diff --git a/handshake_messages.go b/handshake_messages.go index 77b4c6c..e9c433c 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -2146,6 +2146,7 @@ type newSessionTicketMsg13 struct { raw []byte lifetime uint32 ageAdd uint32 + nonce []byte ticket []byte withEarlyDataInfo bool maxEarlyDataLength uint32 @@ -2160,6 +2161,7 @@ func (m *newSessionTicketMsg13) equal(i interface{}) bool { return bytes.Equal(m.raw, m1.raw) && m.lifetime == m1.lifetime && m.ageAdd == m1.ageAdd && + bytes.Equal(m.nonce, m1.nonce) && bytes.Equal(m.ticket, m1.ticket) && m.withEarlyDataInfo == m1.withEarlyDataInfo && m.maxEarlyDataLength == m1.maxEarlyDataLength @@ -2170,9 +2172,10 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) { return m.raw } - // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1 + nonceLen := len(m.nonce) ticketLen := len(m.ticket) - length := 12 + ticketLen + length := 13 + nonceLen + ticketLen if m.withEarlyDataInfo { length += 8 } @@ -2191,15 +2194,20 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) { x[10] = uint8(m.ageAdd >> 8) x[11] = uint8(m.ageAdd) - x[12] = uint8(ticketLen >> 8) - x[13] = uint8(ticketLen) - copy(x[14:], m.ticket) + x[12] = uint8(nonceLen) + copy(x[13:13+nonceLen], m.nonce) + + y := x[13+nonceLen:] + y[0] = uint8(ticketLen >> 8) + y[1] = uint8(ticketLen) + copy(y[2:2+ticketLen], m.ticket) if m.withEarlyDataInfo { - z := x[14+ticketLen:] + z := y[2+ticketLen:] + // z[0] is already 0, this is the extensions vector length. z[1] = 8 - z[2] = uint8(extensionTicketEarlyDataInfo >> 8) - z[3] = uint8(extensionTicketEarlyDataInfo) + z[2] = uint8(extensionEarlyData >> 8) + z[3] = uint8(extensionEarlyData) z[5] = 4 z[6] = uint8(m.maxEarlyDataLength >> 24) z[7] = uint8(m.maxEarlyDataLength >> 16) @@ -2217,7 +2225,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { m.maxEarlyDataLength = 0 m.withEarlyDataInfo = false - if len(data) < 16 { + if len(data) < 17 { return alertDecodeError } @@ -2231,13 +2239,20 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 | uint32(data[10])<<8 | uint32(data[11]) - ticketLen := int(data[12])<<8 + int(data[13]) - if 14+ticketLen > len(data) { + nonceLen := int(data[12]) + if nonceLen == 0 || 13+nonceLen+2 > len(data) { return alertDecodeError } - m.ticket = data[14 : 14+ticketLen] + m.nonce = data[13 : 13+nonceLen] - data = data[14+ticketLen:] + data = data[13+nonceLen:] + ticketLen := int(data[0])<<8 + int(data[1]) + if ticketLen == 0 || 2+ticketLen+2 > len(data) { + return alertDecodeError + } + m.ticket = data[2 : 2+ticketLen] + + data = data[2+ticketLen:] extLen := int(data[0])<<8 + int(data[1]) if extLen != len(data)-2 { return alertDecodeError @@ -2253,7 +2268,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { data = data[4:] switch extType { - case extensionTicketEarlyDataInfo: + case extensionEarlyData: if length != 4 { return alertDecodeError } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 5b23a27..79d8119 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -317,7 +317,8 @@ func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value m := &newSessionTicketMsg13{} m.ageAdd = uint32(rand.Intn(0xffffffff)) m.lifetime = uint32(rand.Intn(0xffffffff)) - m.ticket = randomBytes(rand.Intn(40), rand) + m.nonce = randomBytes(1+rand.Intn(255), rand) + m.ticket = randomBytes(1+rand.Intn(40), rand) if rand.Intn(10) > 5 { m.withEarlyDataInfo = true m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff)) @@ -345,7 +346,7 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { s.ageAdd = uint32(rand.Intn(0xffffffff)) s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) - s.resumptionSecret = randomBytes(rand.Intn(100), rand) + s.pskSecret = randomBytes(rand.Intn(100), rand) s.alpnProtocol = randomString(rand.Intn(100), rand) s.SNI = randomString(rand.Intn(100), rand) return reflect.ValueOf(s) diff --git a/ticket.go b/ticket.go index b8de132..32b40d8 100644 --- a/ticket.go +++ b/ticket.go @@ -150,14 +150,14 @@ func (s *sessionState) unmarshal(data []byte) alert { } type sessionState13 struct { - vers uint16 - suite uint16 - ageAdd uint32 - createdAt uint64 - maxEarlyDataLen uint32 - resumptionSecret []byte - alpnProtocol string - SNI string + vers uint16 + suite uint16 + ageAdd uint32 + createdAt uint64 + maxEarlyDataLen uint32 + pskSecret []byte + alpnProtocol string + SNI string } func (s *sessionState13) equal(i interface{}) bool { @@ -171,13 +171,13 @@ func (s *sessionState13) equal(i interface{}) bool { s.ageAdd == s1.ageAdd && s.createdAt == s1.createdAt && s.maxEarlyDataLen == s1.maxEarlyDataLen && - bytes.Equal(s.resumptionSecret, s1.resumptionSecret) && + subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 && s.alpnProtocol == s1.alpnProtocol && s.SNI == s1.SNI } func (s *sessionState13) marshal() []byte { - length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) + length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) x := make([]byte, length) x[0] = byte(s.vers >> 8) @@ -200,10 +200,10 @@ func (s *sessionState13) marshal() []byte { x[17] = byte(s.maxEarlyDataLen >> 16) x[18] = byte(s.maxEarlyDataLen >> 8) x[19] = byte(s.maxEarlyDataLen) - x[20] = byte(len(s.resumptionSecret) >> 8) - x[21] = byte(len(s.resumptionSecret)) - copy(x[22:], s.resumptionSecret) - z := x[22+len(s.resumptionSecret):] + x[20] = byte(len(s.pskSecret) >> 8) + x[21] = byte(len(s.pskSecret)) + copy(x[22:], s.pskSecret) + z := x[22+len(s.pskSecret):] z[0] = byte(len(s.alpnProtocol) >> 8) z[1] = byte(len(s.alpnProtocol)) copy(z[2:], s.alpnProtocol) @@ -231,7 +231,7 @@ func (s *sessionState13) unmarshal(data []byte) alert { if len(data) < 22+l+2 { return alertDecodeError } - s.resumptionSecret = data[22 : 22+l] + s.pskSecret = data[22 : 22+l] z := data[22+l:] l = int(z[0])<<8 | int(z[1])