Kaynağa Gözat

crypto/tls: split connErr to avoid read/write races.

Currently a write error will cause future reads to return that same error.
However, there may have been extra information from a peer pending on
the read direction that is now unavailable.

This change splits the single connErr into errors for the read, write and
handshake. (Splitting off the handshake error is needed because both read
and write paths check the handshake error.)

Fixes #7414.

LGTM=bradfitz, r
R=golang-codereviews, r, bradfitz
CC=golang-codereviews
https://golang.org/cl/69090044
tls13
Adam Langley 10 yıl önce
ebeveyn
işleme
ef4934a9ed
3 değiştirilmiş dosya ile 61 ekleme ve 70 silme
  1. +59
    -68
      conn.go
  2. +1
    -1
      handshake_client.go
  3. +1
    -1
      handshake_server.go

+ 59
- 68
conn.go Dosyayı Görüntüle

@@ -28,6 +28,7 @@ type Conn struct {

// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
@@ -45,9 +46,6 @@ type Conn struct {
clientProtocol string
clientProtocolFallback bool

// first permanent error
connErr

// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
@@ -57,27 +55,6 @@ type Conn struct {
tmp [16]byte
}

type connErr struct {
mu sync.Mutex
value error
}

func (e *connErr) setError(err error) error {
e.mu.Lock()
defer e.mu.Unlock()

if e.value == nil {
e.value = err
}
return err
}

func (e *connErr) error() error {
e.mu.Lock()
defer e.mu.Unlock()
return e.value
}

// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
@@ -116,6 +93,8 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex

err error // first permanent error
version uint16 // protocol version
cipher interface{} // cipher algorithm
mac macFunction
@@ -129,6 +108,18 @@ type halfConn struct {
inDigestBuf, outDigestBuf []byte
}

func (hc *halfConn) setErrorLocked(err error) error {
hc.err = err
return err
}

func (hc *halfConn) error() error {
hc.Lock()
err := hc.err
hc.Unlock()
return err
}

// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
@@ -520,16 +511,16 @@ func (c *Conn) readRecord(want recordType) error {
switch want {
default:
c.sendAlert(alertInternalError)
return errors.New("tls: unknown record type requested")
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
c.sendAlert(alertInternalError)
return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
}
case recordTypeApplicationData:
if !c.handshakeComplete {
c.sendAlert(alertInternalError)
return errors.New("tls: application data record requested before handshake complete")
return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete"))
}
}

@@ -548,7 +539,7 @@ Again:
// err = io.ErrUnexpectedEOF
// }
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
c.in.setErrorLocked(err)
}
return err
}
@@ -560,18 +551,18 @@ Again:
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: unsupported SSLv2 handshake received")
return c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received"))
}

vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion)
return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
return c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
return fmt.Errorf("tls: oversized record received with length %d", n)
return c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n))
}
if !c.haveVers {
// First message, be extra suspicious:
@@ -584,7 +575,7 @@ Again:
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: first record does not look like a TLS handshake")
return c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake"))
}
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -592,7 +583,7 @@ Again:
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
c.in.setErrorLocked(err)
}
return err
}
@@ -601,27 +592,27 @@ Again:
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, err := c.in.decrypt(b)
if !ok {
return c.sendAlert(err)
c.in.setErrorLocked(c.sendAlert(err))
}
b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
c.sendAlert(alertRecordOverflow)
err := c.sendAlert(alertRecordOverflow)
c.in.freeBlock(b)
return c.error()
return c.in.setErrorLocked(err)
}

switch typ {
default:
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))

case recordTypeAlert:
if len(data) != 2 {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if alert(data[1]) == alertCloseNotify {
c.setError(io.EOF)
c.in.setErrorLocked(io.EOF)
break
}
switch data[0] {
@@ -630,24 +621,24 @@ Again:
c.in.freeBlock(b)
goto Again
case alertLevelError:
c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}

case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
err := c.in.changeCipherSpec()
if err != nil {
c.sendAlert(err.(alert))
c.in.setErrorLocked(c.sendAlert(err.(alert)))
}

case recordTypeApplicationData:
if typ != want {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
c.input = b
@@ -656,7 +647,7 @@ Again:
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want {
return c.sendAlert(alertNoRenegotiation)
return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
}
@@ -664,7 +655,7 @@ Again:
if b != nil {
c.in.freeBlock(b)
}
return c.error()
return c.in.err
}

// sendAlert sends a TLS alert message.
@@ -680,7 +671,7 @@ func (c *Conn) sendAlertLocked(err alert) error {
c.writeRecord(recordTypeAlert, c.tmp[0:2])
// closeNotify is a special case in that it isn't an error:
if err != alertCloseNotify {
return c.setError(&net.OpError{Op: "local error", Err: err})
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
return nil
}
@@ -766,7 +757,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
return n, c.setError(&net.OpError{Op: "local error", Err: err})
return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
}
}
return
@@ -777,7 +768,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
if err := c.error(); err != nil {
if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -788,11 +779,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlert(alertInternalError)
return nil, c.error()
return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
for c.hand.Len() < 4+n {
if err := c.error(); err != nil {
if err := c.in.err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -831,8 +821,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeFinished:
m = new(finishedMsg)
default:
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}

// The handshake message unmarshallers
@@ -841,25 +830,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
data = append([]byte(nil), data...)

if !m.unmarshal(data) {
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
return m, nil
}

// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
if err := c.error(); err != nil {
return 0, err
}

if err := c.Handshake(); err != nil {
return 0, c.setError(err)
return 0, err
}

c.out.Lock()
defer c.out.Unlock()

if err := c.out.err; err != nil {
return 0, err
}

if !c.handshakeComplete {
return 0, alertInternalError
}
@@ -878,14 +866,14 @@ func (c *Conn) Write(b []byte) (int, error) {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.setError(err)
return n, c.out.setErrorLocked(err)
}
m, b = 1, b[1:]
}
}

n, err := c.writeRecord(recordTypeApplicationData, b)
return n + m, c.setError(err)
return n + m, c.out.setErrorLocked(err)
}

// Read can be made to time out and return a net.Error with Timeout() == true
@@ -902,13 +890,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
for c.input == nil && c.error() == nil {
for c.input == nil && c.in.err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
if err := c.error(); err != nil {
if err := c.in.err; err != nil {
return 0, err
}

@@ -949,16 +937,19 @@ func (c *Conn) Close() error {
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.error(); err != nil {
if err := c.handshakeErr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}

if c.isClient {
return c.clientHandshake()
c.handshakeErr = c.clientHandshake()
} else {
c.handshakeErr = c.serverHandshake()
}
return c.serverHandshake()
return c.handshakeErr
}

// ConnectionState returns basic TLS details about the connection.


+ 1
- 1
handshake_client.go Dosyayı Görüntüle

@@ -501,7 +501,7 @@ func (hs *clientHandshakeState) readFinished() error {
c := hs.c

c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
if err := c.in.error(); err != nil {
return err
}



+ 1
- 1
handshake_server.go Dosyayı Görüntüle

@@ -470,7 +470,7 @@ func (hs *serverHandshakeState) readFinished() error {
c := hs.c

c.readRecord(recordTypeChangeCipherSpec)
if err := c.error(); err != nil {
if err := c.in.error(); err != nil {
return err
}



Yükleniyor…
İptal
Kaydet