crypto/tls: return from Handshake before the Client Finished in 1.3

This commit is contained in:
Filippo Valsorda 2016-11-22 22:23:34 -05:00 committed by Peter Wu
parent ee3048cfd2
commit 1117f76fcc
6 changed files with 123 additions and 77 deletions

View File

@ -32,7 +32,7 @@ install:
script: script:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test crypto/tls; fi - if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi
after_script: after_script:
- if [ "$MODE" = "interop" ]; then docker ps -a; docker logs tris-localserver; fi - if [ "$MODE" = "interop" ]; then docker ps -a; docker logs tris-localserver; fi

92
13.go
View File

@ -11,6 +11,7 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"os" "os"
"runtime/debug" "runtime/debug"
@ -50,13 +51,10 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
} }
hs.hello13.keyShare = serverKS hs.hello13.keyShare = serverKS
hash := crypto.SHA256 hash := hashForSuite(hs.suite)
if hs.suite.flags&suiteSHA384 != 0 {
hash = crypto.SHA384
}
hashSize := hash.Size() hashSize := hash.Size()
earlySecret, isPSK := hs.checkPSK(hash) earlySecret, isPSK := hs.checkPSK()
if !isPSK { if !isPSK {
earlySecret = hkdfExtract(hash, nil, nil) earlySecret = hkdfExtract(hash, nil, nil)
} }
@ -68,16 +66,15 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
return errors.New("tls: bad ECDHE client share") return errors.New("tls: bad ECDHE client share")
} }
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) hs.finishedHash13 = hash.New()
hs.finishedHash.discardHandshakeBuffer() hs.finishedHash13.Write(hs.clientHello.marshal())
hs.finishedHash.Write(hs.clientHello.marshal()) hs.finishedHash13.Write(hs.hello13.marshal())
hs.finishedHash.Write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err return err
} }
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx := hs.finishedHash.Sum() handshakeCtx := hs.finishedHash13.Sum(nil)
cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize) cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize)
cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen) cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen)
@ -91,7 +88,7 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
serverCipher := hs.suite.aead(sKey, sIV) serverCipher := hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
hs.finishedHash.Write(hs.hello13Enc.marshal()) hs.finishedHash13.Write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err return err
} }
@ -103,21 +100,19 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
} }
serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize) serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize)
clientFinishedKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize) hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)
h := hmac.New(hash.New, serverFinishedKey) verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
h.Write(hs.finishedHash.Sum())
verifyData := h.Sum(nil)
serverFinished := &finishedMsg{ serverFinished := &finishedMsg{
verifyData: verifyData, verifyData: verifyData,
} }
hs.finishedHash.Write(serverFinished.marshal()) hs.finishedHash13.Write(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err return err
} }
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash.Sum() handshakeCtx = hs.finishedHash13.Sum(nil)
cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize) cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize)
cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen) cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen)
@ -126,16 +121,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen) sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen)
sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12) sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12)
hs.clientCipher = hs.suite.aead(cKey, cIV)
serverCipher = hs.suite.aead(sKey, sIV) serverCipher = hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
// TODO(filippo): here we are ready to send Application Data, but we might want it opt-in. c.phase = waitingClientFinished
// It will be refactored anyway for 0-RTT.
if _, err := c.flush(); err != nil { return nil
return err }
}
// readClientFinished13 is called when, on the second flight of the client,
// a handshake message is received. This might be immediately or after the
// early data. Once done it sends the session tickets. Under c.in lock.
func (hs *serverHandshakeState) readClientFinished13() error {
c := hs.c
c.phase = readingClientFinished
msg, err := c.readHandshake() msg, err := c.readHandshake()
if err != nil { if err != nil {
return err return err
@ -147,20 +148,22 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
return unexpectedMessageError(clientFinished, msg) return unexpectedMessageError(clientFinished, msg)
} }
h = hmac.New(hash.New, clientFinishedKey) hash := hashForSuite(hs.suite)
h.Write(hs.finishedHash.Sum()) expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
expectedVerifyData := h.Sum(nil)
if len(expectedVerifyData) != len(clientFinished.verifyData) || if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return errors.New("tls: client's Finished message is incorrect") return errors.New("tls: client's Finished message is incorrect")
} }
hs.finishedHash.Write(clientFinished.marshal()) hs.finishedHash13.Write(clientFinished.marshal())
clientCipher = hs.suite.aead(cKey, cIV) c.in.setCipher(c.vers, hs.clientCipher)
c.in.setCipher(c.vers, clientCipher)
return hs.sendSessionTicket13(hash) // Discard the server handshake state
c.hs = nil
c.phase = handshakeComplete
return hs.sendSessionTicket13()
} }
func (hs *serverHandshakeState) sendCertificate13() error { func (hs *serverHandshakeState) sendCertificate13() error {
@ -169,7 +172,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
certMsg := &certificateMsg13{ certMsg := &certificateMsg13{
certificates: hs.cert.Certificate, certificates: hs.cert.Certificate,
} }
hs.finishedHash.Write(certMsg.marshal()) hs.finishedHash13.Write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err return err
} }
@ -186,7 +189,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
} }
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash.Sum()) toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash13.Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts) signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
@ -198,7 +201,7 @@ func (hs *serverHandshakeState) sendCertificate13() error {
signatureAndHash: sigSchemeToSigAndHash(sigScheme), signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature, signature: signature,
} }
hs.finishedHash.Write(verifyMsg.marshal()) hs.finishedHash13.Write(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err return err
} }
@ -279,6 +282,13 @@ func hashForSignatureScheme(ss SignatureScheme) crypto.Hash {
} }
} }
func hashForSuite(suite *cipherSuite) crypto.Hash {
if suite.flags&suiteSHA384 != 0 {
return crypto.SHA384
}
return crypto.SHA256
}
func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte { func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte {
message := bytes.Repeat([]byte{32}, 64) message := bytes.Repeat([]byte{32}, 64)
message = append(message, context...) message = append(message, context...)
@ -361,12 +371,18 @@ func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L
return hkdfExpand(hash, secret, hkdfLabel, L) return hkdfExpand(hash, secret, hkdfLabel, L)
} }
func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
h := hmac.New(f.New, key)
h.Write(hash.Sum(nil))
return h.Sum(nil)
}
// Maximum allowed mismatch between the stated age of a ticket // Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See // and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second const ticketAgeSkewAllowance = 10 * time.Second
func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte, ok bool) { func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if hs.c.config.SessionTicketsDisabled { if hs.c.config.SessionTicketsDisabled {
return nil, false return nil, false
} }
@ -382,6 +398,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
return nil, false return nil, false
} }
hash := hashForSuite(hs.suite)
hashSize := hash.Size() hashSize := hash.Size()
for i := range hs.clientHello.psks { for i := range hs.clientHello.psks {
sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...) sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...)
@ -411,9 +428,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize) binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New() chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated) chHash.Write(hs.clientHello.rawTruncated)
h := hmac.New(hash.New, binderFinishedKey) expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
h.Write(chHash.Sum(nil))
expectedBinder := h.Sum(nil)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
hs.hello13.psk = true hs.hello13.psk = true
@ -425,7 +440,7 @@ func (hs *serverHandshakeState) checkPSK(hash crypto.Hash) (earlySecret []byte,
return nil, false return nil, false
} }
func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error { func (hs *serverHandshakeState) sendSessionTicket13() error {
c := hs.c c := hs.c
if c.config.SessionTicketsDisabled { if c.config.SessionTicketsDisabled {
return nil return nil
@ -442,7 +457,8 @@ func (hs *serverHandshakeState) sendSessionTicket13(hash crypto.Hash) error {
return nil return nil
} }
handshakeCtx := hs.finishedHash.Sum() hash := hashForSuite(hs.suite)
handshakeCtx := hs.finishedHash13.Sum(nil)
resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size()) resumptionSecret := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "resumption master secret", hash.Size())
ageAddBuf := make([]byte, 4) ageAddBuf := make([]byte, 4)

57
conn.go
View File

@ -27,6 +27,8 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
phase handshakePhase
// constant after handshake; protected by handshakeMutex // constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed // handshakeCond, if not nil, indicates that a goroutine is committed
@ -39,9 +41,6 @@ type Conn struct {
vers uint16 // TLS version vers uint16 // TLS version
haveVers bool // version has been negotiated haveVers bool // version has been negotiated
config *Config // configuration passed to constructor config *Config // configuration passed to constructor
// handshakeComplete is true if the connection is currently transferring
// application data (i.e. is not currently processing a handshake).
handshakeComplete bool
// handshakes counts the number of handshakes performed on the // handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either // connection so far. If renegotiation is disabled then this is either
// zero or one. // zero or one.
@ -101,9 +100,21 @@ type Conn struct {
// in Conn.Write. // in Conn.Write.
activeCall int32 activeCall int32
// TLS 1.3 needs the server state until it reaches the Client Finished
hs *serverHandshakeState
tmp [16]byte tmp [16]byte
} }
type handshakePhase int
const (
earlyHandshake handshakePhase = iota
waitingClientFinished
readingClientFinished
handshakeComplete
)
// Access to net.Conn methods. // Access to net.Conn methods.
// Cannot just embed net.Conn because that would // Cannot just embed net.Conn because that would
// export the struct field too. // export the struct field too.
@ -604,12 +615,12 @@ func (c *Conn) readRecord(want recordType) error {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec: case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete { if c.phase != earlyHandshake && c.phase != readingClientFinished {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if !c.handshakeComplete { if c.phase == earlyHandshake || c.phase == earlyHandshake {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
} }
@ -741,6 +752,10 @@ Again:
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if c.phase == waitingClientFinished {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if typ != want { if typ != want {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break break
@ -750,10 +765,18 @@ Again:
case recordTypeHandshake: case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close. // TODO(rsc): Should at least pick off connection close.
if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) { if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) &&
c.phase != waitingClientFinished {
return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
} }
c.hand.Write(data) c.hand.Write(data)
if typ != want && c.phase == waitingClientFinished {
if err := c.hs.readClientFinished13(); err != nil {
c.in.setErrorLocked(err)
break
}
goto Again
}
} }
if b != nil { if b != nil {
@ -1108,7 +1131,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err return 0, err
} }
if !c.handshakeComplete { if c.phase == earlyHandshake {
return 0, alertInternalError return 0, alertInternalError
} }
@ -1175,7 +1198,7 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
c.handshakeComplete = false c.phase = earlyHandshake
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} }
@ -1277,7 +1300,7 @@ func (c *Conn) Close() error {
var alertErr error var alertErr error
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
if c.handshakeComplete { if c.phase != earlyHandshake {
alertErr = c.closeNotify() alertErr = c.closeNotify()
} }
c.handshakeMutex.Unlock() c.handshakeMutex.Unlock()
@ -1296,7 +1319,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error { func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
if !c.handshakeComplete { if c.phase == earlyHandshake {
return errEarlyCloseWrite return errEarlyCloseWrite
} }
@ -1319,7 +1342,7 @@ func (c *Conn) closeNotify() error {
// Most uses of this package need not call Handshake // Most uses of this package need not call Handshake
// explicitly: the first Read or Write will call it automatically. // explicitly: the first Read or Write will call it automatically.
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
// c.handshakeErr and c.handshakeComplete are protected by // c.handshakeErr and c.phase == earlyHandshake are protected by
// c.handshakeMutex. In order to perform a handshake, we need to lock // c.handshakeMutex. In order to perform a handshake, we need to lock
// c.in also and c.handshakeMutex must be locked after c.in. // c.in also and c.handshakeMutex must be locked after c.in.
// //
@ -1348,7 +1371,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil { if err := c.handshakeErr; err != nil {
return err return err
} }
if c.handshakeComplete { if c.phase != earlyHandshake {
return nil return nil
} }
if c.handshakeCond == nil { if c.handshakeCond == nil {
@ -1370,7 +1393,7 @@ func (c *Conn) Handshake() error {
// The handshake cannot have completed when handshakeMutex was unlocked // The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond. // because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.handshakeComplete { if c.handshakeErr != nil || c.phase != earlyHandshake {
panic("handshake should not have been able to complete after handshakeCond was set") panic("handshake should not have been able to complete after handshakeCond was set")
} }
@ -1392,7 +1415,7 @@ func (c *Conn) Handshake() error {
c.flush() c.flush()
} }
if c.handshakeErr == nil && !c.handshakeComplete { if c.handshakeErr == nil && c.phase == earlyHandshake {
panic("handshake should have had a result.") panic("handshake should have had a result.")
} }
@ -1410,10 +1433,10 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
var state ConnectionState var state ConnectionState
state.HandshakeComplete = c.handshakeComplete state.HandshakeComplete = c.phase != earlyHandshake
state.ServerName = c.serverName state.ServerName = c.serverName
if c.handshakeComplete { if state.HandshakeComplete {
state.ConnectionID = c.connID state.ConnectionID = c.connID
state.ClientHello = c.clientHello state.ClientHello = c.clientHello
state.Version = c.vers state.Version = c.vers
@ -1455,7 +1478,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient { if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection") return errors.New("tls: VerifyHostname called on TLS server connection")
} }
if !c.handshakeComplete { if c.phase == earlyHandshake {
return errors.New("tls: handshake has not yet been performed") return errors.New("tls: handshake has not yet been performed")
} }
if len(c.verifiedChains) == 0 { if len(c.verifiedChains) == 0 {

View File

@ -251,7 +251,7 @@ NextCipherSuite:
} }
c.didResume = isResume c.didResume = isResume
c.handshakeComplete = true c.phase = handshakeComplete
c.cipherSuite = suite.id c.cipherSuite = suite.id
return nil return nil
} }

View File

@ -13,6 +13,7 @@ import (
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
) )
@ -24,21 +25,28 @@ type Committer interface {
// It's discarded once the handshake has completed. // It's discarded once the handshake has completed.
type serverHandshakeState struct { type serverHandshakeState struct {
c *Conn c *Conn
clientHello *clientHelloMsg
hello *serverHelloMsg
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
suite *cipherSuite suite *cipherSuite
masterSecret []byte
cachedClientHelloInfo *ClientHelloInfo
clientHello *clientHelloMsg
cert *Certificate
// TLS 1.0-1.2 fields
hello *serverHelloMsg
ellipticOk bool ellipticOk bool
ecdsaOk bool ecdsaOk bool
rsaDecryptOk bool rsaDecryptOk bool
rsaSignOk bool rsaSignOk bool
sessionState *sessionState sessionState *sessionState
finishedHash finishedHash finishedHash finishedHash
masterSecret []byte
certsFromClient [][]byte certsFromClient [][]byte
cert *Certificate
cachedClientHelloInfo *ClientHelloInfo // TLS 1.3 fields
hello13 *serverHelloMsg13
hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash
clientFinishedKey []byte
clientCipher interface{}
} }
// serverHandshake performs a TLS handshake as a server. // serverHandshake performs a TLS handshake as a server.
@ -60,11 +68,16 @@ func (c *Conn) serverHandshake() error {
} }
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
// and https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2
c.buffering = true c.buffering = true
if hs.hello13 != nil { if hs.hello13 != nil {
if err := hs.doTLS13Handshake(); err != nil { if err := hs.doTLS13Handshake(); err != nil {
return err return err
} }
if _, err := c.flush(); err != nil {
return err
}
c.hs = &hs
} else if isResume { } else if isResume {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil { if err := hs.doResumeHandshake(); err != nil {
@ -92,6 +105,7 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
c.didResume = true c.didResume = true
c.phase = handshakeComplete
} else { } else {
// The client didn't include a session ticket, or it wasn't // The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake. // valid so we do a full handshake.
@ -115,8 +129,8 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil { if _, err := c.flush(); err != nil {
return err return err
} }
c.phase = handshakeComplete
} }
c.handshakeComplete = true
return nil return nil
} }

7
prf.go
View File

@ -122,13 +122,6 @@ var clientFinishedLabel = []byte("client finished")
var serverFinishedLabel = []byte("server finished") var serverFinishedLabel = []byte("server finished")
func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) {
if version >= VersionTLS13 {
if suite.flags&suiteSHA384 != 0 {
return prf12(sha512.New384), crypto.SHA384
}
return prf12(sha256.New), crypto.SHA256
}
switch version { switch version {
case VersionSSL30: case VersionSSL30:
return prf30, crypto.Hash(0) return prf30, crypto.Hash(0)