diff --git a/13.go b/13.go index 2cf8fee..e529517 100644 --- a/13.go +++ b/13.go @@ -612,7 +612,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { continue } - hs.keySchedule.setSecret(s.resumptionSecret) + hs.keySchedule.setSecret(s.pskSecret) binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) chHash := hash.New() @@ -660,18 +660,18 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return nil } - resumptionSecret := hs.keySchedule.deriveSecret(secretResumption) + resumptionMasterSecret := hs.keySchedule.deriveSecret(secretResumption) ageAddBuf := make([]byte, 4) sessionState := &sessionState13{ - vers: c.vers, - suite: hs.suite.id, - createdAt: uint64(time.Now().Unix()), - resumptionSecret: resumptionSecret, - alpnProtocol: c.clientProtocol, - SNI: c.serverName, - maxEarlyDataLen: c.config.Max0RTTDataSize, + vers: c.vers, + suite: hs.suite.id, + createdAt: uint64(time.Now().Unix()), + alpnProtocol: c.clientProtocol, + SNI: c.serverName, + maxEarlyDataLen: c.config.Max0RTTDataSize, } + hash := hashForSuite(hs.suite) for i := 0; i < numSessionTickets; i++ { if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil { @@ -680,6 +680,12 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { } sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]) + // ticketNonce must be a unique value for this connection. + // Assume there are no more than 255 tickets, otherwise two + // tickets might have the same PSK which could be a problem if + // one of them is compromised. + ticketNonce := []byte{byte(i)} + sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size()) ticket := sessionState.marshal() var err error if c.config.SessionTicketSealer != nil { @@ -700,6 +706,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { maxEarlyDataLength: c.config.Max0RTTDataSize, withEarlyDataInfo: c.config.Max0RTTDataSize > 0, ageAdd: sessionState.ageAdd, + nonce: ticketNonce, ticket: ticket, } if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil { diff --git a/handshake_messages.go b/handshake_messages.go index 77b4c6c..e9c433c 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -2146,6 +2146,7 @@ type newSessionTicketMsg13 struct { raw []byte lifetime uint32 ageAdd uint32 + nonce []byte ticket []byte withEarlyDataInfo bool maxEarlyDataLength uint32 @@ -2160,6 +2161,7 @@ func (m *newSessionTicketMsg13) equal(i interface{}) bool { return bytes.Equal(m.raw, m1.raw) && m.lifetime == m1.lifetime && m.ageAdd == m1.ageAdd && + bytes.Equal(m.nonce, m1.nonce) && bytes.Equal(m.ticket, m1.ticket) && m.withEarlyDataInfo == m1.withEarlyDataInfo && m.maxEarlyDataLength == m1.maxEarlyDataLength @@ -2170,9 +2172,10 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) { return m.raw } - // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1 + nonceLen := len(m.nonce) ticketLen := len(m.ticket) - length := 12 + ticketLen + length := 13 + nonceLen + ticketLen if m.withEarlyDataInfo { length += 8 } @@ -2191,15 +2194,20 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) { x[10] = uint8(m.ageAdd >> 8) x[11] = uint8(m.ageAdd) - x[12] = uint8(ticketLen >> 8) - x[13] = uint8(ticketLen) - copy(x[14:], m.ticket) + x[12] = uint8(nonceLen) + copy(x[13:13+nonceLen], m.nonce) + + y := x[13+nonceLen:] + y[0] = uint8(ticketLen >> 8) + y[1] = uint8(ticketLen) + copy(y[2:2+ticketLen], m.ticket) if m.withEarlyDataInfo { - z := x[14+ticketLen:] + z := y[2+ticketLen:] + // z[0] is already 0, this is the extensions vector length. z[1] = 8 - z[2] = uint8(extensionTicketEarlyDataInfo >> 8) - z[3] = uint8(extensionTicketEarlyDataInfo) + z[2] = uint8(extensionEarlyData >> 8) + z[3] = uint8(extensionEarlyData) z[5] = 4 z[6] = uint8(m.maxEarlyDataLength >> 24) z[7] = uint8(m.maxEarlyDataLength >> 16) @@ -2217,7 +2225,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { m.maxEarlyDataLength = 0 m.withEarlyDataInfo = false - if len(data) < 16 { + if len(data) < 17 { return alertDecodeError } @@ -2231,13 +2239,20 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 | uint32(data[10])<<8 | uint32(data[11]) - ticketLen := int(data[12])<<8 + int(data[13]) - if 14+ticketLen > len(data) { + nonceLen := int(data[12]) + if nonceLen == 0 || 13+nonceLen+2 > len(data) { return alertDecodeError } - m.ticket = data[14 : 14+ticketLen] + m.nonce = data[13 : 13+nonceLen] - data = data[14+ticketLen:] + data = data[13+nonceLen:] + ticketLen := int(data[0])<<8 + int(data[1]) + if ticketLen == 0 || 2+ticketLen+2 > len(data) { + return alertDecodeError + } + m.ticket = data[2 : 2+ticketLen] + + data = data[2+ticketLen:] extLen := int(data[0])<<8 + int(data[1]) if extLen != len(data)-2 { return alertDecodeError @@ -2253,7 +2268,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { data = data[4:] switch extType { - case extensionTicketEarlyDataInfo: + case extensionEarlyData: if length != 4 { return alertDecodeError } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 5b23a27..79d8119 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -317,7 +317,8 @@ func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value m := &newSessionTicketMsg13{} m.ageAdd = uint32(rand.Intn(0xffffffff)) m.lifetime = uint32(rand.Intn(0xffffffff)) - m.ticket = randomBytes(rand.Intn(40), rand) + m.nonce = randomBytes(1+rand.Intn(255), rand) + m.ticket = randomBytes(1+rand.Intn(40), rand) if rand.Intn(10) > 5 { m.withEarlyDataInfo = true m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff)) @@ -345,7 +346,7 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { s.ageAdd = uint32(rand.Intn(0xffffffff)) s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) - s.resumptionSecret = randomBytes(rand.Intn(100), rand) + s.pskSecret = randomBytes(rand.Intn(100), rand) s.alpnProtocol = randomString(rand.Intn(100), rand) s.SNI = randomString(rand.Intn(100), rand) return reflect.ValueOf(s) diff --git a/ticket.go b/ticket.go index b8de132..32b40d8 100644 --- a/ticket.go +++ b/ticket.go @@ -150,14 +150,14 @@ func (s *sessionState) unmarshal(data []byte) alert { } type sessionState13 struct { - vers uint16 - suite uint16 - ageAdd uint32 - createdAt uint64 - maxEarlyDataLen uint32 - resumptionSecret []byte - alpnProtocol string - SNI string + vers uint16 + suite uint16 + ageAdd uint32 + createdAt uint64 + maxEarlyDataLen uint32 + pskSecret []byte + alpnProtocol string + SNI string } func (s *sessionState13) equal(i interface{}) bool { @@ -171,13 +171,13 @@ func (s *sessionState13) equal(i interface{}) bool { s.ageAdd == s1.ageAdd && s.createdAt == s1.createdAt && s.maxEarlyDataLen == s1.maxEarlyDataLen && - bytes.Equal(s.resumptionSecret, s1.resumptionSecret) && + subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 && s.alpnProtocol == s1.alpnProtocol && s.SNI == s1.SNI } func (s *sessionState13) marshal() []byte { - length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) + length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) x := make([]byte, length) x[0] = byte(s.vers >> 8) @@ -200,10 +200,10 @@ func (s *sessionState13) marshal() []byte { x[17] = byte(s.maxEarlyDataLen >> 16) x[18] = byte(s.maxEarlyDataLen >> 8) x[19] = byte(s.maxEarlyDataLen) - x[20] = byte(len(s.resumptionSecret) >> 8) - x[21] = byte(len(s.resumptionSecret)) - copy(x[22:], s.resumptionSecret) - z := x[22+len(s.resumptionSecret):] + x[20] = byte(len(s.pskSecret) >> 8) + x[21] = byte(len(s.pskSecret)) + copy(x[22:], s.pskSecret) + z := x[22+len(s.pskSecret):] z[0] = byte(len(s.alpnProtocol) >> 8) z[1] = byte(len(s.alpnProtocol)) copy(z[2:], s.alpnProtocol) @@ -231,7 +231,7 @@ func (s *sessionState13) unmarshal(data []byte) alert { if len(data) < 22+l+2 { return alertDecodeError } - s.resumptionSecret = data[22 : 22+l] + s.pskSecret = data[22 : 22+l] z := data[22+l:] l = int(z[0])<<8 | int(z[1])