diff --git a/conn.go b/conn.go index 8d6f7e7..28d111a 100644 --- a/conn.go +++ b/conn.go @@ -64,6 +64,13 @@ type Conn struct { // the first transmitted Finished message is the tls-unique // channel-binding value. clientFinishedIsFirst bool + + // closeNotifyErr is any error from sending the alertCloseNotify record. + closeNotifyErr error + // closeNotifySent is true if the Conn attempted to send an + // alertCloseNotify record. + closeNotifySent bool + // clientFinished and serverFinished contain the Finished message sent // by the client or server in the most recent handshake. This is // retained to support the renegotiation extension and tls-unique @@ -1000,7 +1007,10 @@ func (c *Conn) readHandshake() (interface{}, error) { return m, nil } -var errClosed = errors.New("tls: use of closed connection") +var ( + errClosed = errors.New("tls: use of closed connection") + errShutdown = errors.New("tls: protocol is shutdown") +) // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { @@ -1031,6 +1041,10 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, alertInternalError } + if c.closeNotifySent { + return 0, errShutdown + } + // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext // attack when using block mode ciphers due to predictable IVs. // This can be prevented by splitting each Application Data @@ -1194,7 +1208,7 @@ func (c *Conn) Close() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if c.handshakeComplete { - alertErr = c.sendAlert(alertCloseNotify) + alertErr = c.closeNotify() } if err := c.conn.Close(); err != nil { @@ -1203,6 +1217,32 @@ func (c *Conn) Close() error { return alertErr } +var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") + +// CloseWrite shuts down the writing side of the connection. It should only be +// called once the handshake has completed and does not call CloseWrite on the +// underlying connection. Most callers should just use Close. +func (c *Conn) CloseWrite() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.handshakeComplete { + return errEarlyCloseWrite + } + + return c.closeNotify() +} + +func (c *Conn) closeNotify() error { + c.out.Lock() + defer c.out.Unlock() + + if !c.closeNotifySent { + c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) + c.closeNotifySent = true + } + return c.closeNotifyErr +} + // Handshake runs the client or server handshake // protocol if it has not yet been run. // Most uses of this package need not call Handshake diff --git a/tls_test.go b/tls_test.go index 99153a5..b4fedd5 100644 --- a/tls_test.go +++ b/tls_test.go @@ -11,6 +11,7 @@ import ( "fmt" "internal/testenv" "io" + "io/ioutil" "math" "math/rand" "net" @@ -458,6 +459,112 @@ func TestConnCloseBreakingWrite(t *testing.T) { } } +func TestConnCloseWrite(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + clientDoneChan := make(chan struct{}) + + serverCloseWrite := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + data, err := ioutil.ReadAll(srv) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + + if err := srv.CloseWrite(); err != nil { + return fmt.Errorf("server CloseWrite: %v", err) + } + + // Wait for clientCloseWrite to finish, so we know we + // tested the CloseWrite before we defer the + // sconn.Close above, which would also cause the + // client to unblock like CloseWrite. + <-clientDoneChan + return nil + } + + clientCloseWrite := func() error { + defer close(clientDoneChan) + + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + return err + } + if err := conn.Handshake(); err != nil { + return err + } + defer conn.Close() + + if err := conn.CloseWrite(); err != nil { + return fmt.Errorf("client CloseWrite: %v", err) + } + + if _, err := conn.Write([]byte{0}); err != errShutdown { + return fmt.Errorf("CloseWrite error = %v; want errShutdown", err) + } + + data, err := ioutil.ReadAll(conn) + if err != nil { + return err + } + if len(data) > 0 { + return fmt.Errorf("Read data = %q; want nothing", data) + } + return nil + } + + errChan := make(chan error, 2) + + go func() { errChan <- serverCloseWrite() }() + go func() { errChan <- clientCloseWrite() }() + + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if err != nil { + t.Fatal(err) + } + case <-time.After(10 * time.Second): + t.Fatal("deadlock") + } + } + + // Also test CloseWrite being called before the handshake is + // finished: + { + ln2 := newLocalListener(t) + defer ln2.Close() + + netConn, err := net.Dial("tcp", ln2.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer netConn.Close() + conn := Client(netConn, testConfig.Clone()) + + if err := conn.CloseWrite(); err != errEarlyCloseWrite { + t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) + } + } +} + func TestClone(t *testing.T) { var c1 Config v := reflect.ValueOf(&c1).Elem()