crypto/tls: fix data race on conn.err

Fixes #3862.

There were many areas where conn.err was being accessed
outside the mutex. This proposal moves the err value to
an embedded struct to make it more obvious when the error
value is being accessed.

As there are no Benchmark tests in this package I cannot
feel confident of the impact of this additional locking,
although most will be uncontended.

R=dvyukov, agl
CC=golang-dev
https://golang.org/cl/6497070
This commit is contained in:
Dave Cheney 2012-09-06 17:50:26 +10:00
parent 22777bcc54
commit a0608ba23c
2 changed files with 32 additions and 31 deletions

59
conn.go
View File

@ -44,8 +44,7 @@ type Conn struct {
clientProtocolFallback bool
// first permanent error
errMutex sync.Mutex
err error
connErr
// input/output
in, out halfConn // in.Mutex < out.Mutex
@ -56,21 +55,25 @@ type Conn struct {
tmp [16]byte
}
func (c *Conn) setError(err error) error {
c.errMutex.Lock()
defer c.errMutex.Unlock()
type connErr struct {
mu sync.Mutex
value error
}
if c.err == nil {
c.err = err
func (e *connErr) setError(err error) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.value == nil {
e.value = err
}
return err
}
func (c *Conn) error() error {
c.errMutex.Lock()
defer c.errMutex.Unlock()
return c.err
func (e *connErr) error() error {
e.mu.Lock()
defer e.mu.Unlock()
return e.value
}
// Access to net.Conn methods.
@ -660,8 +663,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
c.err = &net.OpError{Op: "local error", Err: err}
return n, c.err
return n, c.setError(&net.OpError{Op: "local error", Err: err})
}
}
return
@ -672,8 +674,8 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
if c.err != nil {
return nil, c.err
if err := c.error(); err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
@ -684,11 +686,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlert(alertInternalError)
return nil, c.err
return nil, c.error()
}
for c.hand.Len() < 4+n {
if c.err != nil {
return nil, c.err
if err := c.error(); err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
@ -738,12 +740,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
if c.err != nil {
return 0, c.err
if err := c.error(); err != nil {
return 0, err
}
if c.err = c.Handshake(); c.err != nil {
return 0, c.err
if err := c.Handshake(); err != nil {
return 0, c.setError(err)
}
c.out.Lock()
@ -753,9 +755,8 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, alertInternalError
}
var n int
n, c.err = c.writeRecord(recordTypeApplicationData, b)
return n, c.err
n, err := c.writeRecord(recordTypeApplicationData, b)
return n, c.setError(err)
}
// Read can be made to time out and return a net.Error with Timeout() == true
@ -768,14 +769,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.in.Lock()
defer c.in.Unlock()
for c.input == nil && c.err == nil {
for c.input == nil && c.error() == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
if c.err != nil {
return 0, c.err
if err := c.error(); err != nil {
return 0, err
}
n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) {

View File

@ -306,8 +306,8 @@ func (c *Conn) clientHandshake() error {
serverHash := suite.mac(c.vers, serverMAC)
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
if err := c.error(); err != nil {
return err
}
msg, err = c.readHandshake()