crypto/tls: fix Conn.phase data races

Phase should only be accessed under in.Mutex. Handshake and all Read
operations obtain that lock. However, many functions checking for
handshakeRunning only obtain handshakeMutex: reintroduce
handshakeCompleted for them. ConnectionState and Close check for
handshakeConfirmed, introduce an atomic flag for them.
This commit is contained in:
Filippo Valsorda 2016-12-05 13:38:08 -05:00 committed by Peter Wu
parent f3fe024dc7
commit 341de96a61
4 changed files with 38 additions and 23 deletions

2
13.go
View File

@ -15,6 +15,7 @@ import (
"io" "io"
"os" "os"
"runtime/debug" "runtime/debug"
"sync/atomic"
"time" "time"
"golang_org/x/crypto/curve25519" "golang_org/x/crypto/curve25519"
@ -157,6 +158,7 @@ func (hs *serverHandshakeState) readClientFinished13() error {
c.hs = nil // Discard the server handshake state c.hs = nil // Discard the server handshake state
c.phase = handshakeConfirmed c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.in.setCipher(c.vers, hs.appClientCipher) c.in.setCipher(c.vers, hs.appClientCipher)
c.in.traceErr, c.out.traceErr = nil, nil c.in.traceErr, c.out.traceErr = nil, nil

36
conn.go
View File

@ -27,21 +27,25 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
phase handshakeStatus // protected by in.Mutex
// handshakeConfirmed is an atomic bool for phase == handshakeConfirmed
handshakeConfirmed int32
// constant after handshake; protected by handshakeMutex // constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed // handshakeCond, if not nil, indicates that a goroutine is committed
// to running the handshake for this Conn. Other goroutines that need // to running the handshake for this Conn. Other goroutines that need
// to wait for the handshake can wait on this, under handshakeMutex. // to wait for the handshake can wait on this, under handshakeMutex.
handshakeCond *sync.Cond handshakeCond *sync.Cond
// The transition from handshakeRunning to the next phase is covered by
// handshakeMutex. All others by in.Mutex.
phase handshakeStatus
handshakeErr error // error resulting from handshake handshakeErr error // error resulting from handshake
connID []byte // Random connection id connID []byte // Random connection id
clientHello []byte // ClientHello packet contents clientHello []byte // ClientHello packet contents
vers uint16 // TLS version vers uint16 // TLS version
haveVers bool // version has been negotiated haveVers bool // version has been negotiated
config *Config // configuration passed to constructor config *Config // configuration passed to constructor
// handshakeComplete is true if the connection reached application data
// and it's equivalent to phase > handshakeRunning
handshakeComplete bool
// handshakes counts the number of handshakes performed on the // handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either // connection so far. If renegotiation is disabled then this is either
// zero or one. // zero or one.
@ -1160,7 +1164,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err return 0, err
} }
if c.phase == handshakeRunning { if !c.handshakeComplete {
return 0, alertInternalError return 0, alertInternalError
} }
@ -1232,6 +1236,7 @@ func (c *Conn) handleRenegotiation() error {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
c.phase = handshakeRunning c.phase = handshakeRunning
c.handshakeComplete = false
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} }
@ -1289,12 +1294,13 @@ func (c *Conn) ConfirmHandshake() error {
return nil return nil
} }
// earlyDataReader wraps a Conn with in locked and reads only early data, // earlyDataReader wraps a Conn and reads only early data, both buffered
// both buffered and still on the wire. // and still on the wire.
type earlyDataReader struct { type earlyDataReader struct {
c *Conn c *Conn
} }
// c.in.Mutex <= L
func (r earlyDataReader) Read(b []byte) (n int, err error) { func (r earlyDataReader) Read(b []byte) (n int, err error) {
c := r.c c := r.c
@ -1423,7 +1429,7 @@ func (c *Conn) Close() error {
var alertErr error var alertErr error
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
if c.phase != handshakeRunning { if c.handshakeComplete {
alertErr = c.closeNotify() alertErr = c.closeNotify()
} }
c.handshakeMutex.Unlock() c.handshakeMutex.Unlock()
@ -1442,7 +1448,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error { func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
if c.phase == handshakeRunning { if !c.handshakeComplete {
return errEarlyCloseWrite return errEarlyCloseWrite
} }
@ -1497,7 +1503,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil { if err := c.handshakeErr; err != nil {
return err return err
} }
if c.phase != handshakeRunning { if c.handshakeComplete {
return nil return nil
} }
if c.handshakeCond == nil { if c.handshakeCond == nil {
@ -1519,7 +1525,7 @@ func (c *Conn) Handshake() error {
// The handshake cannot have completed when handshakeMutex was unlocked // The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond. // because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.phase != handshakeRunning { if c.handshakeErr != nil || c.handshakeComplete {
panic("handshake should not have been able to complete after handshakeCond was set") panic("handshake should not have been able to complete after handshakeCond was set")
} }
@ -1541,7 +1547,7 @@ func (c *Conn) Handshake() error {
c.flush() c.flush()
} }
if c.handshakeErr == nil && c.phase == handshakeRunning { if c.handshakeErr == nil && !c.handshakeComplete {
panic("handshake should have had a result.") panic("handshake should have had a result.")
} }
@ -1559,10 +1565,10 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
var state ConnectionState var state ConnectionState
state.HandshakeComplete = c.phase != handshakeRunning state.HandshakeComplete = c.handshakeComplete
state.ServerName = c.serverName state.ServerName = c.serverName
if state.HandshakeComplete { if c.handshakeComplete {
state.ConnectionID = c.connID state.ConnectionID = c.connID
state.ClientHello = c.clientHello state.ClientHello = c.clientHello
state.Version = c.vers state.Version = c.vers
@ -1574,7 +1580,7 @@ func (c *Conn) ConnectionState() ConnectionState {
state.VerifiedChains = c.verifiedChains state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse state.OCSPResponse = c.ocspResponse
state.HandshakeConfirmed = c.phase == handshakeConfirmed state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1
if !c.didResume { if !c.didResume {
if c.clientFinishedIsFirst { if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:] state.TLSUnique = c.clientFinished[:]
@ -1605,7 +1611,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient { if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection") return errors.New("tls: VerifyHostname called on TLS server connection")
} }
if c.phase == handshakeRunning { if !c.handshakeComplete {
return errors.New("tls: handshake has not yet been performed") return errors.New("tls: handshake has not yet been performed")
} }
if len(c.verifiedChains) == 0 { if len(c.verifiedChains) == 0 {

View File

@ -17,6 +17,7 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
) )
type clientHandshakeState struct { type clientHandshakeState struct {
@ -252,6 +253,8 @@ NextCipherSuite:
c.didResume = isResume c.didResume = isResume
c.phase = handshakeConfirmed c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.handshakeComplete = true
c.cipherSuite = suite.id c.cipherSuite = suite.id
return nil return nil
} }

View File

@ -15,6 +15,7 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"sync/atomic"
) )
type Committer interface { type Committer interface {
@ -78,6 +79,8 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
c.hs = &hs c.hs = &hs
c.handshakeComplete = true
return nil
} else if isResume { } else if isResume {
// The client has included a session ticket and so we do an abbreviated handshake. // The client has included a session ticket and so we do an abbreviated handshake.
if err := hs.doResumeHandshake(); err != nil { if err := hs.doResumeHandshake(); err != nil {
@ -105,7 +108,6 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
c.didResume = true c.didResume = true
c.phase = handshakeConfirmed
} else { } else {
// The client didn't include a session ticket, or it wasn't // The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake. // valid so we do a full handshake.
@ -129,8 +131,10 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil { if _, err := c.flush(); err != nil {
return err return err
} }
c.phase = handshakeConfirmed
} }
c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
c.handshakeComplete = true
return nil return nil
} }