tris: update NewSessionTicket for draft -19 and -21

D19: use early_data instead of custom ticket_early_data_info extension
codepoint. D21: new ticket nonce field and change in PSK calculation.
This nonce provides some minor security advantage in case one of the PSK
is compromised (which would leak the resumption master secret).

Rename "resumptionSecret" to "pskSecret" in sessionState13 to reflect
the D21 change and use constant-time comparison for the secret.

Also fix potential panic if the ticket is large enough, but the
extensions are missing.
This commit is contained in:
Peter Wu 2017-11-12 02:11:32 +00:00 committed by Peter Wu
parent fd93e9ecf6
commit 69dddf0612
4 changed files with 63 additions and 40 deletions

25
13.go
View File

@ -612,7 +612,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
continue continue
} }
hs.keySchedule.setSecret(s.resumptionSecret) hs.keySchedule.setSecret(s.pskSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New() chHash := hash.New()
@ -660,18 +660,18 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return nil return nil
} }
resumptionSecret := hs.keySchedule.deriveSecret(secretResumption) resumptionMasterSecret := hs.keySchedule.deriveSecret(secretResumption)
ageAddBuf := make([]byte, 4) ageAddBuf := make([]byte, 4)
sessionState := &sessionState13{ sessionState := &sessionState13{
vers: c.vers, vers: c.vers,
suite: hs.suite.id, suite: hs.suite.id,
createdAt: uint64(time.Now().Unix()), createdAt: uint64(time.Now().Unix()),
resumptionSecret: resumptionSecret, alpnProtocol: c.clientProtocol,
alpnProtocol: c.clientProtocol, SNI: c.serverName,
SNI: c.serverName, maxEarlyDataLen: c.config.Max0RTTDataSize,
maxEarlyDataLen: c.config.Max0RTTDataSize,
} }
hash := hashForSuite(hs.suite)
for i := 0; i < numSessionTickets; i++ { for i := 0; i < numSessionTickets; i++ {
if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil { if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil {
@ -680,6 +680,12 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
} }
sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]) uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3])
// ticketNonce must be a unique value for this connection.
// Assume there are no more than 255 tickets, otherwise two
// tickets might have the same PSK which could be a problem if
// one of them is compromised.
ticketNonce := []byte{byte(i)}
sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
ticket := sessionState.marshal() ticket := sessionState.marshal()
var err error var err error
if c.config.SessionTicketSealer != nil { if c.config.SessionTicketSealer != nil {
@ -700,6 +706,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
maxEarlyDataLength: c.config.Max0RTTDataSize, maxEarlyDataLength: c.config.Max0RTTDataSize,
withEarlyDataInfo: c.config.Max0RTTDataSize > 0, withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
ageAdd: sessionState.ageAdd, ageAdd: sessionState.ageAdd,
nonce: ticketNonce,
ticket: ticket, ticket: ticket,
} }
if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil {

View File

@ -2146,6 +2146,7 @@ type newSessionTicketMsg13 struct {
raw []byte raw []byte
lifetime uint32 lifetime uint32
ageAdd uint32 ageAdd uint32
nonce []byte
ticket []byte ticket []byte
withEarlyDataInfo bool withEarlyDataInfo bool
maxEarlyDataLength uint32 maxEarlyDataLength uint32
@ -2160,6 +2161,7 @@ func (m *newSessionTicketMsg13) equal(i interface{}) bool {
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
m.lifetime == m1.lifetime && m.lifetime == m1.lifetime &&
m.ageAdd == m1.ageAdd && m.ageAdd == m1.ageAdd &&
bytes.Equal(m.nonce, m1.nonce) &&
bytes.Equal(m.ticket, m1.ticket) && bytes.Equal(m.ticket, m1.ticket) &&
m.withEarlyDataInfo == m1.withEarlyDataInfo && m.withEarlyDataInfo == m1.withEarlyDataInfo &&
m.maxEarlyDataLength == m1.maxEarlyDataLength m.maxEarlyDataLength == m1.maxEarlyDataLength
@ -2170,9 +2172,10 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) {
return m.raw return m.raw
} }
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.6 // See https://tools.ietf.org/html/draft-ietf-tls-tls13-21#section-4.6.1
nonceLen := len(m.nonce)
ticketLen := len(m.ticket) ticketLen := len(m.ticket)
length := 12 + ticketLen length := 13 + nonceLen + ticketLen
if m.withEarlyDataInfo { if m.withEarlyDataInfo {
length += 8 length += 8
} }
@ -2191,15 +2194,20 @@ func (m *newSessionTicketMsg13) marshal() (x []byte) {
x[10] = uint8(m.ageAdd >> 8) x[10] = uint8(m.ageAdd >> 8)
x[11] = uint8(m.ageAdd) x[11] = uint8(m.ageAdd)
x[12] = uint8(ticketLen >> 8) x[12] = uint8(nonceLen)
x[13] = uint8(ticketLen) copy(x[13:13+nonceLen], m.nonce)
copy(x[14:], m.ticket)
y := x[13+nonceLen:]
y[0] = uint8(ticketLen >> 8)
y[1] = uint8(ticketLen)
copy(y[2:2+ticketLen], m.ticket)
if m.withEarlyDataInfo { if m.withEarlyDataInfo {
z := x[14+ticketLen:] z := y[2+ticketLen:]
// z[0] is already 0, this is the extensions vector length.
z[1] = 8 z[1] = 8
z[2] = uint8(extensionTicketEarlyDataInfo >> 8) z[2] = uint8(extensionEarlyData >> 8)
z[3] = uint8(extensionTicketEarlyDataInfo) z[3] = uint8(extensionEarlyData)
z[5] = 4 z[5] = 4
z[6] = uint8(m.maxEarlyDataLength >> 24) z[6] = uint8(m.maxEarlyDataLength >> 24)
z[7] = uint8(m.maxEarlyDataLength >> 16) z[7] = uint8(m.maxEarlyDataLength >> 16)
@ -2217,7 +2225,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
m.maxEarlyDataLength = 0 m.maxEarlyDataLength = 0
m.withEarlyDataInfo = false m.withEarlyDataInfo = false
if len(data) < 16 { if len(data) < 17 {
return alertDecodeError return alertDecodeError
} }
@ -2231,13 +2239,20 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 | m.ageAdd = uint32(data[8])<<24 | uint32(data[9])<<16 |
uint32(data[10])<<8 | uint32(data[11]) uint32(data[10])<<8 | uint32(data[11])
ticketLen := int(data[12])<<8 + int(data[13]) nonceLen := int(data[12])
if 14+ticketLen > len(data) { if nonceLen == 0 || 13+nonceLen+2 > len(data) {
return alertDecodeError return alertDecodeError
} }
m.ticket = data[14 : 14+ticketLen] m.nonce = data[13 : 13+nonceLen]
data = data[14+ticketLen:] data = data[13+nonceLen:]
ticketLen := int(data[0])<<8 + int(data[1])
if ticketLen == 0 || 2+ticketLen+2 > len(data) {
return alertDecodeError
}
m.ticket = data[2 : 2+ticketLen]
data = data[2+ticketLen:]
extLen := int(data[0])<<8 + int(data[1]) extLen := int(data[0])<<8 + int(data[1])
if extLen != len(data)-2 { if extLen != len(data)-2 {
return alertDecodeError return alertDecodeError
@ -2253,7 +2268,7 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert {
data = data[4:] data = data[4:]
switch extType { switch extType {
case extensionTicketEarlyDataInfo: case extensionEarlyData:
if length != 4 { if length != 4 {
return alertDecodeError return alertDecodeError
} }

View File

@ -317,7 +317,8 @@ func (*newSessionTicketMsg13) Generate(rand *rand.Rand, size int) reflect.Value
m := &newSessionTicketMsg13{} m := &newSessionTicketMsg13{}
m.ageAdd = uint32(rand.Intn(0xffffffff)) m.ageAdd = uint32(rand.Intn(0xffffffff))
m.lifetime = uint32(rand.Intn(0xffffffff)) m.lifetime = uint32(rand.Intn(0xffffffff))
m.ticket = randomBytes(rand.Intn(40), rand) m.nonce = randomBytes(1+rand.Intn(255), rand)
m.ticket = randomBytes(1+rand.Intn(40), rand)
if rand.Intn(10) > 5 { if rand.Intn(10) > 5 {
m.withEarlyDataInfo = true m.withEarlyDataInfo = true
m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff)) m.maxEarlyDataLength = uint32(rand.Intn(0xffffffff))
@ -345,7 +346,7 @@ func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
s.ageAdd = uint32(rand.Intn(0xffffffff)) s.ageAdd = uint32(rand.Intn(0xffffffff))
s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff)) s.maxEarlyDataLen = uint32(rand.Intn(0xffffffff))
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
s.resumptionSecret = randomBytes(rand.Intn(100), rand) s.pskSecret = randomBytes(rand.Intn(100), rand)
s.alpnProtocol = randomString(rand.Intn(100), rand) s.alpnProtocol = randomString(rand.Intn(100), rand)
s.SNI = randomString(rand.Intn(100), rand) s.SNI = randomString(rand.Intn(100), rand)
return reflect.ValueOf(s) return reflect.ValueOf(s)

View File

@ -150,14 +150,14 @@ func (s *sessionState) unmarshal(data []byte) alert {
} }
type sessionState13 struct { type sessionState13 struct {
vers uint16 vers uint16
suite uint16 suite uint16
ageAdd uint32 ageAdd uint32
createdAt uint64 createdAt uint64
maxEarlyDataLen uint32 maxEarlyDataLen uint32
resumptionSecret []byte pskSecret []byte
alpnProtocol string alpnProtocol string
SNI string SNI string
} }
func (s *sessionState13) equal(i interface{}) bool { func (s *sessionState13) equal(i interface{}) bool {
@ -171,13 +171,13 @@ func (s *sessionState13) equal(i interface{}) bool {
s.ageAdd == s1.ageAdd && s.ageAdd == s1.ageAdd &&
s.createdAt == s1.createdAt && s.createdAt == s1.createdAt &&
s.maxEarlyDataLen == s1.maxEarlyDataLen && s.maxEarlyDataLen == s1.maxEarlyDataLen &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret) && subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 &&
s.alpnProtocol == s1.alpnProtocol && s.alpnProtocol == s1.alpnProtocol &&
s.SNI == s1.SNI s.SNI == s1.SNI
} }
func (s *sessionState13) marshal() []byte { func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
x := make([]byte, length) x := make([]byte, length)
x[0] = byte(s.vers >> 8) x[0] = byte(s.vers >> 8)
@ -200,10 +200,10 @@ func (s *sessionState13) marshal() []byte {
x[17] = byte(s.maxEarlyDataLen >> 16) x[17] = byte(s.maxEarlyDataLen >> 16)
x[18] = byte(s.maxEarlyDataLen >> 8) x[18] = byte(s.maxEarlyDataLen >> 8)
x[19] = byte(s.maxEarlyDataLen) x[19] = byte(s.maxEarlyDataLen)
x[20] = byte(len(s.resumptionSecret) >> 8) x[20] = byte(len(s.pskSecret) >> 8)
x[21] = byte(len(s.resumptionSecret)) x[21] = byte(len(s.pskSecret))
copy(x[22:], s.resumptionSecret) copy(x[22:], s.pskSecret)
z := x[22+len(s.resumptionSecret):] z := x[22+len(s.pskSecret):]
z[0] = byte(len(s.alpnProtocol) >> 8) z[0] = byte(len(s.alpnProtocol) >> 8)
z[1] = byte(len(s.alpnProtocol)) z[1] = byte(len(s.alpnProtocol))
copy(z[2:], s.alpnProtocol) copy(z[2:], s.alpnProtocol)
@ -231,7 +231,7 @@ func (s *sessionState13) unmarshal(data []byte) alert {
if len(data) < 22+l+2 { if len(data) < 22+l+2 {
return alertDecodeError return alertDecodeError
} }
s.resumptionSecret = data[22 : 22+l] s.pskSecret = data[22 : 22+l]
z := data[22+l:] z := data[22+l:]
l = int(z[0])<<8 | int(z[1]) l = int(z[0])<<8 | int(z[1])