From 1117f76fcc0faf3266e7ec61f6d1c62d45bb170f Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Tue, 22 Nov 2016 22:23:34 -0500 Subject: [PATCH] crypto/tls: return from Handshake before the Client Finished in 1.3 --- .travis.yml | 2 +- 13.go | 92 ++++++++++++++++++++++++++------------------- conn.go | 57 +++++++++++++++++++--------- handshake_client.go | 2 +- handshake_server.go | 40 +++++++++++++------- prf.go | 7 ---- 6 files changed, 123 insertions(+), 77 deletions(-) diff --git a/.travis.yml b/.travis.yml index 37c25d9..34728f5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,7 +32,7 @@ install: script: - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA - - if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test crypto/tls; fi + - if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi after_script: - if [ "$MODE" = "interop" ]; then docker ps -a; docker logs tris-localserver; fi diff --git a/13.go b/13.go index 12fb96a..3be0305 100644 --- a/13.go +++ b/13.go @@ -11,6 +11,7 @@ import ( "encoding/hex" "errors" "fmt" + "hash" "io" "os" "runtime/debug" @@ -50,13 +51,10 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { } hs.hello13.keyShare = serverKS - hash := crypto.SHA256 - if hs.suite.flags&suiteSHA384 != 0 { - hash = crypto.SHA384 - } + hash := hashForSuite(hs.suite) hashSize := hash.Size() - earlySecret, isPSK := hs.checkPSK(hash) + earlySecret, isPSK := hs.checkPSK() if !isPSK { earlySecret = hkdfExtract(hash, nil, nil) } @@ -68,16 +66,15 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { return errors.New("tls: bad ECDHE client share") } - hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - hs.finishedHash.discardHandshakeBuffer() - hs.finishedHash.Write(hs.clientHello.marshal()) - hs.finishedHash.Write(hs.hello13.marshal()) + hs.finishedHash13 = hash.New() + hs.finishedHash13.Write(hs.clientHello.marshal()) + hs.finishedHash13.Write(hs.hello13.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { return err } handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) - handshakeCtx := hs.finishedHash.Sum() + handshakeCtx := hs.finishedHash13.Sum(nil) cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize) cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen) @@ -91,7 +88,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { serverCipher := hs.suite.aead(sKey, sIV) c.out.setCipher(c.vers, serverCipher) - hs.finishedHash.Write(hs.hello13Enc.marshal()) + hs.finishedHash13.Write(hs.hello13Enc.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { return err } @@ -103,21 +100,19 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { } serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize) - clientFinishedKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize) + hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize) - h := hmac.New(hash.New, serverFinishedKey) - h.Write(hs.finishedHash.Sum()) - verifyData := h.Sum(nil) + verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) serverFinished := &finishedMsg{ verifyData: verifyData, } - hs.finishedHash.Write(serverFinished.marshal()) + hs.finishedHash13.Write(serverFinished.marshal()) if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { return err } hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) - handshakeCtx = hs.finishedHash.Sum() + handshakeCtx = hs.finishedHash13.Sum(nil) cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize) cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen) @@ -126,16 +121,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen) sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12) + hs.clientCipher = hs.suite.aead(cKey, cIV) serverCipher = hs.suite.aead(sKey, sIV) c.out.setCipher(c.vers, serverCipher) - // TODO(filippo): here we are ready to send Application Data, but we might want it opt-in. - // It will be refactored anyway for 0-RTT. + c.phase = waitingClientFinished - if _, err := c.flush(); err != nil { - return err - } + return nil +} +// readClientFinished13 is called when, on the second flight of the client, +// a handshake message is received. This might be immediately or after the +// early data. Once done it sends the session tickets. Under c.in lock. +func (hs *serverHandshakeState) readClientFinished13() error { + c := hs.c + + c.phase = readingClientFinished msg, err := c.readHandshake() if err != nil { return err @@ -147,20 +148,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { return unexpectedMessageError(clientFinished, msg) } - h = hmac.New(hash.New, clientFinishedKey) - h.Write(hs.finishedHash.Sum()) - expectedVerifyData := h.Sum(nil) + hash := hashForSuite(hs.suite) + expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) if len(expectedVerifyData) != len(clientFinished.verifyData) || subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { c.sendAlert(alertHandshakeFailure) return errors.New("tls: client's Finished message is incorrect") } - hs.finishedHash.Write(clientFinished.marshal()) + hs.finishedHash13.Write(clientFinished.marshal()) - clientCipher = hs.suite.aead(cKey, cIV) - c.in.setCipher(c.vers, clientCipher) + c.in.setCipher(c.vers, hs.clientCipher) - return hs.sendSessionTicket13(hash) + // Discard the server handshake state + c.hs = nil + c.phase = handshakeComplete + + return hs.sendSessionTicket13() } func (hs *serverHandshakeState) sendCertificate13() error { @@ -169,7 +172,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { certMsg := &certificateMsg13{ certificates: hs.cert.Certificate, } - hs.finishedHash.Write(certMsg.marshal()) + hs.finishedHash13.Write(certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { return err } @@ -186,7 +189,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} } - toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash.Sum()) + toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil)) signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) if err != nil { c.sendAlert(alertInternalError) @@ -198,7 +201,7 @@ func (hs *serverHandshakeState) sendCertificate13() error { signatureAndHash: sigSchemeToSigAndHash(sigScheme), signature: signature, } - hs.finishedHash.Write(verifyMsg.marshal()) + hs.finishedHash13.Write(verifyMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { return err } @@ -279,6 +282,13 @@ func hashForSignatureScheme(ss SignatureScheme) crypto.Hash { } } +func hashForSuite(suite *cipherSuite) crypto.Hash { + if suite.flags&suiteSHA384 != 0 { + return crypto.SHA384 + } + return crypto.SHA256 +} + func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte { message := bytes.Repeat([]byte{32}, 64) message = append(message, context...) @@ -361,12 +371,18 @@ func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L return hkdfExpand(hash, secret, hkdfLabel, L) } +func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte { + h := hmac.New(f.New, key) + h.Write(hash.Sum(nil)) + return h.Sum(nil) +} + // Maximum allowed mismatch between the stated age of a ticket // and the server-observed one. See // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. const ticketAgeSkewAllowance = 10 * time.Second -func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, ok bool) { +func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { if hs.c.config.SessionTicketsDisabled { return nil, false } @@ -382,6 +398,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, return nil, false } + hash := hashForSuite(hs.suite) hashSize := hash.Size() for i := range hs.clientHello.psks { sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...) @@ -411,9 +428,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) chHash := hash.New() chHash.Write(hs.clientHello.rawTruncated) - h := hmac.New(hash.New, binderFinishedKey) - h.Write(chHash.Sum(nil)) - expectedBinder := h.Sum(nil) + expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { hs.hello13.psk = true @@ -425,7 +440,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, return nil, false } -func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error { +func (hs *serverHandshakeState) sendSessionTicket13() error { c := hs.c if c.config.SessionTicketsDisabled { return nil @@ -442,7 +457,8 @@ func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error { return nil } - handshakeCtx := hs.finishedHash.Sum() + hash := hashForSuite(hs.suite) + handshakeCtx := hs.finishedHash13.Sum(nil) resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size()) ageAddBuf := make([]byte, 4) diff --git a/conn.go b/conn.go index d1e358c..1adf7ea 100644 --- a/conn.go +++ b/conn.go @@ -27,6 +27,8 @@ type Conn struct { conn net.Conn isClient bool + phase handshakePhase + // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex // handshakeCond, if not nil, indicates that a goroutine is committed @@ -39,9 +41,6 @@ type Conn struct { vers uint16 // TLS version haveVers bool // version has been negotiated config *Config // configuration passed to constructor - // handshakeComplete is true if the connection is currently transferring - // application data (i.e. is not currently processing a handshake). - handshakeComplete bool // handshakes counts the number of handshakes performed on the // connection so far. If renegotiation is disabled then this is either // zero or one. @@ -101,9 +100,21 @@ type Conn struct { // in Conn.Write. activeCall int32 + // TLS 1.3 needs the server state until it reaches the Client Finished + hs *serverHandshakeState + tmp [16]byte } +type handshakePhase int + +const ( + earlyHandshake handshakePhase = iota + waitingClientFinished + readingClientFinished + handshakeComplete +) + // Access to net.Conn methods. // Cannot just embed net.Conn because that would // export the struct field too. @@ -604,12 +615,12 @@ func (c *Conn) readRecord(want recordType) error { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) case recordTypeHandshake, recordTypeChangeCipherSpec: - if c.handshakeComplete { + if c.phase != earlyHandshake && c.phase != readingClientFinished { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) } case recordTypeApplicationData: - if !c.handshakeComplete { + if c.phase == earlyHandshake || c.phase == earlyHandshake { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) } @@ -741,6 +752,10 @@ Again: } case recordTypeApplicationData: + if c.phase == waitingClientFinished { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } if typ != want { c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) break @@ -750,10 +765,18 @@ Again: case recordTypeHandshake: // TODO(rsc): Should at least pick off connection close. - if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) { + if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) && + c.phase != waitingClientFinished { return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } c.hand.Write(data) + if typ != want && c.phase == waitingClientFinished { + if err := c.hs.readClientFinished13(); err != nil { + c.in.setErrorLocked(err) + break + } + goto Again + } } if b != nil { @@ -1108,7 +1131,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if !c.handshakeComplete { + if c.phase == earlyHandshake { return 0, alertInternalError } @@ -1175,7 +1198,7 @@ func (c *Conn) handleRenegotiation() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - c.handshakeComplete = false + c.phase = earlyHandshake if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { c.handshakes++ } @@ -1277,7 +1300,7 @@ func (c *Conn) Close() error { var alertErr error c.handshakeMutex.Lock() - if c.handshakeComplete { + if c.phase != earlyHandshake { alertErr = c.closeNotify() } c.handshakeMutex.Unlock() @@ -1296,7 +1319,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com func (c *Conn) CloseWrite() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if !c.handshakeComplete { + if c.phase == earlyHandshake { return errEarlyCloseWrite } @@ -1319,7 +1342,7 @@ func (c *Conn) closeNotify() error { // Most uses of this package need not call Handshake // explicitly: the first Read or Write will call it automatically. func (c *Conn) Handshake() error { - // c.handshakeErr and c.handshakeComplete are protected by + // c.handshakeErr and c.phase == earlyHandshake are protected by // c.handshakeMutex. In order to perform a handshake, we need to lock // c.in also and c.handshakeMutex must be locked after c.in. // @@ -1348,7 +1371,7 @@ func (c *Conn) Handshake() error { if err := c.handshakeErr; err != nil { return err } - if c.handshakeComplete { + if c.phase != earlyHandshake { return nil } if c.handshakeCond == nil { @@ -1370,7 +1393,7 @@ func (c *Conn) Handshake() error { // The handshake cannot have completed when handshakeMutex was unlocked // because this goroutine set handshakeCond. - if c.handshakeErr != nil || c.handshakeComplete { + if c.handshakeErr != nil || c.phase != earlyHandshake { panic("handshake should not have been able to complete after handshakeCond was set") } @@ -1392,7 +1415,7 @@ func (c *Conn) Handshake() error { c.flush() } - if c.handshakeErr == nil && !c.handshakeComplete { + if c.handshakeErr == nil && c.phase == earlyHandshake { panic("handshake should have had a result.") } @@ -1410,10 +1433,10 @@ func (c *Conn) ConnectionState() ConnectionState { defer c.handshakeMutex.Unlock() var state ConnectionState - state.HandshakeComplete = c.handshakeComplete + state.HandshakeComplete = c.phase != earlyHandshake state.ServerName = c.serverName - if c.handshakeComplete { + if state.HandshakeComplete { state.ConnectionID = c.connID state.ClientHello = c.clientHello state.Version = c.vers @@ -1455,7 +1478,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if !c.handshakeComplete { + if c.phase == earlyHandshake { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { diff --git a/handshake_client.go b/handshake_client.go index 6cac2d6..3f120e0 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -251,7 +251,7 @@ NextCipherSuite: } c.didResume = isResume - c.handshakeComplete = true + c.phase = handshakeComplete c.cipherSuite = suite.id return nil } diff --git a/handshake_server.go b/handshake_server.go index 8388076..a66af85 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -13,6 +13,7 @@ import ( "encoding/asn1" "errors" "fmt" + "hash" "io" ) @@ -24,21 +25,28 @@ type Committer interface { // It's discarded once the handshake has completed. type serverHandshakeState struct { c *Conn - clientHello *clientHelloMsg - hello *serverHelloMsg - hello13 *serverHelloMsg13 - hello13Enc *encryptedExtensionsMsg suite *cipherSuite - ellipticOk bool - ecdsaOk bool - rsaDecryptOk bool - rsaSignOk bool - sessionState *sessionState - finishedHash finishedHash masterSecret []byte - certsFromClient [][]byte - cert *Certificate cachedClientHelloInfo *ClientHelloInfo + clientHello *clientHelloMsg + cert *Certificate + + // TLS 1.0-1.2 fields + hello *serverHelloMsg + ellipticOk bool + ecdsaOk bool + rsaDecryptOk bool + rsaSignOk bool + sessionState *sessionState + finishedHash finishedHash + certsFromClient [][]byte + + // TLS 1.3 fields + hello13 *serverHelloMsg13 + hello13Enc *encryptedExtensionsMsg + finishedHash13 hash.Hash + clientFinishedKey []byte + clientCipher interface{} } // serverHandshake performs a TLS handshake as a server. @@ -60,11 +68,16 @@ func (c *Conn) serverHandshake() error { } // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 + // and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2 c.buffering = true if hs.hello13 != nil { if err := hs.doTLS13Handshake(); err != nil { return err } + if _, err := c.flush(); err != nil { + return err + } + c.hs = &hs } else if isResume { // The client has included a session ticket and so we do an abbreviated handshake. if err := hs.doResumeHandshake(); err != nil { @@ -92,6 +105,7 @@ func (c *Conn) serverHandshake() error { return err } c.didResume = true + c.phase = handshakeComplete } else { // The client didn't include a session ticket, or it wasn't // valid so we do a full handshake. @@ -115,8 +129,8 @@ func (c *Conn) serverHandshake() error { if _, err := c.flush(); err != nil { return err } + c.phase = handshakeComplete } - c.handshakeComplete = true return nil } diff --git a/prf.go b/prf.go index e397bb9..5833fc1 100644 --- a/prf.go +++ b/prf.go @@ -122,13 +122,6 @@ var clientFinishedLabel = []byte("client finished") var serverFinishedLabel = []byte("server finished") func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { - if version >= VersionTLS13 { - if suite.flags&suiteSHA384 != 0 { - return prf12(sha512.New384), crypto.SHA384 - } - return prf12(sha256.New), crypto.SHA256 - } - switch version { case VersionSSL30: return prf30, crypto.Hash(0)