tris: use keySchedule13 for the server
Use the new keySchedule13 type instead of hash.Hash to avoid tracking the hashContext and intermediate secrets manually. checkPSK is modified not to return the calculated early secret, this is internal to keySchedule13 now. The caller just learns whether it was resumed using a PSK or not.
This commit is contained in:
parent
9f9f06de80
commit
6f580251ca
82
13.go
82
13.go
@ -146,23 +146,24 @@ CurvePreferenceLoop:
|
||||
|
||||
hash := hashForSuite(hs.suite)
|
||||
hashSize := hash.Size()
|
||||
hs.keySchedule = newKeySchedule13(hs.suite)
|
||||
|
||||
earlySecret, pskAlert := hs.checkPSK()
|
||||
// Check for PSK and update key schedule with new early secret key
|
||||
isResumed, pskAlert := hs.checkPSK()
|
||||
switch {
|
||||
case pskAlert != alertSuccess:
|
||||
c.sendAlert(pskAlert)
|
||||
return errors.New("tls: invalid client PSK")
|
||||
case earlySecret == nil:
|
||||
earlySecret = hkdfExtract(hash, nil, nil)
|
||||
case earlySecret != nil:
|
||||
case !isResumed:
|
||||
// apply an empty PSK if not resumed.
|
||||
hs.keySchedule.setSecret(nil)
|
||||
case isResumed:
|
||||
c.didResume = true
|
||||
}
|
||||
|
||||
hs.finishedHash13 = hash.New()
|
||||
hs.finishedHash13.Write(hs.clientHello.marshal())
|
||||
hs.keySchedule.write(hs.clientHello.marshal())
|
||||
|
||||
handshakeCtx := hs.finishedHash13.Sum(nil)
|
||||
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
|
||||
earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
|
||||
|
||||
ecdheSecret := deriveECDHESecret(ks, privateKey)
|
||||
if ecdheSecret == nil {
|
||||
@ -170,22 +171,21 @@ CurvePreferenceLoop:
|
||||
return errors.New("tls: bad ECDHE client share")
|
||||
}
|
||||
|
||||
hs.finishedHash13.Write(hs.hello13.marshal())
|
||||
hs.keySchedule.write(hs.hello13.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
|
||||
handshakeCtx = hs.finishedHash13.Sum(nil)
|
||||
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
|
||||
hs.keySchedule.setSecret(ecdheSecret)
|
||||
clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
|
||||
hs.hsClientCipher = clientCipher
|
||||
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
|
||||
serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
|
||||
c.out.setCipher(c.vers, serverCipher)
|
||||
|
||||
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
|
||||
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
|
||||
|
||||
hs.finishedHash13.Write(hs.hello13Enc.marshal())
|
||||
hs.keySchedule.write(hs.hello13Enc.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -196,19 +196,18 @@ CurvePreferenceLoop:
|
||||
}
|
||||
}
|
||||
|
||||
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
|
||||
verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey)
|
||||
serverFinished := &finishedMsg{
|
||||
verifyData: verifyData,
|
||||
}
|
||||
hs.finishedHash13.Write(serverFinished.marshal())
|
||||
hs.keySchedule.write(serverFinished.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
|
||||
handshakeCtx = hs.finishedHash13.Sum(nil)
|
||||
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
|
||||
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
|
||||
hs.keySchedule.setSecret(nil) // derive master secret
|
||||
hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient)
|
||||
serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer)
|
||||
c.out.setCipher(c.vers, serverCipher)
|
||||
|
||||
if c.hand.Len() > 0 {
|
||||
@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error {
|
||||
}
|
||||
|
||||
hash := hashForSuite(hs.suite)
|
||||
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
|
||||
expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey)
|
||||
if len(expectedVerifyData) != len(clientFinished.verifyData) ||
|
||||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
|
||||
c.sendAlert(alertDecryptError)
|
||||
return errors.New("tls: client's Finished message is incorrect")
|
||||
}
|
||||
hs.finishedHash13.Write(clientFinished.marshal())
|
||||
hs.keySchedule.write(clientFinished.marshal())
|
||||
|
||||
c.hs = nil // Discard the server handshake state
|
||||
if c.hand.Len() > 0 {
|
||||
@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
|
||||
}
|
||||
certMsg := &certificateMsg13{certificates: certEntries}
|
||||
|
||||
hs.finishedHash13.Write(certMsg.marshal())
|
||||
hs.keySchedule.write(certMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
|
||||
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
|
||||
}
|
||||
|
||||
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil))
|
||||
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil))
|
||||
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
|
||||
if err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
|
||||
signatureAndHash: sigSchemeToSigAndHash(sigScheme),
|
||||
signature: signature,
|
||||
}
|
||||
hs.finishedHash13.Write(verifyMsg.marshal())
|
||||
hs.keySchedule.write(verifyMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
|
||||
hash := hashForSuite(hs.suite)
|
||||
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
|
||||
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
|
||||
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
|
||||
return hs.suite.aead(key, iv), trafficSecret
|
||||
}
|
||||
|
||||
// Maximum allowed mismatch between the stated age of a ticket
|
||||
// and the server-observed one. See
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
|
||||
const ticketAgeSkewAllowance = 10 * time.Second
|
||||
|
||||
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
|
||||
// checkPSK tries to resume using a PSK, returning true (and updating the
|
||||
// early secret in the key schedule) if the PSK was used and false otherwise.
|
||||
func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
|
||||
if hs.c.config.SessionTicketsDisabled {
|
||||
return nil, alertSuccess
|
||||
return false, alertSuccess
|
||||
}
|
||||
|
||||
foundDHE := false
|
||||
@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
|
||||
}
|
||||
}
|
||||
if !foundDHE {
|
||||
return nil, alertSuccess
|
||||
return false, alertSuccess
|
||||
}
|
||||
|
||||
hash := hashForSuite(hs.suite)
|
||||
@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
|
||||
continue
|
||||
}
|
||||
|
||||
earlySecret := hkdfExtract(hash, s.resumptionSecret, nil)
|
||||
handshakeCtx := hash.New().Sum(nil)
|
||||
binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize)
|
||||
hs.keySchedule.setSecret(s.resumptionSecret)
|
||||
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
|
||||
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
|
||||
chHash := hash.New()
|
||||
chHash.Write(hs.clientHello.rawTruncated)
|
||||
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
|
||||
|
||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
|
||||
return nil, alertDecryptError
|
||||
return false, alertDecryptError
|
||||
}
|
||||
|
||||
if i == 0 && hs.clientHello.earlyData {
|
||||
// This is a ticket intended to be used for 0-RTT
|
||||
if s.maxEarlyDataLen == 0 {
|
||||
// But we had not tagged it as such.
|
||||
return nil, alertIllegalParameter
|
||||
return false, alertIllegalParameter
|
||||
}
|
||||
if hs.c.config.Accept0RTTData {
|
||||
hs.c.binder = expectedBinder
|
||||
@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
|
||||
}
|
||||
hs.hello13.psk = true
|
||||
hs.hello13.pskIdentity = uint16(i)
|
||||
return earlySecret, alertSuccess
|
||||
return true, alertSuccess
|
||||
}
|
||||
|
||||
return nil, alertSuccess
|
||||
return false, alertSuccess
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) sendSessionTicket13() error {
|
||||
@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hash := hashForSuite(hs.suite)
|
||||
handshakeCtx := hs.finishedHash13.Sum(nil)
|
||||
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
|
||||
resumptionSecret := hs.keySchedule.deriveSecret(secretResumption)
|
||||
|
||||
ageAddBuf := make([]byte, 4)
|
||||
sessionState := &sessionState13{
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
@ -45,7 +44,7 @@ type serverHandshakeState struct {
|
||||
// TLS 1.3 fields
|
||||
hello13 *serverHelloMsg13
|
||||
hello13Enc *encryptedExtensionsMsg
|
||||
finishedHash13 hash.Hash
|
||||
keySchedule *keySchedule13
|
||||
clientFinishedKey []byte
|
||||
hsClientCipher interface{}
|
||||
appClientCipher interface{}
|
||||
|
Loading…
Reference in New Issue
Block a user