th5/ticket.go
Peter Wu 69dddf0612 tris: update NewSessionTicket for draft -19 and -21
D19: use early_data instead of custom ticket_early_data_info extension
codepoint. D21: new ticket nonce field and change in PSK calculation.
This nonce provides some minor security advantage in case one of the PSK
is compromised (which would leak the resumption master secret).

Rename "resumptionSecret" to "pskSecret" in sessionState13 to reflect
the D21 change and use constant-time comparison for the secret.

Also fix potential panic if the ticket is large enough, but the
extensions are missing.
2017-12-13 17:49:25 +00:00

319 lines
8.2 KiB
Go

// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package tls
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"errors"
"io"
)
// A SessionTicketSealer provides a way to securely encapsulate
// session state for storage on the client. All methods are safe for
// concurrent use.
type SessionTicketSealer interface {
// Seal returns a session ticket value that can be later passed to Unseal
// to recover the content, usually by encrypting it. The ticket will be sent
// to the client to be stored, and will be sent back in plaintext, so it can
// be read and modified by an attacker.
Seal(cs *ConnectionState, content []byte) (ticket []byte, err error)
// Unseal returns a session ticket contents. The ticket can't be safely
// assumed to have been generated by Seal.
// If unable to unseal the ticket, the connection will proceed with a
// complete handshake.
Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool)
}
// sessionState contains the information that is serialized into a session
// ticket in order to later resume a connection.
type sessionState struct {
vers uint16
cipherSuite uint16
masterSecret []byte
certificates [][]byte
// usedOldKey is true if the ticket from which this session came from
// was encrypted with an older key and thus should be refreshed.
usedOldKey bool
}
func (s *sessionState) equal(i interface{}) bool {
s1, ok := i.(*sessionState)
if !ok {
return false
}
if s.vers != s1.vers ||
s.cipherSuite != s1.cipherSuite ||
!bytes.Equal(s.masterSecret, s1.masterSecret) {
return false
}
if len(s.certificates) != len(s1.certificates) {
return false
}
for i := range s.certificates {
if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
return false
}
}
return true
}
func (s *sessionState) marshal() []byte {
length := 2 + 2 + 2 + len(s.masterSecret) + 2
for _, cert := range s.certificates {
length += 4 + len(cert)
}
ret := make([]byte, length)
x := ret
x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers)
x[2] = byte(s.cipherSuite >> 8)
x[3] = byte(s.cipherSuite)
x[4] = byte(len(s.masterSecret) >> 8)
x[5] = byte(len(s.masterSecret))
x = x[6:]
copy(x, s.masterSecret)
x = x[len(s.masterSecret):]
x[0] = byte(len(s.certificates) >> 8)
x[1] = byte(len(s.certificates))
x = x[2:]
for _, cert := range s.certificates {
x[0] = byte(len(cert) >> 24)
x[1] = byte(len(cert) >> 16)
x[2] = byte(len(cert) >> 8)
x[3] = byte(len(cert))
copy(x[4:], cert)
x = x[4+len(cert):]
}
return ret
}
func (s *sessionState) unmarshal(data []byte) alert {
if len(data) < 8 {
return alertDecodeError
}
s.vers = uint16(data[0])<<8 | uint16(data[1])
s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
masterSecretLen := int(data[4])<<8 | int(data[5])
data = data[6:]
if len(data) < masterSecretLen {
return alertDecodeError
}
s.masterSecret = data[:masterSecretLen]
data = data[masterSecretLen:]
if len(data) < 2 {
return alertDecodeError
}
numCerts := int(data[0])<<8 | int(data[1])
data = data[2:]
s.certificates = make([][]byte, numCerts)
for i := range s.certificates {
if len(data) < 4 {
return alertDecodeError
}
certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
data = data[4:]
if certLen < 0 {
return alertDecodeError
}
if len(data) < certLen {
return alertDecodeError
}
s.certificates[i] = data[:certLen]
data = data[certLen:]
}
if len(data) != 0 {
return alertDecodeError
}
return alertSuccess
}
type sessionState13 struct {
vers uint16
suite uint16
ageAdd uint32
createdAt uint64
maxEarlyDataLen uint32
pskSecret []byte
alpnProtocol string
SNI string
}
func (s *sessionState13) equal(i interface{}) bool {
s1, ok := i.(*sessionState13)
if !ok {
return false
}
return s.vers == s1.vers &&
s.suite == s1.suite &&
s.ageAdd == s1.ageAdd &&
s.createdAt == s1.createdAt &&
s.maxEarlyDataLen == s1.maxEarlyDataLen &&
subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 &&
s.alpnProtocol == s1.alpnProtocol &&
s.SNI == s1.SNI
}
func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI)
x := make([]byte, length)
x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers)
x[2] = byte(s.suite >> 8)
x[3] = byte(s.suite)
x[4] = byte(s.ageAdd >> 24)
x[5] = byte(s.ageAdd >> 16)
x[6] = byte(s.ageAdd >> 8)
x[7] = byte(s.ageAdd)
x[8] = byte(s.createdAt >> 56)
x[9] = byte(s.createdAt >> 48)
x[10] = byte(s.createdAt >> 40)
x[11] = byte(s.createdAt >> 32)
x[12] = byte(s.createdAt >> 24)
x[13] = byte(s.createdAt >> 16)
x[14] = byte(s.createdAt >> 8)
x[15] = byte(s.createdAt)
x[16] = byte(s.maxEarlyDataLen >> 24)
x[17] = byte(s.maxEarlyDataLen >> 16)
x[18] = byte(s.maxEarlyDataLen >> 8)
x[19] = byte(s.maxEarlyDataLen)
x[20] = byte(len(s.pskSecret) >> 8)
x[21] = byte(len(s.pskSecret))
copy(x[22:], s.pskSecret)
z := x[22+len(s.pskSecret):]
z[0] = byte(len(s.alpnProtocol) >> 8)
z[1] = byte(len(s.alpnProtocol))
copy(z[2:], s.alpnProtocol)
z = z[2+len(s.alpnProtocol):]
z[0] = byte(len(s.SNI) >> 8)
z[1] = byte(len(s.SNI))
copy(z[2:], s.SNI)
return x
}
func (s *sessionState13) unmarshal(data []byte) alert {
if len(data) < 24 {
return alertDecodeError
}
s.vers = uint16(data[0])<<8 | uint16(data[1])
s.suite = uint16(data[2])<<8 | uint16(data[3])
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19])
l := int(data[20])<<8 | int(data[21])
if len(data) < 22+l+2 {
return alertDecodeError
}
s.pskSecret = data[22 : 22+l]
z := data[22+l:]
l = int(z[0])<<8 | int(z[1])
if len(z) < 2+l+2 {
return alertDecodeError
}
s.alpnProtocol = string(z[2 : 2+l])
z = z[2+l:]
l = int(z[0])<<8 | int(z[1])
if len(z) != 2+l {
return alertDecodeError
}
s.SNI = string(z[2 : 2+l])
return alertSuccess
}
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
return nil, err
}
key := c.config.ticketKeys()[0]
copy(keyName, key.keyName[:])
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
}
cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
mac.Sum(macBytes[:0])
return encrypted, nil
}
func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) {
if c.config.SessionTicketsDisabled ||
len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
return nil, false
}
keyName := encrypted[:ticketKeyNameLen]
iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
macBytes := encrypted[len(encrypted)-sha256.Size:]
keys := c.config.ticketKeys()
keyIndex := -1
for i, candidateKey := range keys {
if bytes.Equal(keyName, candidateKey.keyName[:]) {
keyIndex = i
break
}
}
if keyIndex == -1 {
return nil, false
}
key := &keys[keyIndex]
mac := hmac.New(sha256.New, key.hmacKey[:])
mac.Write(encrypted[:len(encrypted)-sha256.Size])
expected := mac.Sum(nil)
if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
return nil, false
}
block, err := aes.NewCipher(key.aesKey[:])
if err != nil {
return nil, false
}
ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
plaintext := ciphertext
cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
return plaintext, keyIndex > 0
}