Add support for TLS 1.3 PSK resumption in Go.

Change-Id: I998f69269cdf813da19ccccc208b476f3501c8c4
Reviewed-on: https://boringssl-review.googlesource.com/8991
Reviewed-by: Steven Valdez <svaldez@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
This commit is contained in:
Nick Harper 2016-07-25 16:16:28 -07:00 committed by CQ bot account: commit-bot@chromium.org
parent afc64dec74
commit 0b3625bcfd
9 changed files with 220 additions and 98 deletions

View File

@ -41,6 +41,7 @@ const (
alertNoRenegotiation alert = 100 alertNoRenegotiation alert = 100
alertMissingExtension alert = 109 alertMissingExtension alert = 109
alertUnsupportedExtension alert = 110 alertUnsupportedExtension alert = 110
alertUnknownPSKIdentity alert = 115
) )
var alertText = map[alert]string{ var alertText = map[alert]string{
@ -69,6 +70,7 @@ var alertText = map[alert]string{
alertNoRenegotiation: "no renegotiation", alertNoRenegotiation: "no renegotiation",
alertMissingExtension: "missing extension", alertMissingExtension: "missing extension",
alertUnsupportedExtension: "unsupported extension", alertUnsupportedExtension: "unsupported extension",
alertUnknownPSKIdentity: "unknown PSK identity",
} }
func (e alert) String() string { func (e alert) String() string {

View File

@ -101,6 +101,27 @@ func (cs cipherSuite) hash() crypto.Hash {
return crypto.SHA256 return crypto.SHA256
} }
// TODO(nharper): Remove this function when TLS 1.3 cipher suites get
// refactored to break out the AEAD/PRF from everything else. Once that's
// done, this won't be necessary anymore.
func ecdhePSKSuite(id uint16) uint16 {
switch id {
case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256:
return TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256
case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256:
return TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256
case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384:
return TLS_ECDHE_PSK_WITH_AES_256_GCM_SHA384
}
return 0
}
var cipherSuites = []*cipherSuite{ var cipherSuites = []*cipherSuite{
// Ciphersuite order is chosen so that ECDHE comes before plain RSA // Ciphersuite order is chosen so that ECDHE comes before plain RSA
// and RC4 comes before AES (because of the Lucky13 attack). // and RC4 comes before AES (because of the Lucky13 attack).

View File

@ -28,7 +28,7 @@ const (
// The draft version of TLS 1.3 that is implemented here and sent in the draft // The draft version of TLS 1.3 that is implemented here and sent in the draft
// indicator extension. // indicator extension.
const tls13DraftVersion = 13 const tls13DraftVersion = 14
const ( const (
maxPlaintext = 16384 // maximum plaintext payload length maxPlaintext = 16384 // maximum plaintext payload length
@ -242,6 +242,10 @@ type ClientSessionState struct {
extendedMasterSecret bool // Whether an extended master secret was used to generate the session extendedMasterSecret bool // Whether an extended master secret was used to generate the session
sctList []byte sctList []byte
ocspResponse []byte ocspResponse []byte
ticketCreationTime time.Time
ticketExpiration time.Time
ticketFlags uint32
ticketAgeAdd uint32
} }
// ClientSessionCache is a cache of ClientSessionState objects that can be used // ClientSessionCache is a cache of ClientSessionState objects that can be used

View File

@ -1389,6 +1389,10 @@ func (c *Conn) handlePostHandshakeMessage() error {
serverCertificates: c.peerCertificates, serverCertificates: c.peerCertificates,
sctList: c.sctList, sctList: c.sctList,
ocspResponse: c.ocspResponse, ocspResponse: c.ocspResponse,
ticketCreationTime: c.config.time(),
ticketExpiration: c.config.time().Add(time.Duration(newSessionTicket.ticketLifetime) * time.Second),
ticketFlags: newSessionTicket.ticketFlags,
ticketAgeAdd: newSessionTicket.ticketAgeAdd,
} }
cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
@ -1667,11 +1671,10 @@ func (c *Conn) SendNewSessionTicket() error {
for _, cert := range c.peerCertificates { for _, cert := range c.peerCertificates {
peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw) peerCertificatesRaw = append(peerCertificatesRaw, cert.Raw)
} }
state := sessionState{
vers: c.vers, var ageAdd uint32
cipherSuite: c.cipherSuite.id, if err := binary.Read(c.config.rand(), binary.LittleEndian, &ageAdd); err != nil {
masterSecret: c.resumptionSecret, return err
certificates: peerCertificatesRaw,
} }
// TODO(davidben): Allow configuring these values. // TODO(davidben): Allow configuring these values.
@ -1679,7 +1682,20 @@ func (c *Conn) SendNewSessionTicket() error {
version: c.vers, version: c.vers,
ticketLifetime: uint32(24 * time.Hour / time.Second), ticketLifetime: uint32(24 * time.Hour / time.Second),
ticketFlags: ticketAllowDHEResumption | ticketAllowPSKResumption, ticketFlags: ticketAllowDHEResumption | ticketAllowPSKResumption,
ticketAgeAdd: ageAdd,
} }
state := sessionState{
vers: c.vers,
cipherSuite: c.cipherSuite.id,
masterSecret: c.resumptionSecret,
certificates: peerCertificatesRaw,
ticketCreationTime: c.config.time(),
ticketExpiration: c.config.time().Add(time.Duration(m.ticketLifetime) * time.Second),
ticketFlags: m.ticketFlags,
ticketAgeAdd: ageAdd,
}
if !c.config.Bugs.SendEmptySessionTicket { if !c.config.Bugs.SendEmptySessionTicket {
var err error var err error
m.ticket, err = c.encryptTicket(&state) m.ticket, err = c.encryptTicket(&state)

View File

@ -18,6 +18,7 @@ import (
"math/big" "math/big"
"net" "net"
"strconv" "strconv"
"time"
) )
type clientHandshakeState struct { type clientHandshakeState struct {
@ -197,6 +198,8 @@ NextCipherSuite:
// Try to resume a previously negotiated TLS session, if // Try to resume a previously negotiated TLS session, if
// available. // available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
// TODO(nharper): Support storing more than one session
// ticket for TLS 1.3.
candidateSession, ok := sessionCache.Get(cacheKey) candidateSession, ok := sessionCache.Get(cacheKey)
if ok { if ok {
ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil
@ -219,7 +222,7 @@ NextCipherSuite:
} }
} }
if session != nil { if session != nil && c.config.time().Before(session.ticketExpiration) {
ticket := session.sessionTicket ticket := session.sessionTicket
if c.config.Bugs.CorruptTicket && len(ticket) > 0 { if c.config.Bugs.CorruptTicket && len(ticket) > 0 {
ticket = make([]byte, len(session.sessionTicket)) ticket = make([]byte, len(session.sessionTicket))
@ -232,7 +235,21 @@ NextCipherSuite:
} }
if session.vers >= VersionTLS13 { if session.vers >= VersionTLS13 {
// TODO(davidben): Offer TLS 1.3 tickets. // TODO(nharper): Support sending more
// than one PSK identity.
if session.ticketFlags&ticketAllowDHEResumption != 0 {
var found bool
for _, id := range hello.cipherSuites {
if id == session.cipherSuite {
found = true
break
}
}
if found {
hello.pskIdentities = [][]uint8{ticket}
hello.cipherSuites = append(hello.cipherSuites, ecdhePSKSuite(session.cipherSuite))
}
}
} else if ticket != nil { } else if ticket != nil {
hello.sessionTicket = ticket hello.sessionTicket = ticket
// A random session ID is used to detect when the // A random session ID is used to detect when the
@ -411,7 +428,7 @@ NextCipherSuite:
} }
} }
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite)
if suite == nil { if suite == nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return fmt.Errorf("tls: server selected an unsupported cipher suite") return fmt.Errorf("tls: server selected an unsupported cipher suite")
@ -546,9 +563,18 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
return errors.New("tls: server omitted the PSK identity extension") return errors.New("tls: server omitted the PSK identity extension")
} }
// TODO(davidben): Support PSK ciphers and PSK resumption. Set // We send at most one PSK identity.
// the resumption context appropriately if resuming. if hs.session == nil || hs.serverHello.pskIdentity != 0 {
return errors.New("tls: PSK ciphers not implemented for TLS 1.3") c.sendAlert(alertUnknownPSKIdentity)
return errors.New("tls: server sent unknown PSK identity")
}
if ecdhePSKSuite(hs.session.cipherSuite) != hs.suite.id {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server sent invalid cipher suite for PSK")
}
psk = deriveResumptionPSK(hs.suite, hs.session.masterSecret)
hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.session.masterSecret))
c.didResume = true
} else { } else {
if hs.serverHello.hasPSKIdentity { if hs.serverHello.hasPSKIdentity {
c.sendAlert(alertUnsupportedExtension) c.sendAlert(alertUnsupportedExtension)
@ -626,6 +652,11 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
c.sendAlert(alertUnsupportedExtension) c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: server sent SCT list without a certificate") return errors.New("tls: server sent SCT list without a certificate")
} }
// Copy over authentication from the session.
c.peerCertificates = hs.session.serverCertificates
c.sctList = hs.session.sctList
c.ocspResponse = hs.session.ocspResponse
} else { } else {
c.ocspResponse = encryptedExtensions.extensions.ocspResponse c.ocspResponse = encryptedExtensions.extensions.ocspResponse
c.sctList = encryptedExtensions.extensions.sctList c.sctList = encryptedExtensions.extensions.sctList
@ -1223,6 +1254,7 @@ func (hs *clientHandshakeState) readSessionTicket() error {
serverCertificates: c.peerCertificates, serverCertificates: c.peerCertificates,
sctList: c.sctList, sctList: c.sctList,
ocspResponse: c.ocspResponse, ocspResponse: c.ocspResponse,
ticketExpiration: c.config.time().Add(time.Duration(7 * 24 * time.Hour)),
} }
if !hs.serverHello.extensions.ticketSupported { if !hs.serverHello.extensions.ticketSupported {

View File

@ -4,7 +4,10 @@
package runner package runner
import "bytes" import (
"bytes"
"encoding/binary"
)
func writeLen(buf []byte, v, size int) { func writeLen(buf []byte, v, size int) {
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@ -67,6 +70,13 @@ func (bb *byteBuilder) addU32(u uint32) {
*bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u)) *bb.buf = append(*bb.buf, byte(u>>24), byte(u>>16), byte(u>>8), byte(u))
} }
func (bb *byteBuilder) addU64(u uint64) {
bb.flush()
var b [8]byte
binary.BigEndian.PutUint64(b[:], u)
*bb.buf = append(*bb.buf, b[:]...)
}
func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder { func (bb *byteBuilder) addU8LengthPrefixed() *byteBuilder {
return bb.createChild(1) return bb.createChild(1)
} }
@ -79,6 +89,10 @@ func (bb *byteBuilder) addU24LengthPrefixed() *byteBuilder {
return bb.createChild(3) return bb.createChild(3)
} }
func (bb *byteBuilder) addU32LengthPrefixed() *byteBuilder {
return bb.createChild(4)
}
func (bb *byteBuilder) addBytes(b []byte) { func (bb *byteBuilder) addBytes(b []byte) {
bb.flush() bb.flush()
*bb.buf = append(*bb.buf, b...) *bb.buf = append(*bb.buf, b...)

View File

@ -305,22 +305,54 @@ Curves:
_, ecdsaOk := hs.cert.PrivateKey.(*ecdsa.PrivateKey) _, ecdsaOk := hs.cert.PrivateKey.(*ecdsa.PrivateKey)
// TODO(davidben): Implement PSK support. for i, pskIdentity := range hs.clientHello.pskIdentities {
pskOk := false sessionState, ok := c.decryptTicket(pskIdentity)
if !ok {
// Select the cipher suite. continue
var preferenceList, supportedList []uint16 }
if config.PreferServerCipherSuites { if sessionState.vers != c.vers {
preferenceList = config.cipherSuites() continue
supportedList = hs.clientHello.cipherSuites }
} else { if sessionState.ticketFlags&ticketAllowDHEResumption == 0 {
preferenceList = hs.clientHello.cipherSuites continue
supportedList = config.cipherSuites() }
if sessionState.ticketExpiration.Before(c.config.time()) {
continue
}
suiteId := ecdhePSKSuite(sessionState.cipherSuite)
suite := mutualCipherSuite(hs.clientHello.cipherSuites, suiteId)
var found bool
for _, id := range config.cipherSuites() {
if id == sessionState.cipherSuite {
found = true
break
}
}
if suite != nil && found {
hs.sessionState = sessionState
hs.suite = suite
hs.hello.hasPSKIdentity = true
hs.hello.pskIdentity = uint16(i)
c.didResume = true
break
}
} }
for _, id := range preferenceList { // If not resuming, select the cipher suite.
if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, pskOk); hs.suite != nil { if hs.suite == nil {
break var preferenceList, supportedList []uint16
if config.PreferServerCipherSuites {
preferenceList = config.cipherSuites()
supportedList = hs.clientHello.cipherSuites
} else {
preferenceList = hs.clientHello.cipherSuites
supportedList = config.cipherSuites()
}
for _, id := range preferenceList {
if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, supportedCurve, ecdsaOk, false); hs.suite != nil {
break
}
} }
} }
@ -339,9 +371,19 @@ Curves:
hs.writeClientHash(hs.clientHello.marshal()) hs.writeClientHash(hs.clientHello.marshal())
// Resolve PSK and compute the early secret. // Resolve PSK and compute the early secret.
// TODO(davidben): Implement PSK in TLS 1.3. var psk []byte
psk := hs.finishedHash.zeroSecret() // The only way for hs.suite to be a PSK suite yet for there to be
hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret()) // no sessionState is if config.Bugs.EnableAllCiphers is true and
// the test runner forced us to negotiated a PSK suite. It doesn't
// really matter what we do here so long as we continue the
// handshake and let the client error out.
if hs.suite.flags&suitePSK != 0 && hs.sessionState != nil {
psk = deriveResumptionPSK(hs.suite, hs.sessionState.masterSecret)
hs.finishedHash.setResumptionContext(deriveResumptionContext(hs.suite, hs.sessionState.masterSecret))
} else {
psk = hs.finishedHash.zeroSecret()
hs.finishedHash.setResumptionContext(hs.finishedHash.zeroSecret())
}
earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk) earlySecret := hs.finishedHash.extractKey(hs.finishedHash.zeroSecret(), psk)
@ -494,9 +536,7 @@ Curves:
c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite) c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite)
c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite) c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite)
if hs.suite.flags&suitePSK != 0 { if hs.suite.flags&suitePSK == 0 {
return errors.New("tls: PSK ciphers not implemented for TLS 1.3")
} else {
if hs.clientHello.ocspStapling { if hs.clientHello.ocspStapling {
encryptedExtensions.extensions.ocspResponse = hs.cert.OCSPStaple encryptedExtensions.extensions.ocspResponse = hs.cert.OCSPStaple
} }
@ -574,6 +614,15 @@ Curves:
hs.writeServerHash(certVerify.marshal()) hs.writeServerHash(certVerify.marshal())
c.writeRecord(recordTypeHandshake, certVerify.marshal()) c.writeRecord(recordTypeHandshake, certVerify.marshal())
} else {
// Pick up certificates from the session instead.
// hs.sessionState may be nil if config.Bugs.EnableAllCiphers is
// true.
if hs.sessionState != nil && len(hs.sessionState.certificates) > 0 {
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
return err
}
}
} }
finished := new(finishedMsg) finished := new(finishedMsg)

View File

@ -497,3 +497,11 @@ func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte,
func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte { func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte {
return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size()) return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size())
} }
func deriveResumptionPSK(suite *cipherSuite, resumptionSecret []byte) []byte {
return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption psk"), nil, suite.hash().Size())
}
func deriveResumptionContext(suite *cipherSuite, resumptionSecret []byte) []byte {
return hkdfExpandLabel(suite.hash(), resumptionSecret, []byte("resumption context"), nil, suite.hash().Size())
}

View File

@ -5,14 +5,15 @@
package runner package runner
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/binary"
"errors" "errors"
"io" "io"
"time"
) )
// sessionState contains the information that is serialized into a session // sessionState contains the information that is serialized into a session
@ -24,79 +25,40 @@ type sessionState struct {
handshakeHash []byte handshakeHash []byte
certificates [][]byte certificates [][]byte
extendedMasterSecret bool extendedMasterSecret bool
} ticketCreationTime time.Time
ticketExpiration time.Time
func (s *sessionState) equal(i interface{}) bool { ticketFlags uint32
s1, ok := i.(*sessionState) ticketAgeAdd uint32
if !ok {
return false
}
if s.vers != s1.vers ||
s.cipherSuite != s1.cipherSuite ||
!bytes.Equal(s.masterSecret, s1.masterSecret) ||
!bytes.Equal(s.handshakeHash, s1.handshakeHash) ||
s.extendedMasterSecret != s1.extendedMasterSecret {
return false
}
if len(s.certificates) != len(s1.certificates) {
return false
}
for i := range s.certificates {
if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
return false
}
}
return true
} }
func (s *sessionState) marshal() []byte { func (s *sessionState) marshal() []byte {
length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2 msg := newByteBuilder()
msg.addU16(s.vers)
msg.addU16(s.cipherSuite)
masterSecret := msg.addU16LengthPrefixed()
masterSecret.addBytes(s.masterSecret)
handshakeHash := msg.addU16LengthPrefixed()
handshakeHash.addBytes(s.handshakeHash)
msg.addU16(uint16(len(s.certificates)))
for _, cert := range s.certificates { for _, cert := range s.certificates {
length += 4 + len(cert) certMsg := msg.addU32LengthPrefixed()
} certMsg.addBytes(cert)
length++
ret := make([]byte, length)
x := ret
x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers)
x[2] = byte(s.cipherSuite >> 8)
x[3] = byte(s.cipherSuite)
x[4] = byte(len(s.masterSecret) >> 8)
x[5] = byte(len(s.masterSecret))
x = x[6:]
copy(x, s.masterSecret)
x = x[len(s.masterSecret):]
x[0] = byte(len(s.handshakeHash) >> 8)
x[1] = byte(len(s.handshakeHash))
x = x[2:]
copy(x, s.handshakeHash)
x = x[len(s.handshakeHash):]
x[0] = byte(len(s.certificates) >> 8)
x[1] = byte(len(s.certificates))
x = x[2:]
for _, cert := range s.certificates {
x[0] = byte(len(cert) >> 24)
x[1] = byte(len(cert) >> 16)
x[2] = byte(len(cert) >> 8)
x[3] = byte(len(cert))
copy(x[4:], cert)
x = x[4+len(cert):]
} }
if s.extendedMasterSecret { if s.extendedMasterSecret {
x[0] = 1 msg.addU8(1)
} else {
msg.addU8(0)
} }
x = x[1:]
return ret if s.vers >= VersionTLS13 {
msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
msg.addU64(uint64(s.ticketExpiration.UnixNano()))
msg.addU32(s.ticketFlags)
msg.addU32(s.ticketAgeAdd)
}
return msg.finish()
} }
func (s *sessionState) unmarshal(data []byte) bool { func (s *sessionState) unmarshal(data []byte) bool {
@ -162,6 +124,20 @@ func (s *sessionState) unmarshal(data []byte) bool {
} }
data = data[1:] data = data[1:]
if s.vers >= VersionTLS13 {
if len(data) < 24 {
return false
}
s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
data = data[8:]
s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
data = data[8:]
s.ticketFlags = binary.BigEndian.Uint32(data)
data = data[4:]
s.ticketAgeAdd = binary.BigEndian.Uint32(data)
data = data[4:]
}
if len(data) > 0 { if len(data) > 0 {
return false return false
} }