diff --git a/13.go b/13.go index f65237d..2eeb39e 100644 --- a/13.go +++ b/13.go @@ -15,6 +15,7 @@ import ( "io" "os" "runtime/debug" + "sync/atomic" "time" "golang_org/x/crypto/curve25519" @@ -157,6 +158,7 @@ func (hs *serverHandshakeState) readClientFinished13() error { c.hs = nil // Discard the server handshake state c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) c.in.setCipher(c.vers, hs.appClientCipher) c.in.traceErr, c.out.traceErr = nil, nil diff --git a/conn.go b/conn.go index 27acbef..ac78161 100644 --- a/conn.go +++ b/conn.go @@ -27,21 +27,25 @@ type Conn struct { conn net.Conn isClient bool + phase handshakeStatus // protected by in.Mutex + // handshakeConfirmed is an atomic bool for phase == handshakeConfirmed + handshakeConfirmed int32 + // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex // handshakeCond, if not nil, indicates that a goroutine is committed // to running the handshake for this Conn. Other goroutines that need // to wait for the handshake can wait on this, under handshakeMutex. handshakeCond *sync.Cond - // The transition from handshakeRunning to the next phase is covered by - // handshakeMutex. All others by in.Mutex. - phase handshakeStatus - handshakeErr error // error resulting from handshake - connID []byte // Random connection id - clientHello []byte // ClientHello packet contents - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *Config // configuration passed to constructor + handshakeErr error // error resulting from handshake + connID []byte // Random connection id + clientHello []byte // ClientHello packet contents + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + // handshakeComplete is true if the connection reached application data + // and it's equivalent to phase > handshakeRunning + handshakeComplete bool // handshakes counts the number of handshakes performed on the // connection so far. If renegotiation is disabled then this is either // zero or one. @@ -1160,7 +1164,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if c.phase == handshakeRunning { + if !c.handshakeComplete { return 0, alertInternalError } @@ -1232,6 +1236,7 @@ func (c *Conn) handleRenegotiation() error { defer c.handshakeMutex.Unlock() c.phase = handshakeRunning + c.handshakeComplete = false if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { c.handshakes++ } @@ -1289,12 +1294,13 @@ func (c *Conn) ConfirmHandshake() error { return nil } -// earlyDataReader wraps a Conn with in locked and reads only early data, -// both buffered and still on the wire. +// earlyDataReader wraps a Conn and reads only early data, both buffered +// and still on the wire. type earlyDataReader struct { c *Conn } +// c.in.Mutex <= L func (r earlyDataReader) Read(b []byte) (n int, err error) { c := r.c @@ -1423,7 +1429,7 @@ func (c *Conn) Close() error { var alertErr error c.handshakeMutex.Lock() - if c.phase != handshakeRunning { + if c.handshakeComplete { alertErr = c.closeNotify() } c.handshakeMutex.Unlock() @@ -1442,7 +1448,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com func (c *Conn) CloseWrite() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if c.phase == handshakeRunning { + if !c.handshakeComplete { return errEarlyCloseWrite } @@ -1497,7 +1503,7 @@ func (c *Conn) Handshake() error { if err := c.handshakeErr; err != nil { return err } - if c.phase != handshakeRunning { + if c.handshakeComplete { return nil } if c.handshakeCond == nil { @@ -1519,7 +1525,7 @@ func (c *Conn) Handshake() error { // The handshake cannot have completed when handshakeMutex was unlocked // because this goroutine set handshakeCond. - if c.handshakeErr != nil || c.phase != handshakeRunning { + if c.handshakeErr != nil || c.handshakeComplete { panic("handshake should not have been able to complete after handshakeCond was set") } @@ -1541,7 +1547,7 @@ func (c *Conn) Handshake() error { c.flush() } - if c.handshakeErr == nil && c.phase == handshakeRunning { + if c.handshakeErr == nil && !c.handshakeComplete { panic("handshake should have had a result.") } @@ -1559,10 +1565,10 @@ func (c *Conn) ConnectionState() ConnectionState { defer c.handshakeMutex.Unlock() var state ConnectionState - state.HandshakeComplete = c.phase != handshakeRunning + state.HandshakeComplete = c.handshakeComplete state.ServerName = c.serverName - if state.HandshakeComplete { + if c.handshakeComplete { state.ConnectionID = c.connID state.ClientHello = c.clientHello state.Version = c.vers @@ -1574,7 +1580,7 @@ func (c *Conn) ConnectionState() ConnectionState { state.VerifiedChains = c.verifiedChains state.SignedCertificateTimestamps = c.scts state.OCSPResponse = c.ocspResponse - state.HandshakeConfirmed = c.phase == handshakeConfirmed + state.HandshakeConfirmed = atomic.LoadInt32(&c.handshakeConfirmed) == 1 if !c.didResume { if c.clientFinishedIsFirst { state.TLSUnique = c.clientFinished[:] @@ -1605,7 +1611,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if c.phase == handshakeRunning { + if !c.handshakeComplete { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { diff --git a/handshake_client.go b/handshake_client.go index 4ead1eb..a570649 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -17,6 +17,7 @@ import ( "net" "strconv" "strings" + "sync/atomic" ) type clientHandshakeState struct { @@ -252,6 +253,8 @@ NextCipherSuite: c.didResume = isResume c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) + c.handshakeComplete = true c.cipherSuite = suite.id return nil } diff --git a/handshake_server.go b/handshake_server.go index 32630d8..652e8a1 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -15,6 +15,7 @@ import ( "fmt" "hash" "io" + "sync/atomic" ) type Committer interface { @@ -78,6 +79,8 @@ func (c *Conn) serverHandshake() error { return err } c.hs = &hs + c.handshakeComplete = true + return nil } else if isResume { // The client has included a session ticket and so we do an abbreviated handshake. if err := hs.doResumeHandshake(); err != nil { @@ -105,7 +108,6 @@ func (c *Conn) serverHandshake() error { return err } c.didResume = true - c.phase = handshakeConfirmed } else { // The client didn't include a session ticket, or it wasn't // valid so we do a full handshake. @@ -129,8 +131,10 @@ func (c *Conn) serverHandshake() error { if _, err := c.flush(); err != nil { return err } - c.phase = handshakeConfirmed } + c.phase = handshakeConfirmed + atomic.StoreInt32(&c.handshakeConfirmed, 1) + c.handshakeComplete = true return nil }