// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package trs import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/sha256" "crypto/subtle" "errors" "io" ) // A SessionTicketSealer provides a way to securely encapsulate // session state for storage on the client. All methods are safe for // concurrent use. type SessionTicketSealer interface { // Seal returns a session ticket value that can be later passed to Unseal // to recover the content, usually by encrypting it. The ticket will be sent // to the client to be stored, and will be sent back in plaintext, so it can // be read and modified by an attacker. Seal(cs *ConnectionState, content []byte) (ticket []byte, err error) // Unseal returns a session ticket contents. The ticket can't be safely // assumed to have been generated by Seal. // If unable to unseal the ticket, the connection will proceed with a // complete handshake. Unseal(chi *ClientHelloInfo, ticket []byte) (content []byte, success bool) } // sessionState contains the information that is serialized into a session // ticket in order to later resume a connection. type sessionState struct { vers uint16 cipherSuite uint16 usedEMS bool masterSecret []byte certificates [][]byte // usedOldKey is true if the ticket from which this session came from // was encrypted with an older key and thus should be refreshed. usedOldKey bool } func (s *sessionState) equal(i interface{}) bool { s1, ok := i.(*sessionState) if !ok { return false } if s.vers != s1.vers || s.usedEMS != s1.usedEMS || s.cipherSuite != s1.cipherSuite || !bytes.Equal(s.masterSecret, s1.masterSecret) { return false } if len(s.certificates) != len(s1.certificates) { return false } for i := range s.certificates { if !bytes.Equal(s.certificates[i], s1.certificates[i]) { return false } } return true } func (s *sessionState) marshal() []byte { length := 2 + 2 + 2 + len(s.masterSecret) + 2 for _, cert := range s.certificates { length += 4 + len(cert) } ret := make([]byte, length) x := ret was_used := byte(0) if s.usedEMS { was_used = byte(0x80) } x[0] = byte(s.vers>>8) | byte(was_used) x[1] = byte(s.vers) x[2] = byte(s.cipherSuite >> 8) x[3] = byte(s.cipherSuite) x[4] = byte(len(s.masterSecret) >> 8) x[5] = byte(len(s.masterSecret)) x = x[6:] copy(x, s.masterSecret) x = x[len(s.masterSecret):] x[0] = byte(len(s.certificates) >> 8) x[1] = byte(len(s.certificates)) x = x[2:] for _, cert := range s.certificates { x[0] = byte(len(cert) >> 24) x[1] = byte(len(cert) >> 16) x[2] = byte(len(cert) >> 8) x[3] = byte(len(cert)) copy(x[4:], cert) x = x[4+len(cert):] } return ret } func (s *sessionState) unmarshal(data []byte) alert { if len(data) < 8 { return alertDecodeError } s.vers = (uint16(data[0])<<8 | uint16(data[1])) & 0x7fff s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) s.usedEMS = (data[0] & 0x80) == 0x80 masterSecretLen := int(data[4])<<8 | int(data[5]) data = data[6:] if len(data) < masterSecretLen { return alertDecodeError } s.masterSecret = data[:masterSecretLen] data = data[masterSecretLen:] if len(data) < 2 { return alertDecodeError } numCerts := int(data[0])<<8 | int(data[1]) data = data[2:] s.certificates = make([][]byte, numCerts) for i := range s.certificates { if len(data) < 4 { return alertDecodeError } certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) data = data[4:] if certLen < 0 { return alertDecodeError } if len(data) < certLen { return alertDecodeError } s.certificates[i] = data[:certLen] data = data[certLen:] } if len(data) != 0 { return alertDecodeError } return alertSuccess } type sessionState13 struct { vers uint16 suite uint16 ageAdd uint32 createdAt uint64 maxEarlyDataLen uint32 pskSecret []byte alpnProtocol string SNI string } func (s *sessionState13) equal(i interface{}) bool { s1, ok := i.(*sessionState13) if !ok { return false } return s.vers == s1.vers && s.suite == s1.suite && s.ageAdd == s1.ageAdd && s.createdAt == s1.createdAt && s.maxEarlyDataLen == s1.maxEarlyDataLen && subtle.ConstantTimeCompare(s.pskSecret, s1.pskSecret) == 1 && s.alpnProtocol == s1.alpnProtocol && s.SNI == s1.SNI } func (s *sessionState13) marshal() []byte { length := 2 + 2 + 4 + 8 + 4 + 2 + len(s.pskSecret) + 2 + len(s.alpnProtocol) + 2 + len(s.SNI) x := make([]byte, length) x[0] = byte(s.vers >> 8) x[1] = byte(s.vers) x[2] = byte(s.suite >> 8) x[3] = byte(s.suite) x[4] = byte(s.ageAdd >> 24) x[5] = byte(s.ageAdd >> 16) x[6] = byte(s.ageAdd >> 8) x[7] = byte(s.ageAdd) x[8] = byte(s.createdAt >> 56) x[9] = byte(s.createdAt >> 48) x[10] = byte(s.createdAt >> 40) x[11] = byte(s.createdAt >> 32) x[12] = byte(s.createdAt >> 24) x[13] = byte(s.createdAt >> 16) x[14] = byte(s.createdAt >> 8) x[15] = byte(s.createdAt) x[16] = byte(s.maxEarlyDataLen >> 24) x[17] = byte(s.maxEarlyDataLen >> 16) x[18] = byte(s.maxEarlyDataLen >> 8) x[19] = byte(s.maxEarlyDataLen) x[20] = byte(len(s.pskSecret) >> 8) x[21] = byte(len(s.pskSecret)) copy(x[22:], s.pskSecret) z := x[22+len(s.pskSecret):] z[0] = byte(len(s.alpnProtocol) >> 8) z[1] = byte(len(s.alpnProtocol)) copy(z[2:], s.alpnProtocol) z = z[2+len(s.alpnProtocol):] z[0] = byte(len(s.SNI) >> 8) z[1] = byte(len(s.SNI)) copy(z[2:], s.SNI) return x } func (s *sessionState13) unmarshal(data []byte) alert { if len(data) < 24 { return alertDecodeError } s.vers = uint16(data[0])<<8 | uint16(data[1]) s.suite = uint16(data[2])<<8 | uint16(data[3]) s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) s.maxEarlyDataLen = uint32(data[16])<<24 | uint32(data[17])<<16 | uint32(data[18])<<8 | uint32(data[19]) l := int(data[20])<<8 | int(data[21]) if len(data) < 22+l+2 { return alertDecodeError } s.pskSecret = data[22 : 22+l] z := data[22+l:] l = int(z[0])<<8 | int(z[1]) if len(z) < 2+l+2 { return alertDecodeError } s.alpnProtocol = string(z[2 : 2+l]) z = z[2+l:] l = int(z[0])<<8 | int(z[1]) if len(z) != 2+l { return alertDecodeError } s.SNI = string(z[2 : 2+l]) return alertSuccess } func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size) keyName := encrypted[:ticketKeyNameLen] iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] macBytes := encrypted[len(encrypted)-sha256.Size:] if _, err := io.ReadFull(c.config.rand(), iv); err != nil { return nil, err } key := c.config.ticketKeys()[0] copy(keyName, key.keyName[:]) block, err := aes.NewCipher(key.aesKey[:]) if err != nil { return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) } cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized) mac := hmac.New(sha256.New, key.hmacKey[:]) mac.Write(encrypted[:len(encrypted)-sha256.Size]) mac.Sum(macBytes[:0]) return encrypted, nil } func (c *Conn) decryptTicket(encrypted []byte) (serialized []byte, usedOldKey bool) { if c.config.SessionTicketsDisabled || len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size { return nil, false } keyName := encrypted[:ticketKeyNameLen] iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize] macBytes := encrypted[len(encrypted)-sha256.Size:] keys := c.config.ticketKeys() keyIndex := -1 for i, candidateKey := range keys { if bytes.Equal(keyName, candidateKey.keyName[:]) { keyIndex = i break } } if keyIndex == -1 { return nil, false } key := &keys[keyIndex] mac := hmac.New(sha256.New, key.hmacKey[:]) mac.Write(encrypted[:len(encrypted)-sha256.Size]) expected := mac.Sum(nil) if subtle.ConstantTimeCompare(macBytes, expected) != 1 { return nil, false } block, err := aes.NewCipher(key.aesKey[:]) if err != nil { return nil, false } ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size] plaintext := ciphertext cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) return plaintext, keyIndex > 0 }