diff --git a/conn.go b/conn.go index d25ad28..000b23c 100644 --- a/conn.go +++ b/conn.go @@ -451,6 +451,8 @@ func (b *block) readFromUntil(r io.Reader, n int) error { m, err := r.Read(b.data[len(b.data):cap(b.data)]) b.data = b.data[0 : len(b.data)+m] if len(b.data) >= n { + // TODO(bradfitz,agl): slightly suspicious + // that we're throwing away r.Read's err here. break } if err != nil { @@ -906,6 +908,25 @@ func (c *Conn) Read(b []byte) (n int, err error) { c.input = nil } + // If a close-notify alert is waiting, read it so that + // we can return (n, EOF) instead of (n, nil), to signal + // to the HTTP response reading goroutine that the + // connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would + // otherwise not observe the EOF until its next read, + // by which time a client goroutine might have already + // tried to reuse the HTTP connection for a new + // request. + // See https://codereview.appspot.com/76400046 + // and http://golang.org/issue/3514 + if ri := c.rawInput; ri != nil && + n != 0 && err == nil && + c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert { + if recErr := c.readRecord(recordTypeApplicationData); recErr != nil { + err = recErr // will be io.EOF on closeNotify + } + } + if n != 0 || err != nil { return n, err } diff --git a/tls_test.go b/tls_test.go index 5b12610..65a243d 100644 --- a/tls_test.go +++ b/tls_test.go @@ -5,6 +5,7 @@ package tls import ( + "io" "net" "strings" "testing" @@ -109,18 +110,22 @@ func TestX509MixedKeyPair(t *testing.T) { } } -func TestDialTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping in short mode") - } - - listener, err := net.Listen("tcp", "127.0.0.1:0") +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { - listener, err = net.Listen("tcp6", "[::1]:0") + ln, err = net.Listen("tcp6", "[::1]:0") } if err != nil { t.Fatal(err) } + return ln +} + +func TestDialTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + listener := newLocalListener(t) addr := listener.Addr().String() defer listener.Close() @@ -142,6 +147,7 @@ func TestDialTimeout(t *testing.T) { Timeout: 10 * time.Millisecond, } + var err error if _, err = DialWithDialer(dialer, "tcp", addr, nil); err == nil { t.Fatal("DialWithTimeout completed successfully") } @@ -150,3 +156,59 @@ func TestDialTimeout(t *testing.T) { t.Errorf("resulting error not a timeout: %s", err) } } + +// tests that Conn.Read returns (non-zero, io.EOF) instead of +// (non-zero, nil) when a Close (alertCloseNotify) is sitting right +// behind the application data in the buffer. +func TestConnReadNonzeroAndEOF(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + srvCh := make(chan *Conn, 1) + go func() { + sconn, err := ln.Accept() + if err != nil { + t.Error(err) + srvCh <- nil + return + } + serverConfig := *testConfig + srv := Server(sconn, &serverConfig) + if err := srv.Handshake(); err != nil { + t.Error("handshake: %v", err) + srvCh <- nil + return + } + srvCh <- srv + }() + + clientConfig := *testConfig + conn, err := Dial("tcp", ln.Addr().String(), &clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + srv := <-srvCh + if srv == nil { + return + } + + buf := make([]byte, 6) + + srv.Write([]byte("foobar")) + n, err := conn.Read(buf) + if n != 6 || err != nil || string(buf) != "foobar" { + t.Fatalf("Read = %d, %v, data %q; want 6, nil, foobar", n, err, buf) + } + + srv.Write([]byte("abcdef")) + srv.Close() + n, err = conn.Read(buf) + if n != 6 || string(buf) != "abcdef" { + t.Fatalf("Read = %d, buf= %q; want 6, abcdef", n, buf) + } + if err != io.EOF { + t.Errorf("Second Read error = %v; want io.EOF", err) + } +}