crypto/tls: fix Conn.phase data races

Phase should only be accessed under in.Mutex. Handshake and all Read
operations obtain that lock. However, many functions checking for
handshakeRunning only obtain handshakeMutex: reintroduce
handshakeCompleted for them. ConnectionState and Close check for
handshakeConfirmed, introduce an atomic flag for them.
This commit is contained in:
Filippo Valsorda 2016-12-05 13:38:08 -05:00 committed by Peter Wu
parent f3fe024dc7
commit 341de96a61
4 changed files with 38 additions and 23 deletions

2
13.go
View File

@ -15,6 +15,7 @@ import (
"io"
"os"
"runtime/debug"
"sync/atomic"
"time"
"golang_org/x/crypto/curve25519"
@ -157,6 +158,7 @@ func (hs *serverHandshakeState) readClientFinished13() error {
c.hs = nil // Discard the server handshake state
c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.in.setCipher(c.vers, hs.appClientCipher)
c.in.traceErr, c.out.traceErr = nil, nil

48
conn.go
View File

@ -27,21 +27,25 @@ type Conn struct {
conn net.Conn
isClient bool
phase handshakeStatus // protected by in.Mutex
// handshakeConfirmed is an atomic bool for phase == handshakeConfirmed
handshakeConfirmed int32
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed
// to running the handshake for this Conn. Other goroutines that need
// to wait for the handshake can wait on this, under handshakeMutex.
handshakeCond *sync.Cond
// The transition from handshakeRunning to the next phase is covered by
// handshakeMutex. All others by in.Mutex.
phase handshakeStatus
handshakeErr error // error resulting from handshake
connID []byte // Random connection id
clientHello []byte // ClientHello packet contents
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
handshakeErr error // error resulting from handshake
connID []byte // Random connection id
clientHello []byte // ClientHello packet contents
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// handshakeComplete is true if the connection reached application data
// and it's equivalent to phase > handshakeRunning
handshakeComplete bool
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
@ -1160,7 +1164,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err
}
if c.phase == handshakeRunning {
if !c.handshakeComplete {
return 0, alertInternalError
}
@ -1232,6 +1236,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock()
c.phase = handshakeRunning
c.handshakeComplete = false
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++
}
@ -1289,12 +1294,13 @@ func (c *Conn) ConfirmHandshake() error {
return nil
}
// earlyDataReader wraps a Conn with in locked and reads only early data,
// both buffered and still on the wire.
// earlyDataReader wraps a Conn and reads only early data, both buffered
// and still on the wire.
type earlyDataReader struct {
c *Conn
}
// c.in.Mutex <= L
func (r earlyDataReader) Read(b []byte) (n int, err error) {
c := r.c
@ -1423,7 +1429,7 @@ func (c *Conn) Close() error {
var alertErr error
c.handshakeMutex.Lock()
if c.phase != handshakeRunning {
if c.handshakeComplete {
alertErr = c.closeNotify()
}
c.handshakeMutex.Unlock()
@ -1442,7 +1448,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.phase == handshakeRunning {
if !c.handshakeComplete {
return errEarlyCloseWrite
}
@ -1497,7 +1503,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil {
return err
}
if c.phase != handshakeRunning {
if c.handshakeComplete {
return nil
}
if c.handshakeCond == nil {
@ -1519,7 +1525,7 @@ func (c *Conn) Handshake() error {
// The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.phase != handshakeRunning {
if c.handshakeErr != nil || c.handshakeComplete {
panic("handshake should not have been able to complete after handshakeCond was set")
}
@ -1541,7 +1547,7 @@ func (c *Conn) Handshake() error {
c.flush()
}
if c.handshakeErr == nil && c.phase == handshakeRunning {
if c.handshakeErr == nil && !c.handshakeComplete {
panic("handshake should have had a result.")
}
@ -1559,10 +1565,10 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock()
var state ConnectionState
state.HandshakeComplete = c.phase != handshakeRunning
state.HandshakeComplete = c.handshakeComplete
state.ServerName = c.serverName
if state.HandshakeComplete {
if c.handshakeComplete {
state.ConnectionID = c.connID
state.ClientHello = c.clientHello
state.Version = c.vers
@ -1574,7 +1580,7 @@ func (c *Conn) ConnectionState() ConnectionState {
state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse
state.HandshakeConfirmed = c.phase == handshakeConfirmed
state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1
if !c.didResume {
if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:]
@ -1605,7 +1611,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if c.phase == handshakeRunning {
if !c.handshakeComplete {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {

View File

@ -17,6 +17,7 @@ import (
"net"
"strconv"
"strings"
"sync/atomic"
)
type clientHandshakeState struct {
@ -252,6 +253,8 @@ NextCipherSuite:
c.didResume = isResume
c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.handshakeComplete = true
c.cipherSuite = suite.id
return nil
}

View File

@ -15,6 +15,7 @@ import (
"fmt"
"hash"
"io"
"sync/atomic"
)
type Committer interface {
@ -78,6 +79,8 @@ func (c *Conn) serverHandshake() error {
return err
}
c.hs = &hs
c.handshakeComplete = true
return nil
} else if isResume {
// The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil {
@ -105,7 +108,6 @@ func (c *Conn) serverHandshake() error {
return err
}
c.didResume = true
c.phase = handshakeConfirmed
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
@ -129,8 +131,10 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil {
return err
}
c.phase = handshakeConfirmed
}
c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.handshakeComplete = true
return nil
}