This promotes a connection hang during TLS handshake to a proper error. This doesn't fully address #14539 because the error reported in that case is a write-on-socket-not-connected error, which implies that an earlier error during connection setup is not being checked, but it is an improvement over the current behaviour. Updates #14539. Change-Id: I0571a752d32d5303db48149ab448226868b19495 Reviewed-on: https://go-review.googlesource.com/19990 Reviewed-by: Adam Langley <agl@golang.org>tls13
@@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error { | |||||
c.tmp[0] = alertLevelError | c.tmp[0] = alertLevelError | ||||
} | } | ||||
c.tmp[1] = byte(err) | c.tmp[1] = byte(err) | ||||
c.writeRecord(recordTypeAlert, c.tmp[0:2]) | |||||
// closeNotify is a special case in that it isn't an error: | |||||
if err != alertCloseNotify { | |||||
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) | |||||
_, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2]) | |||||
if err == alertCloseNotify { | |||||
// closeNotify is a special case in that it isn't an error. | |||||
return writeErr | |||||
} | } | ||||
return nil | |||||
return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) | |||||
} | } | ||||
// sendAlert sends a TLS alert message. | // sendAlert sends a TLS alert message. | ||||
@@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error { | |||||
// writeRecord writes a TLS record with the given type and payload | // writeRecord writes a TLS record with the given type and payload | ||||
// to the connection and updates the record layer state. | // to the connection and updates the record layer state. | ||||
// c.out.Mutex <= L. | // c.out.Mutex <= L. | ||||
func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { | |||||
func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { | |||||
b := c.out.newBlock() | b := c.out.newBlock() | ||||
defer c.out.freeBlock(b) | |||||
var n int | |||||
for len(data) > 0 { | for len(data) > 0 { | ||||
m := len(data) | m := len(data) | ||||
if m > maxPlaintext { | if m > maxPlaintext { | ||||
@@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { | |||||
if explicitIVIsSeq { | if explicitIVIsSeq { | ||||
copy(explicitIV, c.out.seq[:]) | copy(explicitIV, c.out.seq[:]) | ||||
} else { | } else { | ||||
if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { | |||||
break | |||||
if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil { | |||||
return n, err | |||||
} | } | ||||
} | } | ||||
} | } | ||||
copy(b.data[recordHeaderLen+explicitIVLen:], data) | copy(b.data[recordHeaderLen+explicitIVLen:], data) | ||||
c.out.encrypt(b, explicitIVLen) | c.out.encrypt(b, explicitIVLen) | ||||
_, err = c.conn.Write(b.data) | |||||
if err != nil { | |||||
break | |||||
if _, err := c.conn.Write(b.data); err != nil { | |||||
return n, err | |||||
} | } | ||||
n += m | n += m | ||||
data = data[m:] | data = data[m:] | ||||
} | } | ||||
c.out.freeBlock(b) | |||||
if typ == recordTypeChangeCipherSpec { | if typ == recordTypeChangeCipherSpec { | ||||
err = c.out.changeCipherSpec() | |||||
if err != nil { | |||||
// Cannot call sendAlert directly, | |||||
// because we already hold c.out.Mutex. | |||||
c.tmp[0] = alertLevelError | |||||
c.tmp[1] = byte(err.(alert)) | |||||
c.writeRecord(recordTypeAlert, c.tmp[0:2]) | |||||
return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) | |||||
if err := c.out.changeCipherSpec(); err != nil { | |||||
return n, c.sendAlertLocked(err.(alert)) | |||||
} | } | ||||
} | } | ||||
return | |||||
return n, nil | |||||
} | } | ||||
// readHandshake reads the next handshake message from | // readHandshake reads the next handshake message from | ||||
@@ -138,7 +138,9 @@ NextCipherSuite: | |||||
} | } | ||||
} | } | ||||
c.writeRecord(recordTypeHandshake, hello.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { | |||||
return err | |||||
} | |||||
msg, err := c.readHandshake() | msg, err := c.readHandshake() | ||||
if err != nil { | if err != nil { | ||||
@@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||||
certMsg.certificates = chainToSend.Certificate | certMsg.certificates = chainToSend.Certificate | ||||
} | } | ||||
hs.finishedHash.Write(certMsg.marshal()) | hs.finishedHash.Write(certMsg.marshal()) | ||||
c.writeRecord(recordTypeHandshake, certMsg.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0]) | preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0]) | ||||
@@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||||
} | } | ||||
if ckx != nil { | if ckx != nil { | ||||
hs.finishedHash.Write(ckx.marshal()) | hs.finishedHash.Write(ckx.marshal()) | ||||
c.writeRecord(recordTypeHandshake, ckx.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
if chainToSend != nil { | if chainToSend != nil { | ||||
@@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||||
} | } | ||||
hs.finishedHash.Write(certVerify.marshal()) | hs.finishedHash.Write(certVerify.marshal()) | ||||
c.writeRecord(recordTypeHandshake, certVerify.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) | hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) | ||||
@@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error { | |||||
func (hs *clientHandshakeState) sendFinished(out []byte) error { | func (hs *clientHandshakeState) sendFinished(out []byte) error { | ||||
c := hs.c | c := hs.c | ||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) | |||||
if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { | |||||
return err | |||||
} | |||||
if hs.serverHello.nextProtoNeg { | if hs.serverHello.nextProtoNeg { | ||||
nextProto := new(nextProtoMsg) | nextProto := new(nextProtoMsg) | ||||
proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) | proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) | ||||
@@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { | |||||
c.clientProtocolFallback = fallback | c.clientProtocolFallback = fallback | ||||
hs.finishedHash.Write(nextProto.marshal()) | hs.finishedHash.Write(nextProto.marshal()) | ||||
c.writeRecord(recordTypeHandshake, nextProto.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
finished := new(finishedMsg) | finished := new(finishedMsg) | ||||
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) | finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) | ||||
hs.finishedHash.Write(finished.marshal()) | hs.finishedHash.Write(finished.marshal()) | ||||
c.writeRecord(recordTypeHandshake, finished.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | |||||
return err | |||||
} | |||||
copy(out, finished.verifyData) | copy(out, finished.verifyData) | ||||
return nil | return nil | ||||
} | } | ||||
@@ -12,6 +12,7 @@ import ( | |||||
"encoding/base64" | "encoding/base64" | ||||
"encoding/binary" | "encoding/binary" | ||||
"encoding/pem" | "encoding/pem" | ||||
"errors" | |||||
"fmt" | "fmt" | ||||
"io" | "io" | ||||
"net" | "net" | ||||
@@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { | |||||
t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) | t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) | ||||
} | } | ||||
} | } | ||||
// brokenConn wraps a net.Conn and causes all Writes after a certain number to | |||||
// fail with brokenConnErr. | |||||
type brokenConn struct { | |||||
net.Conn | |||||
// breakAfter is the number of successful writes that will be allowed | |||||
// before all subsequent writes fail. | |||||
breakAfter int | |||||
// numWrites is the number of writes that have been done. | |||||
numWrites int | |||||
} | |||||
// brokenConnErr is the error that brokenConn returns once exhausted. | |||||
var brokenConnErr = errors.New("too many writes to brokenConn") | |||||
func (b *brokenConn) Write(data []byte) (int, error) { | |||||
if b.numWrites >= b.breakAfter { | |||||
return 0, brokenConnErr | |||||
} | |||||
b.numWrites++ | |||||
return b.Conn.Write(data) | |||||
} | |||||
func TestFailedWrite(t *testing.T) { | |||||
// Test that a write error during the handshake is returned. | |||||
for _, breakAfter := range []int{0, 1, 2, 3} { | |||||
c, s := net.Pipe() | |||||
done := make(chan bool) | |||||
go func() { | |||||
Server(s, testConfig).Handshake() | |||||
s.Close() | |||||
done <- true | |||||
}() | |||||
brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} | |||||
err := Client(brokenC, testConfig).Handshake() | |||||
if err != brokenConnErr { | |||||
t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) | |||||
} | |||||
brokenC.Close() | |||||
<-done | |||||
} | |||||
} |
@@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error { | |||||
hs.finishedHash.discardHandshakeBuffer() | hs.finishedHash.discardHandshakeBuffer() | ||||
hs.finishedHash.Write(hs.clientHello.marshal()) | hs.finishedHash.Write(hs.clientHello.marshal()) | ||||
hs.finishedHash.Write(hs.hello.marshal()) | hs.finishedHash.Write(hs.hello.marshal()) | ||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | |||||
return err | |||||
} | |||||
if len(hs.sessionState.certificates) > 0 { | if len(hs.sessionState.certificates) > 0 { | ||||
if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { | if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { | ||||
@@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||||
} | } | ||||
hs.finishedHash.Write(hs.clientHello.marshal()) | hs.finishedHash.Write(hs.clientHello.marshal()) | ||||
hs.finishedHash.Write(hs.hello.marshal()) | hs.finishedHash.Write(hs.hello.marshal()) | ||||
c.writeRecord(recordTypeHandshake, hs.hello.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { | |||||
return err | |||||
} | |||||
certMsg := new(certificateMsg) | certMsg := new(certificateMsg) | ||||
certMsg.certificates = hs.cert.Certificate | certMsg.certificates = hs.cert.Certificate | ||||
hs.finishedHash.Write(certMsg.marshal()) | hs.finishedHash.Write(certMsg.marshal()) | ||||
c.writeRecord(recordTypeHandshake, certMsg.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { | |||||
return err | |||||
} | |||||
if hs.hello.ocspStapling { | if hs.hello.ocspStapling { | ||||
certStatus := new(certificateStatusMsg) | certStatus := new(certificateStatusMsg) | ||||
certStatus.statusType = statusTypeOCSP | certStatus.statusType = statusTypeOCSP | ||||
certStatus.response = hs.cert.OCSPStaple | certStatus.response = hs.cert.OCSPStaple | ||||
hs.finishedHash.Write(certStatus.marshal()) | hs.finishedHash.Write(certStatus.marshal()) | ||||
c.writeRecord(recordTypeHandshake, certStatus.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
keyAgreement := hs.suite.ka(c.vers) | keyAgreement := hs.suite.ka(c.vers) | ||||
@@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||||
} | } | ||||
if skx != nil { | if skx != nil { | ||||
hs.finishedHash.Write(skx.marshal()) | hs.finishedHash.Write(skx.marshal()) | ||||
c.writeRecord(recordTypeHandshake, skx.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
if config.ClientAuth >= RequestClientCert { | if config.ClientAuth >= RequestClientCert { | ||||
@@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||||
certReq.certificateAuthorities = config.ClientCAs.Subjects() | certReq.certificateAuthorities = config.ClientCAs.Subjects() | ||||
} | } | ||||
hs.finishedHash.Write(certReq.marshal()) | hs.finishedHash.Write(certReq.marshal()) | ||||
c.writeRecord(recordTypeHandshake, certReq.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { | |||||
return err | |||||
} | |||||
} | } | ||||
helloDone := new(serverHelloDoneMsg) | helloDone := new(serverHelloDoneMsg) | ||||
hs.finishedHash.Write(helloDone.marshal()) | hs.finishedHash.Write(helloDone.marshal()) | ||||
c.writeRecord(recordTypeHandshake, helloDone.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { | |||||
return err | |||||
} | |||||
var pub crypto.PublicKey // public key for client auth, if any | var pub crypto.PublicKey // public key for client auth, if any | ||||
@@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error { | |||||
} | } | ||||
hs.finishedHash.Write(m.marshal()) | hs.finishedHash.Write(m.marshal()) | ||||
c.writeRecord(recordTypeHandshake, m.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { | |||||
return err | |||||
} | |||||
return nil | return nil | ||||
} | } | ||||
@@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { | |||||
func (hs *serverHandshakeState) sendFinished(out []byte) error { | func (hs *serverHandshakeState) sendFinished(out []byte) error { | ||||
c := hs.c | c := hs.c | ||||
c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) | |||||
if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { | |||||
return err | |||||
} | |||||
finished := new(finishedMsg) | finished := new(finishedMsg) | ||||
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) | finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) | ||||
hs.finishedHash.Write(finished.marshal()) | hs.finishedHash.Write(finished.marshal()) | ||||
c.writeRecord(recordTypeHandshake, finished.marshal()) | |||||
if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { | |||||
return err | |||||
} | |||||
c.cipherSuite = hs.suite.id | c.cipherSuite = hs.suite.id | ||||
copy(out, finished.verifyData) | copy(out, finished.verifyData) | ||||
@@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa | |||||
cli.writeRecord(recordTypeHandshake, m.marshal()) | cli.writeRecord(recordTypeHandshake, m.marshal()) | ||||
c.Close() | c.Close() | ||||
}() | }() | ||||
err := Server(s, serverConfig).Handshake() | |||||
hs := serverHandshakeState{ | |||||
c: Server(s, serverConfig), | |||||
} | |||||
_, err := hs.readClientHello() | |||||
s.Close() | s.Close() | ||||
if len(expectedSubStr) == 0 { | if len(expectedSubStr) == 0 { | ||||
if err != nil && err != io.EOF { | if err != nil && err != io.EOF { | ||||