crypto/tls: fix Conn.phase data races
Phase should only be accessed under in.Mutex. Handshake and all Read operations obtain that lock. However, many functions checking for handshakeRunning only obtain handshakeMutex: reintroduce handshakeCompleted for them. ConnectionState and Close check for handshakeConfirmed, introduce an atomic flag for them.
This commit is contained in:
parent
f3fe024dc7
commit
341de96a61
2
13.go
2
13.go
@ -15,6 +15,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang_org/x/crypto/curve25519"
|
||||
@ -157,6 +158,7 @@ func (hs *serverHandshakeState) readClientFinished13() error {
|
||||
|
||||
c.hs = nil // Discard the server handshake state
|
||||
c.phase = handshakeConfirmed
|
||||
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||
c.in.setCipher(c.vers, hs.appClientCipher)
|
||||
c.in.traceErr, c.out.traceErr = nil, nil
|
||||
|
||||
|
36
conn.go
36
conn.go
@ -27,21 +27,25 @@ type Conn struct {
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
phase handshakeStatus // protected by in.Mutex
|
||||
// handshakeConfirmed is an atomic bool for phase == handshakeConfirmed
|
||||
handshakeConfirmed int32
|
||||
|
||||
// constant after handshake; protected by handshakeMutex
|
||||
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
|
||||
// handshakeCond, if not nil, indicates that a goroutine is committed
|
||||
// to running the handshake for this Conn. Other goroutines that need
|
||||
// to wait for the handshake can wait on this, under handshakeMutex.
|
||||
handshakeCond *sync.Cond
|
||||
// The transition from handshakeRunning to the next phase is covered by
|
||||
// handshakeMutex. All others by in.Mutex.
|
||||
phase handshakeStatus
|
||||
handshakeErr error // error resulting from handshake
|
||||
connID []byte // Random connection id
|
||||
clientHello []byte // ClientHello packet contents
|
||||
vers uint16 // TLS version
|
||||
haveVers bool // version has been negotiated
|
||||
config *Config // configuration passed to constructor
|
||||
// handshakeComplete is true if the connection reached application data
|
||||
// and it's equivalent to phase > handshakeRunning
|
||||
handshakeComplete bool
|
||||
// handshakes counts the number of handshakes performed on the
|
||||
// connection so far. If renegotiation is disabled then this is either
|
||||
// zero or one.
|
||||
@ -1160,7 +1164,7 @@ func (c *Conn) Write(b []byte) (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if c.phase == handshakeRunning {
|
||||
if !c.handshakeComplete {
|
||||
return 0, alertInternalError
|
||||
}
|
||||
|
||||
@ -1232,6 +1236,7 @@ func (c *Conn) handleRenegotiation() error {
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
c.phase = handshakeRunning
|
||||
c.handshakeComplete = false
|
||||
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
||||
c.handshakes++
|
||||
}
|
||||
@ -1289,12 +1294,13 @@ func (c *Conn) ConfirmHandshake() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// earlyDataReader wraps a Conn with in locked and reads only early data,
|
||||
// both buffered and still on the wire.
|
||||
// earlyDataReader wraps a Conn and reads only early data, both buffered
|
||||
// and still on the wire.
|
||||
type earlyDataReader struct {
|
||||
c *Conn
|
||||
}
|
||||
|
||||
// c.in.Mutex <= L
|
||||
func (r earlyDataReader) Read(b []byte) (n int, err error) {
|
||||
c := r.c
|
||||
|
||||
@ -1423,7 +1429,7 @@ func (c *Conn) Close() error {
|
||||
var alertErr error
|
||||
|
||||
c.handshakeMutex.Lock()
|
||||
if c.phase != handshakeRunning {
|
||||
if c.handshakeComplete {
|
||||
alertErr = c.closeNotify()
|
||||
}
|
||||
c.handshakeMutex.Unlock()
|
||||
@ -1442,7 +1448,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
|
||||
func (c *Conn) CloseWrite() error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
if c.phase == handshakeRunning {
|
||||
if !c.handshakeComplete {
|
||||
return errEarlyCloseWrite
|
||||
}
|
||||
|
||||
@ -1497,7 +1503,7 @@ func (c *Conn) Handshake() error {
|
||||
if err := c.handshakeErr; err != nil {
|
||||
return err
|
||||
}
|
||||
if c.phase != handshakeRunning {
|
||||
if c.handshakeComplete {
|
||||
return nil
|
||||
}
|
||||
if c.handshakeCond == nil {
|
||||
@ -1519,7 +1525,7 @@ func (c *Conn) Handshake() error {
|
||||
|
||||
// The handshake cannot have completed when handshakeMutex was unlocked
|
||||
// because this goroutine set handshakeCond.
|
||||
if c.handshakeErr != nil || c.phase != handshakeRunning {
|
||||
if c.handshakeErr != nil || c.handshakeComplete {
|
||||
panic("handshake should not have been able to complete after handshakeCond was set")
|
||||
}
|
||||
|
||||
@ -1541,7 +1547,7 @@ func (c *Conn) Handshake() error {
|
||||
c.flush()
|
||||
}
|
||||
|
||||
if c.handshakeErr == nil && c.phase == handshakeRunning {
|
||||
if c.handshakeErr == nil && !c.handshakeComplete {
|
||||
panic("handshake should have had a result.")
|
||||
}
|
||||
|
||||
@ -1559,10 +1565,10 @@ func (c *Conn) ConnectionState() ConnectionState {
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
var state ConnectionState
|
||||
state.HandshakeComplete = c.phase != handshakeRunning
|
||||
state.HandshakeComplete = c.handshakeComplete
|
||||
state.ServerName = c.serverName
|
||||
|
||||
if state.HandshakeComplete {
|
||||
if c.handshakeComplete {
|
||||
state.ConnectionID = c.connID
|
||||
state.ClientHello = c.clientHello
|
||||
state.Version = c.vers
|
||||
@ -1574,7 +1580,7 @@ func (c *Conn) ConnectionState() ConnectionState {
|
||||
state.VerifiedChains = c.verifiedChains
|
||||
state.SignedCertificateTimestamps = c.scts
|
||||
state.OCSPResponse = c.ocspResponse
|
||||
state.HandshakeConfirmed = c.phase == handshakeConfirmed
|
||||
state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1
|
||||
if !c.didResume {
|
||||
if c.clientFinishedIsFirst {
|
||||
state.TLSUnique = c.clientFinished[:]
|
||||
@ -1605,7 +1611,7 @@ func (c *Conn) VerifyHostname(host string) error {
|
||||
if !c.isClient {
|
||||
return errors.New("tls: VerifyHostname called on TLS server connection")
|
||||
}
|
||||
if c.phase == handshakeRunning {
|
||||
if !c.handshakeComplete {
|
||||
return errors.New("tls: handshake has not yet been performed")
|
||||
}
|
||||
if len(c.verifiedChains) == 0 {
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type clientHandshakeState struct {
|
||||
@ -252,6 +253,8 @@ NextCipherSuite:
|
||||
|
||||
c.didResume = isResume
|
||||
c.phase = handshakeConfirmed
|
||||
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||
c.handshakeComplete = true
|
||||
c.cipherSuite = suite.id
|
||||
return nil
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type Committer interface {
|
||||
@ -78,6 +79,8 @@ func (c *Conn) serverHandshake() error {
|
||||
return err
|
||||
}
|
||||
c.hs = &hs
|
||||
c.handshakeComplete = true
|
||||
return nil
|
||||
} else if isResume {
|
||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||
if err := hs.doResumeHandshake(); err != nil {
|
||||
@ -105,7 +108,6 @@ func (c *Conn) serverHandshake() error {
|
||||
return err
|
||||
}
|
||||
c.didResume = true
|
||||
c.phase = handshakeConfirmed
|
||||
} else {
|
||||
// The client didn't include a session ticket, or it wasn't
|
||||
// valid so we do a full handshake.
|
||||
@ -129,8 +131,10 @@ func (c *Conn) serverHandshake() error {
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.phase = handshakeConfirmed
|
||||
}
|
||||
c.phase = handshakeConfirmed
|
||||
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||
c.handshakeComplete = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user