瀏覽代碼

crypto/tls: fix data race on conn.err

Fixes #3862.

There were many areas where conn.err was being accessed
outside the mutex. This proposal moves the err value to
an embedded struct to make it more obvious when the error
value is being accessed.

As there are no Benchmark tests in this package I cannot
feel confident of the impact of this additional locking,
although most will be uncontended.

R=dvyukov, agl
CC=golang-dev
https://golang.org/cl/6497070
v1.2.3
Dave Cheney 12 年之前
父節點
當前提交
a0608ba23c
共有 2 個文件被更改,包括 32 次插入31 次删除
  1. +30
    -29
      conn.go
  2. +2
    -2
      handshake_client.go

+ 30
- 29
conn.go 查看文件

@@ -44,8 +44,7 @@ type Conn struct {
clientProtocolFallback bool

// first permanent error
errMutex sync.Mutex
err error
connErr

// input/output
in, out halfConn // in.Mutex < out.Mutex
@@ -56,21 +55,25 @@ type Conn struct {
tmp [16]byte
}

func (c *Conn) setError(err error) error {
c.errMutex.Lock()
defer c.errMutex.Unlock()
type connErr struct {
mu sync.Mutex
value error
}

func (e *connErr) setError(err error) error {
e.mu.Lock()
defer e.mu.Unlock()

if c.err == nil {
c.err = err
if e.value == nil {
e.value = err
}
return err
}

func (c *Conn) error() error {
c.errMutex.Lock()
defer c.errMutex.Unlock()

return c.err
func (e *connErr) error() error {
e.mu.Lock()
defer e.mu.Unlock()
return e.value
}

// Access to net.Conn methods.
@@ -660,8 +663,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
c.err = &net.OpError{Op: "local error", Err: err}
return n, c.err
return n, c.setError(&net.OpError{Op: "local error", Err: err})
}
}
return
@@ -672,8 +674,8 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
if c.err != nil {
return nil, c.err
if err := c.error(); err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
@@ -684,11 +686,11 @@ func (c *Conn) readHandshake() (interface{}, error) {
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlert(alertInternalError)
return nil, c.err
return nil, c.error()
}
for c.hand.Len() < 4+n {
if c.err != nil {
return nil, c.err
if err := c.error(); err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
return nil, err
@@ -738,12 +740,12 @@ func (c *Conn) readHandshake() (interface{}, error) {

// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
if c.err != nil {
return 0, c.err
if err := c.error(); err != nil {
return 0, err
}

if c.err = c.Handshake(); c.err != nil {
return 0, c.err
if err := c.Handshake(); err != nil {
return 0, c.setError(err)
}

c.out.Lock()
@@ -753,9 +755,8 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, alertInternalError
}

var n int
n, c.err = c.writeRecord(recordTypeApplicationData, b)
return n, c.err
n, err := c.writeRecord(recordTypeApplicationData, b)
return n, c.setError(err)
}

// Read can be made to time out and return a net.Error with Timeout() == true
@@ -768,14 +769,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.in.Lock()
defer c.in.Unlock()

for c.input == nil && c.err == nil {
for c.input == nil && c.error() == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
if c.err != nil {
return 0, c.err
if err := c.error(); err != nil {
return 0, err
}
n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) {


+ 2
- 2
handshake_client.go 查看文件

@@ -306,8 +306,8 @@ func (c *Conn) clientHandshake() error {
serverHash := suite.mac(c.vers, serverMAC)
c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
c.readRecord(recordTypeChangeCipherSpec)
if c.err != nil {
return c.err
if err := c.error(); err != nil {
return err
}

msg, err = c.readHandshake()


Loading…
取消
儲存