From 326f5bb02b8e849826e98de3fb17a0356eb54b5d Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 26 Feb 2016 14:17:29 -0500 Subject: [PATCH] crypto/tls: check errors from (*Conn).writeRecord This promotes a connection hang during TLS handshake to a proper error. This doesn't fully address #14539 because the error reported in that case is a write-on-socket-not-connected error, which implies that an earlier error during connection setup is not being checked, but it is an improvement over the current behaviour. Updates #14539. Change-Id: I0571a752d32d5303db48149ab448226868b19495 Reviewed-on: https://go-review.googlesource.com/19990 Reviewed-by: Adam Langley --- conn.go | 40 ++++++++++++++++---------------- handshake_client.go | 28 +++++++++++++++++------ handshake_client_test.go | 49 ++++++++++++++++++++++++++++++++++++++++ handshake_server.go | 40 ++++++++++++++++++++++++-------- handshake_server_test.go | 5 +++- 5 files changed, 123 insertions(+), 39 deletions(-) diff --git a/conn.go b/conn.go index 65b1d4b..89e4c2f 100644 --- a/conn.go +++ b/conn.go @@ -694,12 +694,14 @@ func (c *Conn) sendAlertLocked(err alert) error { c.tmp[0] = alertLevelError } c.tmp[1] = byte(err) - c.writeRecord(recordTypeAlert, c.tmp[0:2]) - // closeNotify is a special case in that it isn't an error: - if err != alertCloseNotify { - return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + + _, writeErr := c.writeRecord(recordTypeAlert, c.tmp[0:2]) + if err == alertCloseNotify { + // closeNotify is a special case in that it isn't an error. + return writeErr } - return nil + + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) } // sendAlert sends a TLS alert message. @@ -713,8 +715,11 @@ func (c *Conn) sendAlert(err alert) error { // writeRecord writes a TLS record with the given type and payload // to the connection and updates the record layer state. // c.out.Mutex <= L. -func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { +func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) { b := c.out.newBlock() + defer c.out.freeBlock(b) + + var n int for len(data) > 0 { m := len(data) if m > maxPlaintext { @@ -759,34 +764,27 @@ func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { if explicitIVIsSeq { copy(explicitIV, c.out.seq[:]) } else { - if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { - break + if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil { + return n, err } } } copy(b.data[recordHeaderLen+explicitIVLen:], data) c.out.encrypt(b, explicitIVLen) - _, err = c.conn.Write(b.data) - if err != nil { - break + if _, err := c.conn.Write(b.data); err != nil { + return n, err } n += m data = data[m:] } - c.out.freeBlock(b) if typ == recordTypeChangeCipherSpec { - err = c.out.changeCipherSpec() - if err != nil { - // Cannot call sendAlert directly, - // because we already hold c.out.Mutex. - c.tmp[0] = alertLevelError - c.tmp[1] = byte(err.(alert)) - c.writeRecord(recordTypeAlert, c.tmp[0:2]) - return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + if err := c.out.changeCipherSpec(); err != nil { + return n, c.sendAlertLocked(err.(alert)) } } - return + + return n, nil } // readHandshake reads the next handshake message from diff --git a/handshake_client.go b/handshake_client.go index b129922..d38b061 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -138,7 +138,9 @@ NextCipherSuite: } } - c.writeRecord(recordTypeHandshake, hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + return err + } msg, err := c.readHandshake() if err != nil { @@ -419,7 +421,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { certMsg.certificates = chainToSend.Certificate } hs.finishedHash.Write(certMsg.marshal()) - c.writeRecord(recordTypeHandshake, certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } } preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, certs[0]) @@ -429,7 +433,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { } if ckx != nil { hs.finishedHash.Write(ckx.marshal()) - c.writeRecord(recordTypeHandshake, ckx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil { + return err + } } if chainToSend != nil { @@ -471,7 +477,9 @@ func (hs *clientHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(certVerify.marshal()) - c.writeRecord(recordTypeHandshake, certVerify.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil { + return err + } } hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) @@ -615,7 +623,9 @@ func (hs *clientHandshakeState) readSessionTicket() error { func (hs *clientHandshakeState) sendFinished(out []byte) error { c := hs.c - c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } if hs.serverHello.nextProtoNeg { nextProto := new(nextProtoMsg) proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) @@ -624,13 +634,17 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { c.clientProtocolFallback = fallback hs.finishedHash.Write(nextProto.marshal()) - c.writeRecord(recordTypeHandshake, nextProto.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil { + return err + } } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) hs.finishedHash.Write(finished.marshal()) - c.writeRecord(recordTypeHandshake, finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } copy(out, finished.verifyData) return nil } diff --git a/handshake_client_test.go b/handshake_client_test.go index 9b6c432..322c64e 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -12,6 +12,7 @@ import ( "encoding/base64" "encoding/binary" "encoding/pem" + "errors" "fmt" "io" "net" @@ -725,3 +726,51 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { t.Fatalf("Expected error about unconfigured cipher suite but got %q", err) } } + +// brokenConn wraps a net.Conn and causes all Writes after a certain number to +// fail with brokenConnErr. +type brokenConn struct { + net.Conn + + // breakAfter is the number of successful writes that will be allowed + // before all subsequent writes fail. + breakAfter int + + // numWrites is the number of writes that have been done. + numWrites int +} + +// brokenConnErr is the error that brokenConn returns once exhausted. +var brokenConnErr = errors.New("too many writes to brokenConn") + +func (b *brokenConn) Write(data []byte) (int, error) { + if b.numWrites >= b.breakAfter { + return 0, brokenConnErr + } + + b.numWrites++ + return b.Conn.Write(data) +} + +func TestFailedWrite(t *testing.T) { + // Test that a write error during the handshake is returned. + for _, breakAfter := range []int{0, 1, 2, 3} { + c, s := net.Pipe() + done := make(chan bool) + + go func() { + Server(s, testConfig).Handshake() + s.Close() + done <- true + }() + + brokenC := &brokenConn{Conn: c, breakAfter: breakAfter} + err := Client(brokenC, testConfig).Handshake() + if err != brokenConnErr { + t.Errorf("#%d: expected error from brokenConn but got %q", breakAfter, err) + } + brokenC.Close() + + <-done + } +} diff --git a/handshake_server.go b/handshake_server.go index dbab60b..facc17d 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -322,7 +322,9 @@ func (hs *serverHandshakeState) doResumeHandshake() error { hs.finishedHash.discardHandshakeBuffer() hs.finishedHash.Write(hs.clientHello.marshal()) hs.finishedHash.Write(hs.hello.marshal()) - c.writeRecord(recordTypeHandshake, hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } if len(hs.sessionState.certificates) > 0 { if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { @@ -354,19 +356,25 @@ func (hs *serverHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(hs.clientHello.marshal()) hs.finishedHash.Write(hs.hello.marshal()) - c.writeRecord(recordTypeHandshake, hs.hello.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { + return err + } certMsg := new(certificateMsg) certMsg.certificates = hs.cert.Certificate hs.finishedHash.Write(certMsg.marshal()) - c.writeRecord(recordTypeHandshake, certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) certStatus.statusType = statusTypeOCSP certStatus.response = hs.cert.OCSPStaple hs.finishedHash.Write(certStatus.marshal()) - c.writeRecord(recordTypeHandshake, certStatus.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { + return err + } } keyAgreement := hs.suite.ka(c.vers) @@ -377,7 +385,9 @@ func (hs *serverHandshakeState) doFullHandshake() error { } if skx != nil { hs.finishedHash.Write(skx.marshal()) - c.writeRecord(recordTypeHandshake, skx.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil { + return err + } } if config.ClientAuth >= RequestClientCert { @@ -401,12 +411,16 @@ func (hs *serverHandshakeState) doFullHandshake() error { certReq.certificateAuthorities = config.ClientCAs.Subjects() } hs.finishedHash.Write(certReq.marshal()) - c.writeRecord(recordTypeHandshake, certReq.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { + return err + } } helloDone := new(serverHelloDoneMsg) hs.finishedHash.Write(helloDone.marshal()) - c.writeRecord(recordTypeHandshake, helloDone.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil { + return err + } var pub crypto.PublicKey // public key for client auth, if any @@ -632,7 +646,9 @@ func (hs *serverHandshakeState) sendSessionTicket() error { } hs.finishedHash.Write(m.marshal()) - c.writeRecord(recordTypeHandshake, m.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } return nil } @@ -640,12 +656,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error { func (hs *serverHandshakeState) sendFinished(out []byte) error { c := hs.c - c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil { + return err + } finished := new(finishedMsg) finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) hs.finishedHash.Write(finished.marshal()) - c.writeRecord(recordTypeHandshake, finished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil { + return err + } c.cipherSuite = hs.suite.id copy(out, finished.verifyData) diff --git a/handshake_server_test.go b/handshake_server_test.go index f8de4e4..afadd62 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -80,7 +80,10 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa cli.writeRecord(recordTypeHandshake, m.marshal()) c.Close() }() - err := Server(s, serverConfig).Handshake() + hs := serverHandshakeState{ + c: Server(s, serverConfig), + } + _, err := hs.readClientHello() s.Close() if len(expectedSubStr) == 0 { if err != nil && err != io.EOF {