diff --git a/conn.go b/conn.go index 28d111a..8d6f7e7 100644 --- a/conn.go +++ b/conn.go @@ -64,13 +64,6 @@ type Conn struct { // the first transmitted Finished message is the tls-unique // channel-binding value. clientFinishedIsFirst bool - - // closeNotifyErr is any error from sending the alertCloseNotify record. - closeNotifyErr error - // closeNotifySent is true if the Conn attempted to send an - // alertCloseNotify record. - closeNotifySent bool - // clientFinished and serverFinished contain the Finished message sent // by the client or server in the most recent handshake. This is // retained to support the renegotiation extension and tls-unique @@ -1007,10 +1000,7 @@ func (c *Conn) readHandshake() (interface{}, error) { return m, nil } -var ( - errClosed = errors.New("tls: use of closed connection") - errShutdown = errors.New("tls: protocol is shutdown") -) +var errClosed = errors.New("tls: use of closed connection") // Write writes data to the connection. func (c *Conn) Write(b []byte) (int, error) { @@ -1041,10 +1031,6 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, alertInternalError } - if c.closeNotifySent { - return 0, errShutdown - } - // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext // attack when using block mode ciphers due to predictable IVs. // This can be prevented by splitting each Application Data @@ -1208,7 +1194,7 @@ func (c *Conn) Close() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if c.handshakeComplete { - alertErr = c.closeNotify() + alertErr = c.sendAlert(alertCloseNotify) } if err := c.conn.Close(); err != nil { @@ -1217,32 +1203,6 @@ func (c *Conn) Close() error { return alertErr } -var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete") - -// CloseWrite shuts down the writing side of the connection. It should only be -// called once the handshake has completed and does not call CloseWrite on the -// underlying connection. Most callers should just use Close. -func (c *Conn) CloseWrite() error { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - if !c.handshakeComplete { - return errEarlyCloseWrite - } - - return c.closeNotify() -} - -func (c *Conn) closeNotify() error { - c.out.Lock() - defer c.out.Unlock() - - if !c.closeNotifySent { - c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify) - c.closeNotifySent = true - } - return c.closeNotifyErr -} - // Handshake runs the client or server handshake // protocol if it has not yet been run. // Most uses of this package need not call Handshake diff --git a/tls_test.go b/tls_test.go index 04da492..8b8dfa4 100644 --- a/tls_test.go +++ b/tls_test.go @@ -11,7 +11,6 @@ import ( "fmt" "internal/testenv" "io" - "io/ioutil" "math" "math/rand" "net" @@ -459,91 +458,6 @@ func TestConnCloseBreakingWrite(t *testing.T) { } } -func TestConnCloseWrite(t *testing.T) { - ln := newLocalListener(t) - defer ln.Close() - - go func() { - sconn, err := ln.Accept() - if err != nil { - t.Fatal(err) - } - - serverConfig := testConfig.Clone() - srv := Server(sconn, serverConfig) - if err := srv.Handshake(); err != nil { - t.Fatalf("handshake: %v", err) - } - defer srv.Close() - - data, err := ioutil.ReadAll(srv) - if err != nil { - t.Fatal(err) - } - if len(data) > 0 { - t.Errorf("Read data = %q; want nothing", data) - } - - if err = srv.CloseWrite(); err != nil { - t.Errorf("server CloseWrite: %v", err) - } - }() - - clientConfig := testConfig.Clone() - conn, err := Dial("tcp", ln.Addr().String(), clientConfig) - if err != nil { - t.Fatal(err) - } - if err = conn.Handshake(); err != nil { - t.Fatal(err) - } - defer conn.Close() - - if err = conn.CloseWrite(); err != nil { - t.Errorf("client CloseWrite: %v", err) - } - - if _, err := conn.Write([]byte{0}); err != errShutdown { - t.Errorf("CloseWrite error = %v; want errShutdown", err) - } - - data, err := ioutil.ReadAll(conn) - if err != nil { - t.Fatal(err) - } - if len(data) > 0 { - t.Errorf("Read data = %q; want nothing", data) - } - - // test CloseWrite called before handshake finished - - ln2 := newLocalListener(t) - defer ln2.Close() - - go func() { - sconn, err := ln2.Accept() - if err != nil { - t.Fatal(err) - } - - serverConfig := testConfig.Clone() - srv := Server(sconn, serverConfig) - - srv.Handshake() - srv.Close() - }() - - netConn, err := net.Dial("tcp", ln2.Addr().String()) - if err != nil { - t.Fatal(err) - } - conn = Client(netConn, clientConfig) - - if err = conn.CloseWrite(); err != errEarlyCloseWrite { - t.Errorf("CloseWrite error = %v; want errEarlyCloseWrite", err) - } -} - func TestClone(t *testing.T) { var c1 Config v := reflect.ValueOf(&c1).Elem()