diff --git a/13.go b/13.go index 661caef..efeb305 100644 --- a/13.go +++ b/13.go @@ -146,23 +146,24 @@ CurvePreferenceLoop: hash := hashForSuite(hs.suite) hashSize := hash.Size() + hs.keySchedule = newKeySchedule13(hs.suite) - earlySecret, pskAlert := hs.checkPSK() + // Check for PSK and update key schedule with new early secret key + isResumed, pskAlert := hs.checkPSK() switch { case pskAlert != alertSuccess: c.sendAlert(pskAlert) return errors.New("tls: invalid client PSK") - case earlySecret == nil: - earlySecret = hkdfExtract(hash, nil, nil) - case earlySecret != nil: + case !isResumed: + // apply an empty PSK if not resumed. + hs.keySchedule.setSecret(nil) + case isResumed: c.didResume = true } - hs.finishedHash13 = hash.New() - hs.finishedHash13.Write(hs.clientHello.marshal()) + hs.keySchedule.write(hs.clientHello.marshal()) - handshakeCtx := hs.finishedHash13.Sum(nil) - earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret") + earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient) ecdheSecret := deriveECDHESecret(ks, privateKey) if ecdheSecret == nil { @@ -170,22 +171,21 @@ CurvePreferenceLoop: return errors.New("tls: bad ECDHE client share") } - hs.finishedHash13.Write(hs.hello13.marshal()) + hs.keySchedule.write(hs.hello13.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { return err } - handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) - handshakeCtx = hs.finishedHash13.Sum(nil) - clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret") + hs.keySchedule.setSecret(ecdheSecret) + clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) hs.hsClientCipher = clientCipher - serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret") + serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) c.out.setCipher(c.vers, serverCipher) serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize) hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize) - hs.finishedHash13.Write(hs.hello13Enc.marshal()) + hs.keySchedule.write(hs.hello13Enc.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { return err } @@ -196,19 +196,18 @@ CurvePreferenceLoop: } } - verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) + verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey) serverFinished := &finishedMsg{ verifyData: verifyData, } - hs.finishedHash13.Write(serverFinished.marshal()) + hs.keySchedule.write(serverFinished.marshal()) if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { return err } - hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) - handshakeCtx = hs.finishedHash13.Sum(nil) - hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret") - serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret") + hs.keySchedule.setSecret(nil) // derive master secret + hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient) + serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer) c.out.setCipher(c.vers, serverCipher) if c.hand.Len() > 0 { @@ -247,13 +246,13 @@ func (hs *serverHandshakeState) readClientFinished13() error { } hash := hashForSuite(hs.suite) - expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) + expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey) if len(expectedVerifyData) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { c.sendAlert(alertDecryptError) return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash13.Write(clientFinished.marshal()) + hs.keySchedule.write(clientFinished.marshal()) c.hs = nil // Discard the server handshake state if c.hand.Len() > 0 { @@ -287,7 +286,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { } certMsg := &certificateMsg13{certificates: certEntries} - hs.finishedHash13.Write(certMsg.marshal()) + hs.keySchedule.write(certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { return err } @@ -304,7 +303,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} } - toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil)) + toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil)) signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) if err != nil { c.sendAlert(alertInternalError) @@ -316,7 +315,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { signatureAndHash: sigSchemeToSigAndHash(sigScheme), signature: signature, } - hs.finishedHash13.Write(verifyMsg.marshal()) + hs.keySchedule.write(verifyMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { return err } @@ -505,22 +504,16 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte { return h.Sum(nil) } -func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) { - hash := hashForSuite(hs.suite) - trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size()) - key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen) - iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12) - return hs.suite.aead(key, iv), trafficSecret -} - // Maximum allowed mismatch between the stated age of a ticket // and the server-observed one. See // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. const ticketAgeSkewAllowance = 10 * time.Second -func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { +// checkPSK tries to resume using a PSK, returning true (and updating the +// early secret in the key schedule) if the PSK was used and false otherwise. +func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) { if hs.c.config.SessionTicketsDisabled { - return nil, alertSuccess + return false, alertSuccess } foundDHE := false @@ -531,7 +524,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { } } if !foundDHE { - return nil, alertSuccess + return false, alertSuccess } hash := hashForSuite(hs.suite) @@ -580,23 +573,22 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { continue } - earlySecret := hkdfExtract(hash, s.resumptionSecret, nil) - handshakeCtx := hash.New().Sum(nil) - binderKey := hkdfExpandLabel(hash, earlySecret, handshakeCtx, "resumption psk binder key", hashSize) + hs.keySchedule.setSecret(s.resumptionSecret) + binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) chHash := hash.New() chHash.Write(hs.clientHello.rawTruncated) expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 { - return nil, alertDecryptError + return false, alertDecryptError } if i == 0 && hs.clientHello.earlyData { // This is a ticket intended to be used for 0-RTT if s.maxEarlyDataLen == 0 { // But we had not tagged it as such. - return nil, alertIllegalParameter + return false, alertIllegalParameter } if hs.c.config.Accept0RTTData { hs.c.binder = expectedBinder @@ -606,10 +598,10 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { } hs.hello13.psk = true hs.hello13.pskIdentity = uint16(i) - return earlySecret, alertSuccess + return true, alertSuccess } - return nil, alertSuccess + return false, alertSuccess } func (hs *serverHandshakeState) sendSessionTicket13() error { @@ -629,9 +621,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return nil } - hash := hashForSuite(hs.suite) - handshakeCtx := hs.finishedHash13.Sum(nil) - resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size()) + resumptionSecret := hs.keySchedule.deriveSecret(secretResumption) ageAddBuf := make([]byte, 4) sessionState := &sessionState13{ diff --git a/handshake_server.go b/handshake_server.go index b71348d..f3bc79c 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -13,7 +13,6 @@ import ( "encoding/asn1" "errors" "fmt" - "hash" "io" "sync/atomic" ) @@ -45,7 +44,7 @@ type serverHandshakeState struct { // TLS 1.3 fields hello13 *serverHelloMsg13 hello13Enc *encryptedExtensionsMsg - finishedHash13 hash.Hash + keySchedule *keySchedule13 clientFinishedKey []byte hsClientCipher interface{} appClientCipher interface{}