crypto/tls: Client side support for TLS session resumption.
Adam (agl@) had already done an initial review of this CL in a branch. Added ClientSessionState to Config which now allows clients to keep state required to resume a TLS session with a server. A client handshake will try and use the SessionTicket/MasterSecret in this cached state if the server acknowledged resumption. We also added support to cache ClientSessionState object in Config that will be looked up by server remote address during the handshake. R=golang-codereviews, agl, rsc, agl, agl, bradfitz, mikioh.mikioh CC=golang-codereviews https://golang.org/cl/15680043
This commit is contained in:
parent
6f38414b48
commit
9323f900fd
99
common.go
99
common.go
@ -5,6 +5,7 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
@ -173,6 +174,29 @@ const (
|
||||
RequireAndVerifyClientCert
|
||||
)
|
||||
|
||||
// ClientSessionState contains the state needed by clients to resume TLS
|
||||
// sessions.
|
||||
type ClientSessionState struct {
|
||||
sessionTicket []uint8 // Encrypted ticket used for session resumption with server
|
||||
vers uint16 // SSL/TLS version negotiated for the session
|
||||
cipherSuite uint16 // Ciphersuite negotiated for the session
|
||||
masterSecret []byte // MasterSecret generated by client on a full handshake
|
||||
serverCertificates []*x509.Certificate // Certificate chain presented by the server
|
||||
}
|
||||
|
||||
// ClientSessionCache is a cache of ClientSessionState objects that can be used
|
||||
// by a client to resume a TLS session with a given server. ClientSessionCache
|
||||
// implementations should expect to be called concurrently from different
|
||||
// goroutines.
|
||||
type ClientSessionCache interface {
|
||||
// Get searches for a ClientSessionState associated with the given key.
|
||||
// On return, ok is true if one was found.
|
||||
Get(sessionKey string) (session *ClientSessionState, ok bool)
|
||||
|
||||
// Put adds the ClientSessionState to the cache with the given key.
|
||||
Put(sessionKey string, cs *ClientSessionState)
|
||||
}
|
||||
|
||||
// A Config structure is used to configure a TLS client or server. After one
|
||||
// has been passed to a TLS function it must not be modified.
|
||||
type Config struct {
|
||||
@ -251,6 +275,10 @@ type Config struct {
|
||||
// connections using that key are compromised.
|
||||
SessionTicketKey [32]byte
|
||||
|
||||
// SessionCache is a cache of ClientSessionState entries for TLS session
|
||||
// resumption.
|
||||
ClientSessionCache ClientSessionCache
|
||||
|
||||
// MinVersion contains the minimum SSL/TLS version that is acceptable.
|
||||
// If zero, then SSLv3 is taken as the minimum.
|
||||
MinVersion uint16
|
||||
@ -412,6 +440,77 @@ type handshakeMessage interface {
|
||||
unmarshal([]byte) bool
|
||||
}
|
||||
|
||||
// lruSessionCache is a ClientSessionCache implementation that uses an LRU
|
||||
// caching strategy.
|
||||
type lruSessionCache struct {
|
||||
sync.Mutex
|
||||
|
||||
m map[string]*list.Element
|
||||
q *list.List
|
||||
capacity int
|
||||
}
|
||||
|
||||
type lruSessionCacheEntry struct {
|
||||
sessionKey string
|
||||
state *ClientSessionState
|
||||
}
|
||||
|
||||
// NewLRUClientSessionCache returns a ClientSessionCache with the given
|
||||
// capacity that uses an LRU strategy. If capacity is < 1, a default capacity
|
||||
// is used instead.
|
||||
func NewLRUClientSessionCache(capacity int) ClientSessionCache {
|
||||
const defaultSessionCacheCapacity = 64
|
||||
|
||||
if capacity < 1 {
|
||||
capacity = defaultSessionCacheCapacity
|
||||
}
|
||||
return &lruSessionCache{
|
||||
m: make(map[string]*list.Element),
|
||||
q: list.New(),
|
||||
capacity: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// Put adds the provided (sessionKey, cs) pair to the cache.
|
||||
func (c *lruSessionCache) Put(sessionKey string, cs *ClientSessionState) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if elem, ok := c.m[sessionKey]; ok {
|
||||
entry := elem.Value.(*lruSessionCacheEntry)
|
||||
entry.state = cs
|
||||
c.q.MoveToFront(elem)
|
||||
return
|
||||
}
|
||||
|
||||
if c.q.Len() < c.capacity {
|
||||
entry := &lruSessionCacheEntry{sessionKey, cs}
|
||||
c.m[sessionKey] = c.q.PushFront(entry)
|
||||
return
|
||||
}
|
||||
|
||||
elem := c.q.Back()
|
||||
entry := elem.Value.(*lruSessionCacheEntry)
|
||||
delete(c.m, entry.sessionKey)
|
||||
entry.sessionKey = sessionKey
|
||||
entry.state = cs
|
||||
c.q.MoveToFront(elem)
|
||||
c.m[sessionKey] = elem
|
||||
}
|
||||
|
||||
// Get returns the ClientSessionState value associated with a given key. It
|
||||
// returns (nil, false) if no value is found.
|
||||
func (c *lruSessionCache) Get(sessionKey string) (*ClientSessionState, bool) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
if elem, ok := c.m[sessionKey]; ok {
|
||||
c.q.MoveToFront(elem)
|
||||
return elem.Value.(*lruSessionCacheEntry).state, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// TODO(jsing): Make these available to both crypto/x509 and crypto/tls.
|
||||
type dsaSignature struct {
|
||||
R, S *big.Int
|
||||
|
2
conn.go
2
conn.go
@ -799,6 +799,8 @@ func (c *Conn) readHandshake() (interface{}, error) {
|
||||
m = new(clientHelloMsg)
|
||||
case typeServerHello:
|
||||
m = new(serverHelloMsg)
|
||||
case typeNewSessionTicket:
|
||||
m = new(newSessionTicketMsg)
|
||||
case typeCertificate:
|
||||
m = new(certificateMsg)
|
||||
case typeCertificateRequest:
|
||||
|
@ -13,9 +13,20 @@ import (
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type clientHandshakeState struct {
|
||||
c *Conn
|
||||
serverHello *serverHelloMsg
|
||||
hello *clientHelloMsg
|
||||
suite *cipherSuite
|
||||
finishedHash finishedHash
|
||||
masterSecret []byte
|
||||
session *ClientSessionState
|
||||
}
|
||||
|
||||
func (c *Conn) clientHandshake() error {
|
||||
if c.config == nil {
|
||||
c.config = defaultConfig()
|
||||
@ -60,13 +71,58 @@ NextCipherSuite:
|
||||
_, err := io.ReadFull(c.config.rand(), hello.random[4:])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("short read from Rand")
|
||||
return errors.New("tls: short read from Rand: " + err.Error())
|
||||
}
|
||||
|
||||
if hello.vers >= VersionTLS12 {
|
||||
hello.signatureAndHashes = supportedSKXSignatureAlgorithms
|
||||
}
|
||||
|
||||
var session *ClientSessionState
|
||||
var cacheKey string
|
||||
sessionCache := c.config.ClientSessionCache
|
||||
if c.config.SessionTicketsDisabled {
|
||||
sessionCache = nil
|
||||
}
|
||||
|
||||
if sessionCache != nil {
|
||||
hello.ticketSupported = true
|
||||
|
||||
// Try to resume a previously negotiated TLS session, if
|
||||
// available.
|
||||
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
|
||||
candidateSession, ok := sessionCache.Get(cacheKey)
|
||||
if ok {
|
||||
// Check that the ciphersuite/version used for the
|
||||
// previous session are still valid.
|
||||
cipherSuiteOk := false
|
||||
for _, id := range hello.cipherSuites {
|
||||
if id == candidateSession.cipherSuite {
|
||||
cipherSuiteOk = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
versOk := candidateSession.vers >= c.config.minVersion() &&
|
||||
candidateSession.vers <= c.config.maxVersion()
|
||||
if versOk && cipherSuiteOk {
|
||||
session = candidateSession
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
hello.sessionTicket = session.sessionTicket
|
||||
// A random session ID is used to detect when the
|
||||
// server accepted the ticket and is resuming a session
|
||||
// (see RFC 5077).
|
||||
hello.sessionId = make([]byte, 16)
|
||||
if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: short read from Rand: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
c.writeRecord(recordTypeHandshake, hello.marshal())
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
@ -86,25 +142,73 @@ NextCipherSuite:
|
||||
c.vers = vers
|
||||
c.haveVers = true
|
||||
|
||||
finishedHash := newFinishedHash(c.vers)
|
||||
finishedHash.Write(hello.marshal())
|
||||
finishedHash.Write(serverHello.marshal())
|
||||
|
||||
if serverHello.compressionMethod != compressionNone {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
if !hello.nextProtoNeg && serverHello.nextProtoNeg {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return errors.New("server advertised unrequested NPN")
|
||||
}
|
||||
|
||||
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
|
||||
if suite == nil {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
hs := &clientHandshakeState{
|
||||
c: c,
|
||||
serverHello: serverHello,
|
||||
hello: hello,
|
||||
suite: suite,
|
||||
finishedHash: newFinishedHash(c.vers),
|
||||
session: session,
|
||||
}
|
||||
|
||||
hs.finishedHash.Write(hs.hello.marshal())
|
||||
hs.finishedHash.Write(hs.serverHello.marshal())
|
||||
|
||||
isResume, err := hs.processServerHello()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isResume {
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := hs.doFullHandshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.establishKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.sendFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readSessionTicket(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := hs.readFinished(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if sessionCache != nil && hs.session != nil && session != hs.session {
|
||||
sessionCache.Put(cacheKey, hs.session)
|
||||
}
|
||||
|
||||
c.didResume = isResume
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
c := hs.c
|
||||
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -112,7 +216,7 @@ NextCipherSuite:
|
||||
if !ok || len(certMsg.certificates) == 0 {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
|
||||
certs := make([]*x509.Certificate, len(certMsg.certificates))
|
||||
for i, asn1Data := range certMsg.certificates {
|
||||
@ -154,7 +258,7 @@ NextCipherSuite:
|
||||
|
||||
c.peerCertificates = certs
|
||||
|
||||
if serverHello.ocspStapling {
|
||||
if hs.serverHello.ocspStapling {
|
||||
msg, err = c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -163,7 +267,7 @@ NextCipherSuite:
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(cs.marshal())
|
||||
hs.finishedHash.Write(cs.marshal())
|
||||
|
||||
if cs.statusType == statusTypeOCSP {
|
||||
c.ocspResponse = cs.response
|
||||
@ -175,12 +279,12 @@ NextCipherSuite:
|
||||
return err
|
||||
}
|
||||
|
||||
keyAgreement := suite.ka(c.vers)
|
||||
keyAgreement := hs.suite.ka(c.vers)
|
||||
|
||||
skx, ok := msg.(*serverKeyExchangeMsg)
|
||||
if ok {
|
||||
finishedHash.Write(skx.marshal())
|
||||
err = keyAgreement.processServerKeyExchange(c.config, hello, serverHello, certs[0], skx)
|
||||
hs.finishedHash.Write(skx.marshal())
|
||||
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, certs[0], skx)
|
||||
if err != nil {
|
||||
c.sendAlert(alertUnexpectedMessage)
|
||||
return err
|
||||
@ -209,7 +313,7 @@ NextCipherSuite:
|
||||
// ClientCertificateType, unless there is some external
|
||||
// arrangement to the contrary.
|
||||
|
||||
finishedHash.Write(certReq.marshal())
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
|
||||
var rsaAvail, ecdsaAvail bool
|
||||
for _, certType := range certReq.certificateTypes {
|
||||
@ -274,7 +378,7 @@ NextCipherSuite:
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
finishedHash.Write(shd.marshal())
|
||||
hs.finishedHash.Write(shd.marshal())
|
||||
|
||||
// If the server requested a certificate then we have to send a
|
||||
// Certificate message, even if it's empty because we don't have a
|
||||
@ -284,17 +388,17 @@ NextCipherSuite:
|
||||
if chainToSend != nil {
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
}
|
||||
finishedHash.Write(certMsg.marshal())
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certMsg.marshal())
|
||||
}
|
||||
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hello, certs[0])
|
||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0])
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
if ckx != nil {
|
||||
finishedHash.Write(ckx.marshal())
|
||||
hs.finishedHash.Write(ckx.marshal())
|
||||
c.writeRecord(recordTypeHandshake, ckx.marshal())
|
||||
}
|
||||
|
||||
@ -306,7 +410,7 @@ NextCipherSuite:
|
||||
|
||||
switch key := c.config.Certificates[0].PrivateKey.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
digest, _, hashId := finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
digest, _, hashId := hs.finishedHash.hashForClientCertificate(signatureECDSA)
|
||||
r, s, err := ecdsa.Sign(c.config.rand(), key, digest)
|
||||
if err == nil {
|
||||
signed, err = asn1.Marshal(ecdsaSignature{r, s})
|
||||
@ -314,7 +418,7 @@ NextCipherSuite:
|
||||
certVerify.signatureAndHash.signature = signatureECDSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
case *rsa.PrivateKey:
|
||||
digest, hashFunc, hashId := finishedHash.hashForClientCertificate(signatureRSA)
|
||||
digest, hashFunc, hashId := hs.finishedHash.hashForClientCertificate(signatureRSA)
|
||||
signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest)
|
||||
certVerify.signatureAndHash.signature = signatureRSA
|
||||
certVerify.signatureAndHash.hash = hashId
|
||||
@ -326,56 +430,73 @@ NextCipherSuite:
|
||||
}
|
||||
certVerify.signature = signed
|
||||
|
||||
finishedHash.Write(certVerify.marshal())
|
||||
hs.finishedHash.Write(certVerify.marshal())
|
||||
c.writeRecord(recordTypeHandshake, certVerify.marshal())
|
||||
}
|
||||
|
||||
masterSecret := masterFromPreMasterSecret(c.vers, preMasterSecret, hello.random, serverHello.random)
|
||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, preMasterSecret, hs.hello.random, hs.serverHello.random)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) establishKeys() error {
|
||||
c := hs.c
|
||||
|
||||
clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
|
||||
keysFromMasterSecret(c.vers, masterSecret, hello.random, serverHello.random, suite.macLen, suite.keyLen, suite.ivLen)
|
||||
|
||||
var clientCipher interface{}
|
||||
var clientHash macFunction
|
||||
if suite.cipher != nil {
|
||||
clientCipher = suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = suite.mac(c.vers, clientMAC)
|
||||
keysFromMasterSecret(c.vers, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
|
||||
var clientCipher, serverCipher interface{}
|
||||
var clientHash, serverHash macFunction
|
||||
if hs.suite.cipher != nil {
|
||||
clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */)
|
||||
clientHash = hs.suite.mac(c.vers, clientMAC)
|
||||
serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = hs.suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
clientCipher = suite.aead(clientKey, clientIV)
|
||||
}
|
||||
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
|
||||
if serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, serverHello.nextProtos)
|
||||
nextProto.proto = proto
|
||||
c.clientProtocol = proto
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
clientCipher = hs.suite.aead(clientKey, clientIV)
|
||||
serverCipher = hs.suite.aead(serverKey, serverIV)
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = finishedHash.clientSum(masterSecret)
|
||||
finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
|
||||
var serverCipher interface{}
|
||||
var serverHash macFunction
|
||||
if suite.cipher != nil {
|
||||
serverCipher = suite.cipher(serverKey, serverIV, true /* for reading */)
|
||||
serverHash = suite.mac(c.vers, serverMAC)
|
||||
} else {
|
||||
serverCipher = suite.aead(serverKey, serverIV)
|
||||
}
|
||||
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
|
||||
c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) serverResumedSession() bool {
|
||||
// If the server responded with the same sessionId then it means the
|
||||
// sessionTicket is being used to resume a TLS session.
|
||||
return hs.session != nil && hs.hello.sessionId != nil &&
|
||||
bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
||||
c := hs.c
|
||||
|
||||
if hs.serverHello.compressionMethod != compressionNone {
|
||||
return false, c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
|
||||
c.sendAlert(alertHandshakeFailure)
|
||||
return false, errors.New("server advertised unrequested NPN")
|
||||
}
|
||||
|
||||
if hs.serverResumedSession() {
|
||||
// Restore masterSecret and peerCerts from previous state
|
||||
hs.masterSecret = hs.session.masterSecret
|
||||
c.peerCertificates = hs.session.serverCertificates
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) readFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.readRecord(recordTypeChangeCipherSpec)
|
||||
if err := c.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -384,17 +505,73 @@ NextCipherSuite:
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
|
||||
verify := finishedHash.serverSum(masterSecret)
|
||||
verify := hs.finishedHash.serverSum(hs.masterSecret)
|
||||
if len(verify) != len(serverFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
|
||||
return c.sendAlert(alertHandshakeFailure)
|
||||
}
|
||||
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
hs.finishedHash.Write(serverFinished.marshal())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) readSessionTicket() error {
|
||||
if !hs.serverHello.ticketSupported {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := hs.c
|
||||
msg, err := c.readHandshake()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
|
||||
if !ok {
|
||||
return c.sendAlert(alertUnexpectedMessage)
|
||||
}
|
||||
hs.finishedHash.Write(sessionTicketMsg.marshal())
|
||||
|
||||
hs.session = &ClientSessionState{
|
||||
sessionTicket: sessionTicketMsg.ticket,
|
||||
vers: c.vers,
|
||||
cipherSuite: hs.suite.id,
|
||||
masterSecret: hs.masterSecret,
|
||||
serverCertificates: c.peerCertificates,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *clientHandshakeState) sendFinished() error {
|
||||
c := hs.c
|
||||
|
||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
|
||||
if hs.serverHello.nextProtoNeg {
|
||||
nextProto := new(nextProtoMsg)
|
||||
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
|
||||
nextProto.proto = proto
|
||||
c.clientProtocol = proto
|
||||
c.clientProtocolFallback = fallback
|
||||
|
||||
hs.finishedHash.Write(nextProto.marshal())
|
||||
c.writeRecord(recordTypeHandshake, nextProto.marshal())
|
||||
}
|
||||
|
||||
finished := new(finishedMsg)
|
||||
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
|
||||
hs.finishedHash.Write(finished.marshal())
|
||||
c.writeRecord(recordTypeHandshake, finished.marshal())
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
||||
// be used to resume previously negotiated TLS sessions with a server.
|
||||
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
|
||||
if len(config.ServerName) > 0 {
|
||||
return config.ServerName
|
||||
}
|
||||
return serverAddr.String()
|
||||
}
|
||||
|
||||
// mutualProtocol finds the mutual Next Protocol Negotiation protocol given the
|
||||
// set of client and server supported protocols. The set of client supported
|
||||
// protocols must not be empty. It returns the resulting protocol and flag
|
||||
|
@ -353,3 +353,87 @@ func TestHandshakeClientCertECDSA(t *testing.T) {
|
||||
runClientTestTLS10(t, test)
|
||||
runClientTestTLS12(t, test)
|
||||
}
|
||||
|
||||
func TestClientResumption(t *testing.T) {
|
||||
serverConfig := &Config{
|
||||
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_ECDHE_RSA_WITH_RC4_128_SHA},
|
||||
Certificates: testConfig.Certificates,
|
||||
}
|
||||
clientConfig := &Config{
|
||||
CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
||||
InsecureSkipVerify: true,
|
||||
ClientSessionCache: NewLRUClientSessionCache(32),
|
||||
}
|
||||
|
||||
testResumeState := func(test string, didResume bool) {
|
||||
hs, err := testHandshake(clientConfig, serverConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("%s: handshake failed: %s", test, err)
|
||||
}
|
||||
if hs.DidResume != didResume {
|
||||
t.Fatalf("%s resumed: %v, expected: %v", test, hs.DidResume, didResume)
|
||||
}
|
||||
}
|
||||
|
||||
testResumeState("Handshake", false)
|
||||
testResumeState("Resume", true)
|
||||
|
||||
if _, err := io.ReadFull(serverConfig.rand(), serverConfig.SessionTicketKey[:]); err != nil {
|
||||
t.Fatalf("Failed to invalidate SessionTicketKey")
|
||||
}
|
||||
testResumeState("InvalidSessionTicketKey", false)
|
||||
testResumeState("ResumeAfterInvalidSessionTicketKey", true)
|
||||
|
||||
clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_RC4_128_SHA}
|
||||
testResumeState("DifferentCipherSuite", false)
|
||||
testResumeState("DifferentCipherSuiteRecovers", true)
|
||||
|
||||
clientConfig.ClientSessionCache = nil
|
||||
testResumeState("WithoutSessionCache", false)
|
||||
}
|
||||
|
||||
func TestLRUClientSessionCache(t *testing.T) {
|
||||
// Initialize cache of capacity 4.
|
||||
cache := NewLRUClientSessionCache(4)
|
||||
cs := make([]ClientSessionState, 6)
|
||||
keys := []string{"0", "1", "2", "3", "4", "5", "6"}
|
||||
|
||||
// Add 4 entries to the cache and look them up.
|
||||
for i := 0; i < 4; i++ {
|
||||
cache.Put(keys[i], &cs[i])
|
||||
}
|
||||
for i := 0; i < 4; i++ {
|
||||
if s, ok := cache.Get(keys[i]); !ok || s != &cs[i] {
|
||||
t.Fatalf("session cache failed lookup for added key: %s", keys[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Add 2 more entries to the cache. First 2 should be evicted.
|
||||
for i := 4; i < 6; i++ {
|
||||
cache.Put(keys[i], &cs[i])
|
||||
}
|
||||
for i := 0; i < 2; i++ {
|
||||
if s, ok := cache.Get(keys[i]); ok || s != nil {
|
||||
t.Fatalf("session cache should have evicted key: %s", keys[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Touch entry 2. LRU should evict 3 next.
|
||||
cache.Get(keys[2])
|
||||
cache.Put(keys[0], &cs[0])
|
||||
if s, ok := cache.Get(keys[3]); ok || s != nil {
|
||||
t.Fatalf("session cache should have evicted key 3")
|
||||
}
|
||||
|
||||
// Update entry 0 in place.
|
||||
cache.Put(keys[0], &cs[3])
|
||||
if s, ok := cache.Get(keys[0]); !ok || s != &cs[3] {
|
||||
t.Fatalf("session cache failed update for key 0")
|
||||
}
|
||||
|
||||
// Adding a nil entry is valid.
|
||||
cache.Put(keys[0], nil)
|
||||
if s, ok := cache.Get(keys[0]); !ok || s != nil {
|
||||
t.Fatalf("failed to add nil entry to cache")
|
||||
}
|
||||
}
|
||||
|
@ -177,10 +177,12 @@ func TestClose(t *testing.T) {
|
||||
|
||||
func testHandshake(clientConfig, serverConfig *Config) (state ConnectionState, err error) {
|
||||
c, s := net.Pipe()
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
cli := Client(c, clientConfig)
|
||||
cli.Handshake()
|
||||
c.Close()
|
||||
done <- true
|
||||
}()
|
||||
server := Server(s, serverConfig)
|
||||
err = server.Handshake()
|
||||
@ -188,6 +190,7 @@ func testHandshake(clientConfig, serverConfig *Config) (state ConnectionState, e
|
||||
state = server.ConnectionState()
|
||||
}
|
||||
s.Close()
|
||||
<-done
|
||||
return
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user