tris: use keySchedule13 for the server

Use the new keySchedule13 type instead of hash.Hash to avoid tracking
the hashContext and intermediate secrets manually.

checkPSK is modified not to return the calculated early secret, this is
internal to keySchedule13 now. The caller just learns whether it was
resumed using a PSK or not.
This commit is contained in:
Peter Wu 2017-09-18 16:01:36 +01:00
parent 9f9f06de80
commit 6f580251ca
2 changed files with 37 additions and 48 deletions

82
13.go
View File

@ -146,23 +146,24 @@ CurvePreferenceLoop:
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
hs.keySchedule = newKeySchedule13(hs.suite)
earlySecret, pskAlert := hs.checkPSK()
// Check for PSK and update key schedule with new early secret key
isResumed, pskAlert := hs.checkPSK()
switch {
case pskAlert != alertSuccess:
c.sendAlert(pskAlert)
return errors.New("tls: invalid client PSK")
case earlySecret == nil:
earlySecret = hkdfExtract(hash, nil, nil)
case earlySecret != nil:
case !isResumed:
// apply an empty PSK if not resumed.
hs.keySchedule.setSecret(nil)
case isResumed:
c.didResume = true
}
hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
hs.keySchedule.write(hs.clientHello.marshal())
handshakeCtx := hs.finishedHash13.Sum(nil)
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil {
@ -170,22 +171,21 @@ CurvePreferenceLoop:
return errors.New("tls: bad ECDHE client share")
}
hs.finishedHash13.Write(hs.hello13.marshal())
hs.keySchedule.write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err
}
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx = hs.finishedHash13.Sum(nil)
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
hs.keySchedule.setSecret(ecdheSecret)
clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
c.out.setCipher(c.vers, serverCipher)
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
hs.finishedHash13.Write(hs.hello13Enc.marshal())
hs.keySchedule.write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err
}
@ -196,19 +196,18 @@ CurvePreferenceLoop:
}
}
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey)
serverFinished := &finishedMsg{
verifyData: verifyData,
}
hs.finishedHash13.Write(serverFinished.marshal())
hs.keySchedule.write(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err
}
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash13.Sum(nil)
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
hs.keySchedule.setSecret(nil) // derive master secret
hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient)
serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer)
c.out.setCipher(c.vers, serverCipher)
if c.hand.Len() > 0 {
@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error {
}
hash := hashForSuite(hs.suite)
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect")
}
hs.finishedHash13.Write(clientFinished.marshal())
hs.keySchedule.write(clientFinished.marshal())
c.hs = nil // Discard the server handshake state
if c.hand.Len() > 0 {
@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
}
certMsg := &certificateMsg13{certificates: certEntries}
hs.finishedHash13.Write(certMsg.marshal())
hs.keySchedule.write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil))
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil {
c.sendAlert(alertInternalError)
@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature,
}
hs.finishedHash13.Write(verifyMsg.marshal())
hs.keySchedule.write(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err
}
@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
return h.Sum(nil)
}
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
hash := hashForSuite(hs.suite)
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return hs.suite.aead(key, iv), trafficSecret
}
// Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
// checkPSK tries to resume using a PSK, returning true (and updating the
// early secret in the key schedule) if the PSK was used and false otherwise.
func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
if hs.c.config.SessionTicketsDisabled {
return nil, alertSuccess
return false, alertSuccess
}
foundDHE := false
@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
}
}
if !foundDHE {
return nil, alertSuccess
return false, alertSuccess
}
hash := hashForSuite(hs.suite)
@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
continue
}
earlySecret := hkdfExtract(hash, s.resumptionSecret, nil)
handshakeCtx := hash.New().Sum(nil)
binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize)
hs.keySchedule.setSecret(s.resumptionSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
return nil, alertDecryptError
return false, alertDecryptError
}
if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such.
return nil, alertIllegalParameter
return false, alertIllegalParameter
}
if hs.c.config.Accept0RTTData {
hs.c.binder = expectedBinder
@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
}
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
return earlySecret, alertSuccess
return true, alertSuccess
}
return nil, alertSuccess
return false, alertSuccess
}
func (hs *serverHandshakeState) sendSessionTicket13() error {
@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return nil
}
hash := hashForSuite(hs.suite)
handshakeCtx := hs.finishedHash13.Sum(nil)
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
resumptionSecret := hs.keySchedule.deriveSecret(secretResumption)
ageAddBuf := make([]byte, 4)
sessionState := &sessionState13{

View File

@ -13,7 +13,6 @@ import (
"encoding/asn1"
"errors"
"fmt"
"hash"
"io"
"sync/atomic"
)
@ -45,7 +44,7 @@ type serverHandshakeState struct {
// TLS 1.3 fields
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash
keySchedule *keySchedule13
clientFinishedKey []byte
hsClientCipher interface{}
appClientCipher interface{}