crypto/tls: fix data race on conn.err

Fixes #3862.

There were many areas where conn.err was being accessed
outside the mutex. This proposal moves the err value to
an embedded struct to make it more obvious when the error
value is being accessed.

As there are no Benchmark tests in this package I cannot
feel confident of the impact of this additional locking,
although most will be uncontended.

R=dvyukov, agl
CC=golang-dev
https://golang.org/cl/6497070
This commit is contained in:
Dave Cheney 2012-09-06 17:50:26 +10:00
parent 22777bcc54
commit a0608ba23c
2 changed files with 32 additions and 31 deletions

59
conn.go
View File

@ -44,8 +44,7 @@ type Conn struct {
clientProtocolFallback bool clientProtocolFallback bool
// first permanent error // first permanent error
errMutex sync.Mutex connErr
err error
// input/output // input/output
in, out halfConn // in.Mutex < out.Mutex in, out halfConn // in.Mutex < out.Mutex
@ -56,21 +55,25 @@ type Conn struct {
tmp [16]byte tmp [16]byte
} }
func (c *Conn) setError(err error) error { type connErr struct {
c.errMutex.Lock() mu sync.Mutex
defer c.errMutex.Unlock() value error
}
if c.err == nil { func (e *connErr) setError(err error) error {
c.err = err e.mu.Lock()
defer e.mu.Unlock()
if e.value == nil {
e.value = err
} }
return err return err
} }
func (c *Conn) error() error { func (e *connErr) error() error {
c.errMutex.Lock() e.mu.Lock()
defer c.errMutex.Unlock() defer e.mu.Unlock()
return e.value
return c.err
} }
// Access to net.Conn methods. // Access to net.Conn methods.
@ -660,8 +663,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert)) c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2]) c.writeRecord(recordTypeAlert, c.tmp[0:2])
c.err = &net.OpError{Op: "local error", Err: err} return n, c.setError(&net.OpError{Op: "local error", Err: err})
return n, c.err
} }
} }
return return
@ -672,8 +674,8 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L. // c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) { func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 { for c.hand.Len() < 4 {
if c.err != nil { if err := c.error(); err != nil {
return nil, c.err return nil, err
} }
if err := c.readRecord(recordTypeHandshake); err != nil { if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err return nil, err
@ -684,11 +686,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake { if n > maxHandshake {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return nil, c.err return nil, c.error()
} }
for c.hand.Len() < 4+n { for c.hand.Len() < 4+n {
if c.err != nil { if err := c.error(); err != nil {
return nil, c.err return nil, err
} }
if err := c.readRecord(recordTypeHandshake); err != nil { if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err return nil, err
@ -738,12 +740,12 @@ func (c *Conn) readHandshake() (interface{}, error) {
// Write writes data to the connection. // Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) { func (c *Conn) Write(b []byte) (int, error) {
if c.err != nil { if err := c.error(); err != nil {
return 0, c.err return 0, err
} }
if c.err = c.Handshake(); c.err != nil { if err := c.Handshake(); err != nil {
return 0, c.err return 0, c.setError(err)
} }
c.out.Lock() c.out.Lock()
@ -753,9 +755,8 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, alertInternalError return 0, alertInternalError
} }
var n int n, err := c.writeRecord(recordTypeApplicationData, b)
n, c.err = c.writeRecord(recordTypeApplicationData, b) return n, c.setError(err)
return n, c.err
} }
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
@ -768,14 +769,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.in.Lock() c.in.Lock()
defer c.in.Unlock() defer c.in.Unlock()
for c.input == nil && c.err == nil { for c.input == nil && c.error() == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil { if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN // Soft error, like EAGAIN
return 0, err return 0, err
} }
} }
if c.err != nil { if err := c.error(); err != nil {
return 0, c.err return 0, err
} }
n, err = c.input.Read(b) n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) { if c.input.off >= len(c.input.data) {

View File

@ -306,8 +306,8 @@ func (c *Conn) clientHandshake() error {
serverHash := suite.mac(c.vers, serverMAC) serverHash := suite.mac(c.vers, serverMAC)
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.readRecord(recordTypeChangeCipherSpec) c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil { if err := c.error(); err != nil {
return c.err return err
} }
msg, err = c.readHandshake() msg, err = c.readHandshake()