@@ -28,6 +28,7 @@ type Conn struct {
// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
handshakeErr error // error resulting from handshake
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
@@ -45,9 +46,6 @@ type Conn struct {
clientProtocol string
clientProtocolFallback bool
// first permanent error
connErr
// input/output
in, out halfConn // in.Mutex < out.Mutex
rawInput *block // raw input, right off the wire
@@ -57,27 +55,6 @@ type Conn struct {
tmp [16]byte
}
type connErr struct {
mu sync.Mutex
value error
}
func (e *connErr) setError(err error) error {
e.mu.Lock()
defer e.mu.Unlock()
if e.value == nil {
e.value = err
}
return err
}
func (e *connErr) error() error {
e.mu.Lock()
defer e.mu.Unlock()
return e.value
}
// Access to net.Conn methods.
// Cannot just embed net.Conn because that would
// export the struct field too.
@@ -116,6 +93,8 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// connection, either sending or receiving.
type halfConn struct {
sync.Mutex
err error // first permanent error
version uint16 // protocol version
cipher interface{} // cipher algorithm
mac macFunction
@@ -129,6 +108,18 @@ type halfConn struct {
inDigestBuf, outDigestBuf []byte
}
func (hc *halfConn) setErrorLocked(err error) error {
hc.err = err
return err
}
func (hc *halfConn) error() error {
hc.Lock()
err := hc.err
hc.Unlock()
return err
}
// prepareCipherSpec sets the encryption and MAC states
// that a subsequent changeCipherSpec will use.
func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
@@ -520,16 +511,16 @@ func (c *Conn) readRecord(want recordType) error {
switch want {
default:
c.sendAlert(alertInternalError)
return errors.New("tls: unknown record type requested")
return c.in.setErrorLocked( errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.handshakeComplete {
c.sendAlert(alertInternalError)
return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
return c.in.setErrorLocked( errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete"))
}
case recordTypeApplicationData:
if !c.handshakeComplete {
c.sendAlert(alertInternalError)
return errors.New("tls: application data record requested before handshake complete")
return c.in.setErrorLocked( errors.New("tls: application data record requested before handshake complete"))
}
}
@@ -548,7 +539,7 @@ Again:
// err = io.ErrUnexpectedEOF
// }
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
c.in. setErrorLocked (err)
}
return err
}
@@ -560,18 +551,18 @@ Again:
// an SSLv2 client.
if want == recordTypeHandshake && typ == 0x80 {
c.sendAlert(alertProtocolVersion)
return errors.New("tls: unsupported SSLv2 handshake received")
return c.in.setErrorLocked( errors.New("tls: unsupported SSLv2 handshake received"))
}
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
n := int(b.data[3])<<8 | int(b.data[4])
if c.haveVers && vers != c.vers {
c.sendAlert(alertProtocolVersion)
return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
return c.in.setErrorLocked( fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers))
}
if n > maxCiphertext {
c.sendAlert(alertRecordOverflow)
return fmt.Errorf("tls: oversized record received with length %d", n)
return c.in.setErrorLocked( fmt.Errorf("tls: oversized record received with length %d", n))
}
if !c.haveVers {
// First message, be extra suspicious:
@@ -584,7 +575,7 @@ Again:
// it's probably not real.
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
c.sendAlert(alertUnexpectedMessage)
return fmt.Errorf("tls: first record does not look like a TLS handshake")
return c.in.setErrorLocked( fmt.Errorf("tls: first record does not look like a TLS handshake"))
}
}
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
@@ -592,7 +583,7 @@ Again:
err = io.ErrUnexpectedEOF
}
if e, ok := err.(net.Error); !ok || !e.Temporary() {
c.setError(err)
c.in. setErrorLocked (err)
}
return err
}
@@ -601,27 +592,27 @@ Again:
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, err := c.in.decrypt(b)
if !ok {
return c.sendAlert(err )
c.in.setErrorLocked(c.sendAlert(err) )
}
b.off = off
data := b.data[b.off:]
if len(data) > maxPlaintext {
c.sendAlert(alertRecordOverflow)
err := c.sendAlert(alertRecordOverflow)
c.in.freeBlock(b)
return c.error()
return c.in.s etE rrorLocked (err )
}
switch typ {
default:
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c. sendAlert(alertUnexpectedMessage))
case recordTypeAlert:
if len(data) != 2 {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c. sendAlert(alertUnexpectedMessage))
break
}
if alert(data[1]) == alertCloseNotify {
c.setError(io.EOF)
c.in. setErrorLocked (io.EOF)
break
}
switch data[0] {
@@ -630,24 +621,24 @@ Again:
c.in.freeBlock(b)
goto Again
case alertLevelError:
c.setError(&net.OpError{Op: "remote error", Err: alert(data[1])})
c.in. setErrorLocked (&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c. sendAlert(alertUnexpectedMessage))
}
case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c. sendAlert(alertUnexpectedMessage))
break
}
err := c.in.changeCipherSpec()
if err != nil {
c.sendAlert(err.(alert))
c.in.setErrorLocked(c. sendAlert(err.(alert) ))
}
case recordTypeApplicationData:
if typ != want {
c.sendAlert(alertUnexpectedMessage)
c.in.setErrorLocked(c. sendAlert(alertUnexpectedMessage))
break
}
c.input = b
@@ -656,7 +647,7 @@ Again:
case recordTypeHandshake:
// TODO(rsc): Should at least pick off connection close.
if typ != want {
return c.sendAlert(alertNoRenegotiation)
return c.in.setErrorLocked(c. sendAlert(alertNoRenegotiation))
}
c.hand.Write(data)
}
@@ -664,7 +655,7 @@ Again:
if b != nil {
c.in.freeBlock(b)
}
return c.error()
return c.in. err
}
// sendAlert sends a TLS alert message.
@@ -680,7 +671,7 @@ func (c *Conn) sendAlertLocked(err alert) error {
c.writeRecord(recordTypeAlert, c.tmp[0:2])
// closeNotify is a special case in that it isn't an error:
if err != alertCloseNotify {
return c.setError(&net.OpError{Op: "local error", Err: err})
return c.out. setErrorLocked (&net.OpError{Op: "local error", Err: err})
}
return nil
}
@@ -766,7 +757,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
c.tmp[0] = alertLevelError
c.tmp[1] = byte(err.(alert))
c.writeRecord(recordTypeAlert, c.tmp[0:2])
return n, c.setError(&net.OpError{Op: "local error", Err: err})
return n, c.out. setErrorLocked (&net.OpError{Op: "local error", Err: err})
}
}
return
@@ -777,7 +768,7 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) {
// c.in.Mutex < L; c.out.Mutex < L.
func (c *Conn) readHandshake() (interface{}, error) {
for c.hand.Len() < 4 {
if err := c.error() ; err != nil {
if err := c.in. err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -788,11 +779,10 @@ func (c *Conn) readHandshake() (interface{}, error) {
data := c.hand.Bytes()
n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if n > maxHandshake {
c.sendAlert(alertInternalError)
return nil, c.error()
return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
}
for c.hand.Len() < 4+n {
if err := c.error() ; err != nil {
if err := c.in. err; err != nil {
return nil, err
}
if err := c.readRecord(recordTypeHandshake); err != nil {
@@ -831,8 +821,7 @@ func (c *Conn) readHandshake() (interface{}, error) {
case typeFinished:
m = new(finishedMsg)
default:
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
// The handshake message unmarshallers
@@ -841,25 +830,24 @@ func (c *Conn) readHandshake() (interface{}, error) {
data = append([]byte(nil), data...)
if !m.unmarshal(data) {
c.sendAlert(alertUnexpectedMessage)
return nil, alertUnexpectedMessage
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
return m, nil
}
// Write writes data to the connection.
func (c *Conn) Write(b []byte) (int, error) {
if err := c.error(); err != nil {
return 0, err
}
if err := c.Handshake(); err != nil {
return 0, c.setError( err)
return 0, err
}
c.out.Lock()
defer c.out.Unlock()
if err := c.out.err; err != nil {
return 0, err
}
if !c.handshakeComplete {
return 0, alertInternalError
}
@@ -878,14 +866,14 @@ func (c *Conn) Write(b []byte) (int, error) {
if _, ok := c.out.cipher.(cipher.BlockMode); ok {
n, err := c.writeRecord(recordTypeApplicationData, b[:1])
if err != nil {
return n, c.setError(err)
return n, c.out. setErrorLocked (err)
}
m, b = 1, b[1:]
}
}
n, err := c.writeRecord(recordTypeApplicationData, b)
return n + m, c.setError(err)
return n + m, c.out. setErrorLocked (err)
}
// Read can be made to time out and return a net.Error with Timeout() == true
@@ -902,13 +890,13 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// CBC IV. So this loop ignores a limited number of empty records.
const maxConsecutiveEmptyRecords = 100
for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
for c.input == nil && c.error() == nil {
for c.input == nil && c.in. err == nil {
if err := c.readRecord(recordTypeApplicationData); err != nil {
// Soft error, like EAGAIN
return 0, err
}
}
if err := c.error() ; err != nil {
if err := c.in. err; err != nil {
return 0, err
}
@@ -949,16 +937,19 @@ func (c *Conn) Close() error {
func (c *Conn) Handshake() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if err := c.error() ; err != nil {
if err := c.handshak eE rr; err != nil {
return err
}
if c.handshakeComplete {
return nil
}
if c.isClient {
return c.clientHandshake()
c.handshakeErr = c.clientHandshake()
} else {
c.handshakeErr = c.serverHandshake()
}
return c.serverHandshake()
return c.handshakeErr
}
// ConnectionState returns basic TLS details about the connection.