crypto/tls: return from Handshake before the Client Finished in 1.3

This commit is contained in:
Filippo Valsorda 2016-11-22 22:23:34 -05:00 committed by Peter Wu
parent ee3048cfd2
commit 1117f76fcc
6 changed files with 123 additions and 77 deletions

View File

@ -32,7 +32,7 @@ install:
script:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test crypto/tls; fi
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi
after_script:
- if [ "$MODE" = "interop" ]; then docker ps -a; docker logs tris-localserver; fi

92
13.go
View File

@ -11,6 +11,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"os"
"runtime/debug"
@ -50,13 +51,10 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
}
hs.hello13.keyShare = serverKS
hash := crypto.SHA256
if hs.suite.flags&suiteSHA384 != 0 {
hash = crypto.SHA384
}
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
earlySecret, isPSK := hs.checkPSK(hash)
earlySecret, isPSK := hs.checkPSK()
if !isPSK {
earlySecret = hkdfExtract(hash, nil, nil)
}
@ -68,16 +66,15 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
return errors.New("tls: bad ECDHE client share")
}
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
hs.finishedHash.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.hello13.marshal())
hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
hs.finishedHash13.Write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err
}
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx := hs.finishedHash.Sum()
handshakeCtx := hs.finishedHash13.Sum(nil)
cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize)
cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen)
@ -91,7 +88,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
serverCipher := hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher)
hs.finishedHash.Write(hs.hello13Enc.marshal())
hs.finishedHash13.Write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err
}
@ -103,21 +100,19 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
}
serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize)
clientFinishedKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)
h := hmac.New(hash.New, serverFinishedKey)
h.Write(hs.finishedHash.Sum())
verifyData := h.Sum(nil)
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
serverFinished := &finishedMsg{
verifyData: verifyData,
}
hs.finishedHash.Write(serverFinished.marshal())
hs.finishedHash13.Write(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err
}
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash.Sum()
handshakeCtx = hs.finishedHash13.Sum(nil)
cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize)
cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen)
@ -126,16 +121,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen)
sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12)
hs.clientCipher = hs.suite.aead(cKey, cIV)
serverCipher = hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher)
// TODO(filippo): here we are ready to send Application Data, but we might want it opt-in.
// It will be refactored anyway for 0-RTT.
c.phase = waitingClientFinished
if _, err := c.flush(); err != nil {
return err
}
return nil
}
// readClientFinished13 is called when, on the second flight of the client,
// a handshake message is received. This might be immediately or after the
// early data. Once done it sends the session tickets. Under c.in lock.
func (hs *serverHandshakeState) readClientFinished13() error {
c := hs.c
c.phase = readingClientFinished
msg, err := c.readHandshake()
if err != nil {
return err
@ -147,20 +148,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
return unexpectedMessageError(clientFinished, msg)
}
h = hmac.New(hash.New, clientFinishedKey)
h.Write(hs.finishedHash.Sum())
expectedVerifyData := h.Sum(nil)
hash := hashForSuite(hs.suite)
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect")
}
hs.finishedHash.Write(clientFinished.marshal())
hs.finishedHash13.Write(clientFinished.marshal())
clientCipher = hs.suite.aead(cKey, cIV)
c.in.setCipher(c.vers, clientCipher)
c.in.setCipher(c.vers, hs.clientCipher)
return hs.sendSessionTicket13(hash)
// Discard the server handshake state
c.hs = nil
c.phase = handshakeComplete
return hs.sendSessionTicket13()
}
func (hs *serverHandshakeState) sendCertificate13() error {
@ -169,7 +172,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
certMsg := &certificateMsg13{
certificates: hs.cert.Certificate,
}
hs.finishedHash.Write(certMsg.marshal())
hs.finishedHash13.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
@ -186,7 +189,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash.Sum())
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil {
c.sendAlert(alertInternalError)
@ -198,7 +201,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature,
}
hs.finishedHash.Write(verifyMsg.marshal())
hs.finishedHash13.Write(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err
}
@ -279,6 +282,13 @@ func hashForSignatureScheme(ss SignatureScheme) crypto.Hash {
}
}
func hashForSuite(suite *cipherSuite) crypto.Hash {
if suite.flags&suiteSHA384 != 0 {
return crypto.SHA384
}
return crypto.SHA256
}
func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte {
message := bytes.Repeat([]byte{32}, 64)
message = append(message, context...)
@ -361,12 +371,18 @@ func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L
return hkdfExpand(hash, secret, hkdfLabel, L)
}
func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
h := hmac.New(f.New, key)
h.Write(hash.Sum(nil))
return h.Sum(nil)
}
// Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second
func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, ok bool) {
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if hs.c.config.SessionTicketsDisabled {
return nil, false
}
@ -382,6 +398,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
return nil, false
}
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
for i := range hs.clientHello.psks {
sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...)
@ -411,9 +428,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
h := hmac.New(hash.New, binderFinishedKey)
h.Write(chHash.Sum(nil))
expectedBinder := h.Sum(nil)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
hs.hello13.psk = true
@ -425,7 +440,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
return nil, false
}
func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error {
func (hs *serverHandshakeState) sendSessionTicket13() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
@ -442,7 +457,8 @@ func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error {
return nil
}
handshakeCtx := hs.finishedHash.Sum()
hash := hashForSuite(hs.suite)
handshakeCtx := hs.finishedHash13.Sum(nil)
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
ageAddBuf := make([]byte, 4)

57
conn.go
View File

@ -27,6 +27,8 @@ type Conn struct {
conn net.Conn
isClient bool
phase handshakePhase
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed
@ -39,9 +41,6 @@ type Conn struct {
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// handshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
handshakeComplete bool
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
@ -101,9 +100,21 @@ type Conn struct {
// in Conn.Write.
activeCall int32
// TLS 1.3 needs the server state until it reaches the Client Finished
hs *serverHandshakeState
tmp [16]byte
}
type handshakePhase int
const (
earlyHandshake handshakePhase = iota
waitingClientFinished
readingClientFinished
handshakeComplete
)
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
@ -604,12 +615,12 @@ func (c *Conn) readRecord(want recordType) error {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
if c.phase != earlyHandshake && c.phase != readingClientFinished {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
}
case recordTypeApplicationData:
if !c.handshakeComplete {
if c.phase == earlyHandshake || c.phase == earlyHandshake {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
}
@ -741,6 +752,10 @@ Again:
}
case recordTypeApplicationData:
if c.phase == waitingClientFinished {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if typ != want {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
@ -750,10 +765,18 @@ Again:
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) &&
c.phase != waitingClientFinished {
return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
if typ != want && c.phase == waitingClientFinished {
if err := c.hs.readClientFinished13(); err != nil {
c.in.setErrorLocked(err)
break
}
goto Again
}
}
if b != nil {
@ -1108,7 +1131,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err
}
if !c.handshakeComplete {
if c.phase == earlyHandshake {
return 0, alertInternalError
}
@ -1175,7 +1198,7 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
c.handshakeComplete = false
c.phase = earlyHandshake
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++
}
@ -1277,7 +1300,7 @@ func (c *Conn) Close() error {
var alertErr error
c.handshakeMutex.Lock()
if c.handshakeComplete {
if c.phase != earlyHandshake {
alertErr = c.closeNotify()
}
c.handshakeMutex.Unlock()
@ -1296,7 +1319,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if !c.handshakeComplete {
if c.phase == earlyHandshake {
return errEarlyCloseWrite
}
@ -1319,7 +1342,7 @@ func (c *Conn) closeNotify() error {
// Most uses of this package need not call Handshake
// explicitly: the first Read or Write will call it automatically.
func (c *Conn) Handshake() error {
// c.handshakeErr and c.handshakeComplete are protected by
// c.handshakeErr and c.phase == earlyHandshake are protected by
// c.handshakeMutex. In order to perform a handshake, we need to lock
// c.in also and c.handshakeMutex must be locked after c.in.
//
@ -1348,7 +1371,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
if c.phase != earlyHandshake {
return nil
}
if c.handshakeCond == nil {
@ -1370,7 +1393,7 @@ func (c *Conn) Handshake() error {
// The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.handshakeComplete {
if c.handshakeErr != nil || c.phase != earlyHandshake {
panic("handshake should not have been able to complete after handshakeCond was set")
}
@ -1392,7 +1415,7 @@ func (c *Conn) Handshake() error {
c.flush()
}
if c.handshakeErr == nil && !c.handshakeComplete {
if c.handshakeErr == nil && c.phase == earlyHandshake {
panic("handshake should have had a result.")
}
@ -1410,10 +1433,10 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock()
var state ConnectionState
state.HandshakeComplete = c.handshakeComplete
state.HandshakeComplete = c.phase != earlyHandshake
state.ServerName = c.serverName
if c.handshakeComplete {
if state.HandshakeComplete {
state.ConnectionID = c.connID
state.ClientHello = c.clientHello
state.Version = c.vers
@ -1455,7 +1478,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if !c.handshakeComplete {
if c.phase == earlyHandshake {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {

View File

@ -251,7 +251,7 @@ NextCipherSuite:
}
c.didResume = isResume
c.handshakeComplete = true
c.phase = handshakeComplete
c.cipherSuite = suite.id
return nil
}

View File

@ -13,6 +13,7 @@ import (
"encoding/asn1"
"errors"
"fmt"
"hash"
"io"
)
@ -24,21 +25,28 @@ type Committer interface {
// It's discarded once the handshake has completed.
type serverHandshakeState struct {
c *Conn
clientHello *clientHelloMsg
hello *serverHelloMsg
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
suite *cipherSuite
ellipticOk bool
ecdsaOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *sessionState
finishedHash finishedHash
masterSecret []byte
certsFromClient [][]byte
cert *Certificate
cachedClientHelloInfo *ClientHelloInfo
clientHello *clientHelloMsg
cert *Certificate
// TLS 1.0-1.2 fields
hello *serverHelloMsg
ellipticOk bool
ecdsaOk bool
rsaDecryptOk bool
rsaSignOk bool
sessionState *sessionState
finishedHash finishedHash
certsFromClient [][]byte
// TLS 1.3 fields
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash
clientFinishedKey []byte
clientCipher interface{}
}
// serverHandshake performs a TLS handshake as a server.
@ -60,11 +68,16 @@ func (c *Conn) serverHandshake() error {
}
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
// and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2
c.buffering = true
if hs.hello13 != nil {
if err := hs.doTLS13Handshake(); err != nil {
return err
}
if _, err := c.flush(); err != nil {
return err
}
c.hs = &hs
} else if isResume {
// The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil {
@ -92,6 +105,7 @@ func (c *Conn) serverHandshake() error {
return err
}
c.didResume = true
c.phase = handshakeComplete
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
@ -115,8 +129,8 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil {
return err
}
c.phase = handshakeComplete
}
c.handshakeComplete = true
return nil
}

7
prf.go
View File

@ -122,13 +122,6 @@ var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
if version >= VersionTLS13 {
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
}
switch version {
case VersionSSL30:
return prf30, crypto.Hash(0)