@@ -146,23 +146,24 @@ CurvePreferenceLoop:
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
hs.keySchedule = newKeySchedule13(hs.suite)
earlySecret, pskAlert := hs.checkPSK()
// Check for PSK and update key schedule with new early secret key
isResumed, pskAlert := hs.checkPSK()
switch {
case pskAlert != alertSuccess:
c.sendAlert(pskAlert)
return errors.New("tls: invalid client PSK")
case earlySecret == nil:
earlySecret = hkdfExtract(hash, nil, nil)
case earlySecret != nil:
case !isResumed:
// apply an empty PSK if not resumed.
hs.keySchedule.setSecret(nil)
case isResumed:
c.didResume = true
}
hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
hs.keySchedule.write(hs.clientHello.marshal())
handshakeCtx := hs.finishedHash13.Sum(nil)
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil {
@@ -170,22 +171,21 @@ CurvePreferenceLoop:
return errors.New("tls: bad ECDHE client share")
}
hs.finishedHash13.W rite(hs.hello13.marshal())
hs.keySchedule.w rite(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err
}
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx = hs.finishedHash13.Sum(nil)
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
hs.keySchedule.setSecret(ecdheSecret)
clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret" )
serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer )
c.out.setCipher(c.vers, serverCipher)
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
hs.finishedHash13.W rite(hs.hello13Enc.marshal())
hs.keySchedule.w rite(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err
}
@@ -196,19 +196,18 @@ CurvePreferenceLoop:
}
}
verifyData := hmacOfSum(hash, hs.finishedHash13 , serverFinishedKey)
verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash , serverFinishedKey)
serverFinished := &finishedMsg{
verifyData: verifyData,
}
hs.finishedHash13.W rite(serverFinished.marshal())
hs.keySchedule.w rite(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err
}
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash13.Sum(nil)
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
hs.keySchedule.setSecret(nil) // derive master secret
hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient)
serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer)
c.out.setCipher(c.vers, serverCipher)
if c.hand.Len() > 0 {
@@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error {
}
hash := hashForSuite(hs.suite)
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13 , hs.clientFinishedKey)
expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash , hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect")
}
hs.finishedHash13.W rite(clientFinished.marshal())
hs.keySchedule.w rite(clientFinished.marshal())
c.hs = nil // Discard the server handshake state
if c.hand.Len() > 0 {
@@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
}
certMsg := &certificateMsg13{certificates: certEntries}
hs.finishedHash13.W rite(certMsg.marshal())
hs.keySchedule.w rite(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
@@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13 .Sum(nil))
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash .Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil {
c.sendAlert(alertInternalError)
@@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature,
}
hs.finishedHash13.W rite(verifyMsg.marshal())
hs.keySchedule.w rite(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err
}
@@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
return h.Sum(nil)
}
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
hash := hashForSuite(hs.suite)
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return hs.suite.aead(key, iv), trafficSecret
}
// Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
// checkPSK tries to resume using a PSK, returning true (and updating the
// early secret in the key schedule) if the PSK was used and false otherwise.
func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
if hs.c.config.SessionTicketsDisabled {
return nil, alertSuccess
return false , alertSuccess
}
foundDHE := false
@@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
}
}
if !foundDHE {
return nil , alertSuccess
return false , alertSuccess
}
hash := hashForSuite(hs.suite)
@@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
continue
}
earlySecret := hkdfExtract(hash, s.resumptionSecret, nil)
handshakeCtx := hash.New().Sum(nil)
binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize)
hs.keySchedule.setSecret(s.resumptionSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
return nil , alertDecryptError
return false , alertDecryptError
}
if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such.
return nil , alertIllegalParameter
return false , alertIllegalParameter
}
if hs.c.config.Accept0RTTData {
hs.c.binder = expectedBinder
@@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
}
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
return earlySecret , alertSuccess
return true , alertSuccess
}
return nil , alertSuccess
return false , alertSuccess
}
func (hs *serverHandshakeState) sendSessionTicket13() error {
@@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return nil
}
hash := hashForSuite(hs.suite)
handshakeCtx := hs.finishedHash13.Sum(nil)
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
resumptionSecret := hs.keySchedule.deriveSecret(secretResumption)
ageAddBuf := make([]byte, 4)
sessionState := &sessionState13{