From 6f580251ca61da39a69d58f045ba5ace7e2a6637 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Mon, 18 Sep 2017 16:01:36 +0100 Subject: [PATCH] tris: use keySchedule13 for the server Use the new keySchedule13 type instead of hash.Hash to avoid tracking the hashContext and intermediate secrets manually. checkPSK is modified not to return the calculated early secret, this is internal to keySchedule13 now. The caller just learns whether it was resumed using a PSK or not. --- 13.go | 82 ++++++++++++++++++++------------------------- handshake_server.go | 3 +- 2 files changed, 37 insertions(+), 48 deletions(-) diff --git a/13.go b/13.go index 661caef..efeb305 100644 --- a/13.go +++ b/13.go @@ -146,23 +146,24 @@ CurvePreferenceLoop: hash := hashForSuite(hs.suite) hashSize := hash.Size() + hs.keySchedule = newKeySchedule13(hs.suite) - earlySecret, pskAlert := hs.checkPSK() + // Check for PSK and update key schedule with new early secret key + isResumed, pskAlert := hs.checkPSK() switch { case pskAlert != alertSuccess: c.sendAlert(pskAlert) return errors.New("tls: invalid client PSK") - case earlySecret == nil: - earlySecret = hkdfExtract(hash, nil, nil) - case earlySecret != nil: + case !isResumed: + // apply an empty PSK if not resumed. + hs.keySchedule.setSecret(nil) + case isResumed: c.didResume = true } - hs.finishedHash13 = hash.New() - hs.finishedHash13.Write(hs.clientHello.marshal()) + hs.keySchedule.write(hs.clientHello.marshal()) - handshakeCtx := hs.finishedHash13.Sum(nil) - earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret") + earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient) ecdheSecret := deriveECDHESecret(ks, privateKey) if ecdheSecret == nil { @@ -170,22 +171,21 @@ CurvePreferenceLoop: return errors.New("tls: bad ECDHE client share") } - hs.finishedHash13.Write(hs.hello13.marshal()) + hs.keySchedule.write(hs.hello13.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { return err } - handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) - handshakeCtx = hs.finishedHash13.Sum(nil) - clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret") + hs.keySchedule.setSecret(ecdheSecret) + clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) hs.hsClientCipher = clientCipher - serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret") + serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) c.out.setCipher(c.vers, serverCipher) serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize) hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize) - hs.finishedHash13.Write(hs.hello13Enc.marshal()) + hs.keySchedule.write(hs.hello13Enc.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { return err } @@ -196,19 +196,18 @@ CurvePreferenceLoop: } } - verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) + verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey) serverFinished := &finishedMsg{ verifyData: verifyData, } - hs.finishedHash13.Write(serverFinished.marshal()) + hs.keySchedule.write(serverFinished.marshal()) if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { return err } - hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) - handshakeCtx = hs.finishedHash13.Sum(nil) - hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret") - serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret") + hs.keySchedule.setSecret(nil) // derive master secret + hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient) + serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer) c.out.setCipher(c.vers, serverCipher) if c.hand.Len() > 0 { @@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error { } hash := hashForSuite(hs.suite) - expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) + expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey) if len(expectedVerifyData) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { c.sendAlert(alertDecryptError) return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash13.Write(clientFinished.marshal()) + hs.keySchedule.write(clientFinished.marshal()) c.hs = nil // Discard the server handshake state if c.hand.Len() > 0 { @@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { } certMsg := &certificateMsg13{certificates: certEntries} - hs.finishedHash13.Write(certMsg.marshal()) + hs.keySchedule.write(certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { return err } @@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} } - toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil)) + toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil)) signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) if err != nil { c.sendAlert(alertInternalError) @@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { signatureAndHash: sigSchemeToSigAndHash(sigScheme), signature: signature, } - hs.finishedHash13.Write(verifyMsg.marshal()) + hs.keySchedule.write(verifyMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { return err } @@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte { return h.Sum(nil) } -func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) { - hash := hashForSuite(hs.suite) - trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size()) - key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen) - iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12) - return hs.suite.aead(key, iv), trafficSecret -} - // Maximum allowed mismatch between the stated age of a ticket // and the server-observed one. See // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. const ticketAgeSkewAllowance = 10 * time.Second -func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { +// checkPSK tries to resume using a PSK, returning true (and updating the +// early secret in the key schedule) if the PSK was used and false otherwise. +func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { if hs.c.config.SessionTicketsDisabled { - return nil, alertSuccess + return false, alertSuccess } foundDHE := false @@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { } } if !foundDHE { - return nil, alertSuccess + return false, alertSuccess } hash := hashForSuite(hs.suite) @@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { continue } - earlySecret := hkdfExtract(hash, s.resumptionSecret, nil) - handshakeCtx := hash.New().Sum(nil) - binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize) + hs.keySchedule.setSecret(s.resumptionSecret) + binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) chHash := hash.New() chHash.Write(hs.clientHello.rawTruncated) expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 { - return nil, alertDecryptError + return false, alertDecryptError } if i == 0 && hs.clientHello.earlyData { // This is a ticket intended to be used for 0-RTT if s.maxEarlyDataLen == 0 { // But we had not tagged it as such. - return nil, alertIllegalParameter + return false, alertIllegalParameter } if hs.c.config.Accept0RTTData { hs.c.binder = expectedBinder @@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { } hs.hello13.psk = true hs.hello13.pskIdentity = uint16(i) - return earlySecret, alertSuccess + return true, alertSuccess } - return nil, alertSuccess + return false, alertSuccess } func (hs *serverHandshakeState) sendSessionTicket13() error { @@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return nil } - hash := hashForSuite(hs.suite) - handshakeCtx := hs.finishedHash13.Sum(nil) - resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size()) + resumptionSecret := hs.keySchedule.deriveSecret(secretResumption) ageAddBuf := make([]byte, 4) sessionState := &sessionState13{ diff --git a/handshake_server.go b/handshake_server.go index b71348d..f3bc79c 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -13,7 +13,6 @@ import ( "encoding/asn1" "errors" "fmt" - "hash" "io" "sync/atomic" ) @@ -45,7 +44,7 @@ type serverHandshakeState struct { // TLS 1.3 fields hello13 *serverHelloMsg13 hello13Enc *encryptedExtensionsMsg - finishedHash13 hash.Hash + keySchedule *keySchedule13 clientFinishedKey []byte hsClientCipher interface{} appClientCipher interface{}