crypto/tls: fix Conn.phase data races
Phase should only be accessed under in.Mutex. Handshake and all Read operations obtain that lock. However, many functions checking for handshakeRunning only obtain handshakeMutex: reintroduce handshakeCompleted for them. ConnectionState and Close check for handshakeConfirmed, introduce an atomic flag for them.
This commit is contained in:
parent
f3fe024dc7
commit
341de96a61
2
13.go
2
13.go
@ -15,6 +15,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang_org/x/crypto/curve25519"
|
"golang_org/x/crypto/curve25519"
|
||||||
@ -157,6 +158,7 @@ func (hs *serverHandshakeState) readClientFinished13() error {
|
|||||||
|
|
||||||
c.hs = nil // Discard the server handshake state
|
c.hs = nil // Discard the server handshake state
|
||||||
c.phase = handshakeConfirmed
|
c.phase = handshakeConfirmed
|
||||||
|
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||||
c.in.setCipher(c.vers, hs.appClientCipher)
|
c.in.setCipher(c.vers, hs.appClientCipher)
|
||||||
c.in.traceErr, c.out.traceErr = nil, nil
|
c.in.traceErr, c.out.traceErr = nil, nil
|
||||||
|
|
||||||
|
48
conn.go
48
conn.go
@ -27,21 +27,25 @@ type Conn struct {
|
|||||||
conn net.Conn
|
conn net.Conn
|
||||||
isClient bool
|
isClient bool
|
||||||
|
|
||||||
|
phase handshakeStatus // protected by in.Mutex
|
||||||
|
// handshakeConfirmed is an atomic bool for phase == handshakeConfirmed
|
||||||
|
handshakeConfirmed int32
|
||||||
|
|
||||||
// constant after handshake; protected by handshakeMutex
|
// constant after handshake; protected by handshakeMutex
|
||||||
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
|
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
|
||||||
// handshakeCond, if not nil, indicates that a goroutine is committed
|
// handshakeCond, if not nil, indicates that a goroutine is committed
|
||||||
// to running the handshake for this Conn. Other goroutines that need
|
// to running the handshake for this Conn. Other goroutines that need
|
||||||
// to wait for the handshake can wait on this, under handshakeMutex.
|
// to wait for the handshake can wait on this, under handshakeMutex.
|
||||||
handshakeCond *sync.Cond
|
handshakeCond *sync.Cond
|
||||||
// The transition from handshakeRunning to the next phase is covered by
|
handshakeErr error // error resulting from handshake
|
||||||
// handshakeMutex. All others by in.Mutex.
|
connID []byte // Random connection id
|
||||||
phase handshakeStatus
|
clientHello []byte // ClientHello packet contents
|
||||||
handshakeErr error // error resulting from handshake
|
vers uint16 // TLS version
|
||||||
connID []byte // Random connection id
|
haveVers bool // version has been negotiated
|
||||||
clientHello []byte // ClientHello packet contents
|
config *Config // configuration passed to constructor
|
||||||
vers uint16 // TLS version
|
// handshakeComplete is true if the connection reached application data
|
||||||
haveVers bool // version has been negotiated
|
// and it's equivalent to phase > handshakeRunning
|
||||||
config *Config // configuration passed to constructor
|
handshakeComplete bool
|
||||||
// handshakes counts the number of handshakes performed on the
|
// handshakes counts the number of handshakes performed on the
|
||||||
// connection so far. If renegotiation is disabled then this is either
|
// connection so far. If renegotiation is disabled then this is either
|
||||||
// zero or one.
|
// zero or one.
|
||||||
@ -1160,7 +1164,7 @@ func (c *Conn) Write(b []byte) (int, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.phase == handshakeRunning {
|
if !c.handshakeComplete {
|
||||||
return 0, alertInternalError
|
return 0, alertInternalError
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1232,6 +1236,7 @@ func (c *Conn) handleRenegotiation() error {
|
|||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
c.phase = handshakeRunning
|
c.phase = handshakeRunning
|
||||||
|
c.handshakeComplete = false
|
||||||
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
||||||
c.handshakes++
|
c.handshakes++
|
||||||
}
|
}
|
||||||
@ -1289,12 +1294,13 @@ func (c *Conn) ConfirmHandshake() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// earlyDataReader wraps a Conn with in locked and reads only early data,
|
// earlyDataReader wraps a Conn and reads only early data, both buffered
|
||||||
// both buffered and still on the wire.
|
// and still on the wire.
|
||||||
type earlyDataReader struct {
|
type earlyDataReader struct {
|
||||||
c *Conn
|
c *Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// c.in.Mutex <= L
|
||||||
func (r earlyDataReader) Read(b []byte) (n int, err error) {
|
func (r earlyDataReader) Read(b []byte) (n int, err error) {
|
||||||
c := r.c
|
c := r.c
|
||||||
|
|
||||||
@ -1423,7 +1429,7 @@ func (c *Conn) Close() error {
|
|||||||
var alertErr error
|
var alertErr error
|
||||||
|
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
if c.phase != handshakeRunning {
|
if c.handshakeComplete {
|
||||||
alertErr = c.closeNotify()
|
alertErr = c.closeNotify()
|
||||||
}
|
}
|
||||||
c.handshakeMutex.Unlock()
|
c.handshakeMutex.Unlock()
|
||||||
@ -1442,7 +1448,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
|
|||||||
func (c *Conn) CloseWrite() error {
|
func (c *Conn) CloseWrite() error {
|
||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
if c.phase == handshakeRunning {
|
if !c.handshakeComplete {
|
||||||
return errEarlyCloseWrite
|
return errEarlyCloseWrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1497,7 +1503,7 @@ func (c *Conn) Handshake() error {
|
|||||||
if err := c.handshakeErr; err != nil {
|
if err := c.handshakeErr; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.phase != handshakeRunning {
|
if c.handshakeComplete {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if c.handshakeCond == nil {
|
if c.handshakeCond == nil {
|
||||||
@ -1519,7 +1525,7 @@ func (c *Conn) Handshake() error {
|
|||||||
|
|
||||||
// The handshake cannot have completed when handshakeMutex was unlocked
|
// The handshake cannot have completed when handshakeMutex was unlocked
|
||||||
// because this goroutine set handshakeCond.
|
// because this goroutine set handshakeCond.
|
||||||
if c.handshakeErr != nil || c.phase != handshakeRunning {
|
if c.handshakeErr != nil || c.handshakeComplete {
|
||||||
panic("handshake should not have been able to complete after handshakeCond was set")
|
panic("handshake should not have been able to complete after handshakeCond was set")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1541,7 +1547,7 @@ func (c *Conn) Handshake() error {
|
|||||||
c.flush()
|
c.flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.handshakeErr == nil && c.phase == handshakeRunning {
|
if c.handshakeErr == nil && !c.handshakeComplete {
|
||||||
panic("handshake should have had a result.")
|
panic("handshake should have had a result.")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1559,10 +1565,10 @@ func (c *Conn) ConnectionState() ConnectionState {
|
|||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
|
|
||||||
var state ConnectionState
|
var state ConnectionState
|
||||||
state.HandshakeComplete = c.phase != handshakeRunning
|
state.HandshakeComplete = c.handshakeComplete
|
||||||
state.ServerName = c.serverName
|
state.ServerName = c.serverName
|
||||||
|
|
||||||
if state.HandshakeComplete {
|
if c.handshakeComplete {
|
||||||
state.ConnectionID = c.connID
|
state.ConnectionID = c.connID
|
||||||
state.ClientHello = c.clientHello
|
state.ClientHello = c.clientHello
|
||||||
state.Version = c.vers
|
state.Version = c.vers
|
||||||
@ -1574,7 +1580,7 @@ func (c *Conn) ConnectionState() ConnectionState {
|
|||||||
state.VerifiedChains = c.verifiedChains
|
state.VerifiedChains = c.verifiedChains
|
||||||
state.SignedCertificateTimestamps = c.scts
|
state.SignedCertificateTimestamps = c.scts
|
||||||
state.OCSPResponse = c.ocspResponse
|
state.OCSPResponse = c.ocspResponse
|
||||||
state.HandshakeConfirmed = c.phase == handshakeConfirmed
|
state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1
|
||||||
if !c.didResume {
|
if !c.didResume {
|
||||||
if c.clientFinishedIsFirst {
|
if c.clientFinishedIsFirst {
|
||||||
state.TLSUnique = c.clientFinished[:]
|
state.TLSUnique = c.clientFinished[:]
|
||||||
@ -1605,7 +1611,7 @@ func (c *Conn) VerifyHostname(host string) error {
|
|||||||
if !c.isClient {
|
if !c.isClient {
|
||||||
return errors.New("tls: VerifyHostname called on TLS server connection")
|
return errors.New("tls: VerifyHostname called on TLS server connection")
|
||||||
}
|
}
|
||||||
if c.phase == handshakeRunning {
|
if !c.handshakeComplete {
|
||||||
return errors.New("tls: handshake has not yet been performed")
|
return errors.New("tls: handshake has not yet been performed")
|
||||||
}
|
}
|
||||||
if len(c.verifiedChains) == 0 {
|
if len(c.verifiedChains) == 0 {
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientHandshakeState struct {
|
type clientHandshakeState struct {
|
||||||
@ -252,6 +253,8 @@ NextCipherSuite:
|
|||||||
|
|
||||||
c.didResume = isResume
|
c.didResume = isResume
|
||||||
c.phase = handshakeConfirmed
|
c.phase = handshakeConfirmed
|
||||||
|
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||||
|
c.handshakeComplete = true
|
||||||
c.cipherSuite = suite.id
|
c.cipherSuite = suite.id
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Committer interface {
|
type Committer interface {
|
||||||
@ -78,6 +79,8 @@ func (c *Conn) serverHandshake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.hs = &hs
|
c.hs = &hs
|
||||||
|
c.handshakeComplete = true
|
||||||
|
return nil
|
||||||
} else if isResume {
|
} else if isResume {
|
||||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||||
if err := hs.doResumeHandshake(); err != nil {
|
if err := hs.doResumeHandshake(); err != nil {
|
||||||
@ -105,7 +108,6 @@ func (c *Conn) serverHandshake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.didResume = true
|
c.didResume = true
|
||||||
c.phase = handshakeConfirmed
|
|
||||||
} else {
|
} else {
|
||||||
// The client didn't include a session ticket, or it wasn't
|
// The client didn't include a session ticket, or it wasn't
|
||||||
// valid so we do a full handshake.
|
// valid so we do a full handshake.
|
||||||
@ -129,8 +131,10 @@ func (c *Conn) serverHandshake() error {
|
|||||||
if _, err := c.flush(); err != nil {
|
if _, err := c.flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.phase = handshakeConfirmed
|
|
||||||
}
|
}
|
||||||
|
c.phase = handshakeConfirmed
|
||||||
|
atomic.StoreInt32(&c.handshakeConfirmed, 1)
|
||||||
|
c.handshakeComplete = true
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user