tris: use keySchedule13 for the server

Use the new keySchedule13 type instead of hash.Hash to avoid tracking
the hashContext and intermediate secrets manually.

checkPSK is modified not to return the calculated early secret, this is
internal to keySchedule13 now. The caller just learns whether it was
resumed using a PSK or not.
This commit is contained in:
Peter Wu 2017-09-18 16:01:36 +01:00
parent 9f9f06de80
commit 6f580251ca
2 changed files with 37 additions and 48 deletions

82
13.go
View File

@ -146,23 +146,24 @@ CurvePreferenceLoop:
hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
hashSize := hash.Size() hashSize := hash.Size()
hs.keySchedule = newKeySchedule13(hs.suite)
earlySecret, pskAlert := hs.checkPSK() // Check for PSK and update key schedule with new early secret key
isResumed, pskAlert := hs.checkPSK()
switch { switch {
case pskAlert != alertSuccess: case pskAlert != alertSuccess:
c.sendAlert(pskAlert) c.sendAlert(pskAlert)
return errors.New("tls: invalid client PSK") return errors.New("tls: invalid client PSK")
case earlySecret == nil: case !isResumed:
earlySecret = hkdfExtract(hash, nil, nil) // apply an empty PSK if not resumed.
case earlySecret != nil: hs.keySchedule.setSecret(nil)
case isResumed:
c.didResume = true c.didResume = true
} }
hs.finishedHash13 = hash.New() hs.keySchedule.write(hs.clientHello.marshal())
hs.finishedHash13.Write(hs.clientHello.marshal())
handshakeCtx := hs.finishedHash13.Sum(nil) earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
ecdheSecret := deriveECDHESecret(ks, privateKey) ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil { if ecdheSecret == nil {
@ -170,22 +171,21 @@ CurvePreferenceLoop:
return errors.New("tls: bad ECDHE client share") return errors.New("tls: bad ECDHE client share")
} }
hs.finishedHash13.Write(hs.hello13.marshal()) hs.keySchedule.write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err return err
} }
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) hs.keySchedule.setSecret(ecdheSecret)
handshakeCtx = hs.finishedHash13.Sum(nil) clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
hs.hsClientCipher = clientCipher hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret") serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize) serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize) hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
hs.finishedHash13.Write(hs.hello13Enc.marshal()) hs.keySchedule.write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err return err
} }
@ -196,19 +196,18 @@ CurvePreferenceLoop:
} }
} }
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey)
serverFinished := &finishedMsg{ serverFinished := &finishedMsg{
verifyData: verifyData, verifyData: verifyData,
} }
hs.finishedHash13.Write(serverFinished.marshal()) hs.keySchedule.write(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err return err
} }
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) hs.keySchedule.setSecret(nil) // derive master secret
handshakeCtx = hs.finishedHash13.Sum(nil) hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient)
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret") serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer)
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
if c.hand.Len() > 0 { if c.hand.Len() > 0 {
@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error {
} }
hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) || if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertDecryptError) c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect") return errors.New("tls: client's Finished message is incorrect")
} }
hs.finishedHash13.Write(clientFinished.marshal()) hs.keySchedule.write(clientFinished.marshal())
c.hs = nil // Discard the server handshake state c.hs = nil // Discard the server handshake state
if c.hand.Len() > 0 { if c.hand.Len() > 0 {
@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
} }
certMsg := &certificateMsg13{certificates: certEntries} certMsg := &certificateMsg13{certificates: certEntries}
hs.finishedHash13.Write(certMsg.marshal()) hs.keySchedule.write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err return err
} }
@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
} }
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil)) toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
signatureAndHash: sigSchemeToSigAndHash(sigScheme), signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature, signature: signature,
} }
hs.finishedHash13.Write(verifyMsg.marshal()) hs.keySchedule.write(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err return err
} }
@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
return h.Sum(nil) return h.Sum(nil)
} }
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
hash := hashForSuite(hs.suite)
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return hs.suite.aead(key, iv), trafficSecret
}
// Maximum allowed mismatch between the stated age of a ticket // Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See // and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second const ticketAgeSkewAllowance = 10 * time.Second
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { // checkPSK tries to resume using a PSK, returning true (and updating the
// early secret in the key schedule) if the PSK was used and false otherwise.
func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
if hs.c.config.SessionTicketsDisabled { if hs.c.config.SessionTicketsDisabled {
return nil, alertSuccess return false, alertSuccess
} }
foundDHE := false foundDHE := false
@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
} }
} }
if !foundDHE { if !foundDHE {
return nil, alertSuccess return false, alertSuccess
} }
hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
continue continue
} }
earlySecret := hkdfExtract(hash, s.resumptionSecret, nil) hs.keySchedule.setSecret(s.resumptionSecret)
handshakeCtx := hash.New().Sum(nil) binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New() chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated) chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 { if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
return nil, alertDecryptError return false, alertDecryptError
} }
if i == 0 && hs.clientHello.earlyData { if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT // This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 { if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such. // But we had not tagged it as such.
return nil, alertIllegalParameter return false, alertIllegalParameter
} }
if hs.c.config.Accept0RTTData { if hs.c.config.Accept0RTTData {
hs.c.binder = expectedBinder hs.c.binder = expectedBinder
@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
} }
hs.hello13.psk = true hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i) hs.hello13.pskIdentity = uint16(i)
return earlySecret, alertSuccess return true, alertSuccess
} }
return nil, alertSuccess return false, alertSuccess
} }
func (hs *serverHandshakeState) sendSessionTicket13() error { func (hs *serverHandshakeState) sendSessionTicket13() error {
@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return nil return nil
} }
hash := hashForSuite(hs.suite) resumptionSecret := hs.keySchedule.deriveSecret(secretResumption)
handshakeCtx := hs.finishedHash13.Sum(nil)
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
ageAddBuf := make([]byte, 4) ageAddBuf := make([]byte, 4)
sessionState := &sessionState13{ sessionState := &sessionState13{

View File

@ -13,7 +13,6 @@ import (
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"sync/atomic" "sync/atomic"
) )
@ -45,7 +44,7 @@ type serverHandshakeState struct {
// TLS 1.3 fields // TLS 1.3 fields
hello13 *serverHelloMsg13 hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash keySchedule *keySchedule13
clientFinishedKey []byte clientFinishedKey []byte
hsClientCipher interface{} hsClientCipher interface{}
appClientCipher interface{} appClientCipher interface{}