crypto/tls: buffer handshake messages.
This change causes TLS handshake messages to be buffered and written in a single Write to the underlying net.Conn. There are two reasons to want to do this: Firstly, it's slightly preferable to do this in order to save sending several, small packets over the network where a single one will do. Secondly, since 37c28759ca46cf381a466e32168a793165d9c9e9 errors from Write have been returned from a handshake. This means that, if a peer closes the connection during a handshake, a “broken pipe” error may result from tls.Conn.Handshake(). This can mask any, more detailed, fatal alerts that the peer may have sent because a read will never happen. Buffering handshake messages means that the peer will not receive, and possibly reject, any of a flow while it's still being written. Fixes #15709 Change-Id: I38dcff1abecc06e52b2de647ea98713ce0fb9a21 Reviewed-on: https://go-review.googlesource.com/23609 Reviewed-by: Andrew Gerrand <adg@golang.org> Run-TryBot: Andrew Gerrand <adg@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
a709e2d83c
commit
0d94116736
37
conn.go
37
conn.go
@ -71,10 +71,12 @@ type Conn struct {
|
|||||||
clientProtocolFallback bool
|
clientProtocolFallback bool
|
||||||
|
|
||||||
// input/output
|
// input/output
|
||||||
in, out halfConn // in.Mutex < out.Mutex
|
in, out halfConn // in.Mutex < out.Mutex
|
||||||
rawInput *block // raw input, right off the wire
|
rawInput *block // raw input, right off the wire
|
||||||
input *block // application data waiting to be read
|
input *block // application data waiting to be read
|
||||||
hand bytes.Buffer // handshake data waiting to be read
|
hand bytes.Buffer // handshake data waiting to be read
|
||||||
|
buffering bool // whether records are buffered in sendBuf
|
||||||
|
sendBuf []byte // a buffer of records waiting to be sent
|
||||||
|
|
||||||
// bytesSent counts the bytes of application data sent.
|
// bytesSent counts the bytes of application data sent.
|
||||||
// packetsSent counts packets.
|
// packetsSent counts packets.
|
||||||
@ -803,6 +805,30 @@ func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
|
|||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// c.out.Mutex <= L.
|
||||||
|
func (c *Conn) write(data []byte) (int, error) {
|
||||||
|
if c.buffering {
|
||||||
|
c.sendBuf = append(c.sendBuf, data...)
|
||||||
|
return len(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.conn.Write(data)
|
||||||
|
c.bytesSent += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) flush() (int, error) {
|
||||||
|
if len(c.sendBuf) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := c.conn.Write(c.sendBuf)
|
||||||
|
c.bytesSent += int64(n)
|
||||||
|
c.sendBuf = nil
|
||||||
|
c.buffering = false
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
// writeRecordLocked writes a TLS record with the given type and payload to the
|
// writeRecordLocked writes a TLS record with the given type and payload to the
|
||||||
// connection and updates the record layer state.
|
// connection and updates the record layer state.
|
||||||
// c.out.Mutex <= L.
|
// c.out.Mutex <= L.
|
||||||
@ -862,10 +888,9 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
copy(b.data[recordHeaderLen+explicitIVLen:], data)
|
copy(b.data[recordHeaderLen+explicitIVLen:], data)
|
||||||
c.out.encrypt(b, explicitIVLen)
|
c.out.encrypt(b, explicitIVLen)
|
||||||
if _, err := c.conn.Write(b.data); err != nil {
|
if _, err := c.write(b.data); err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
c.bytesSent += int64(m)
|
|
||||||
n += m
|
n += m
|
||||||
data = data[m:]
|
data = data[m:]
|
||||||
}
|
}
|
||||||
|
@ -206,6 +206,7 @@ NextCipherSuite:
|
|||||||
hs.finishedHash.Write(hs.hello.marshal())
|
hs.finishedHash.Write(hs.hello.marshal())
|
||||||
hs.finishedHash.Write(hs.serverHello.marshal())
|
hs.finishedHash.Write(hs.serverHello.marshal())
|
||||||
|
|
||||||
|
c.buffering = true
|
||||||
if isResume {
|
if isResume {
|
||||||
if err := hs.establishKeys(); err != nil {
|
if err := hs.establishKeys(); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -220,6 +221,9 @@ NextCipherSuite:
|
|||||||
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
|
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := hs.doFullHandshake(); err != nil {
|
if err := hs.doFullHandshake(); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -230,6 +234,9 @@ NextCipherSuite:
|
|||||||
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
|
if err := hs.sendFinished(c.clientFinished[:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
c.clientFinishedIsFirst = true
|
c.clientFinishedIsFirst = true
|
||||||
if err := hs.readSessionTicket(); err != nil {
|
if err := hs.readSessionTicket(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -983,7 +983,7 @@ func (b *brokenConn) Write(data []byte) (int, error) {
|
|||||||
|
|
||||||
func TestFailedWrite(t *testing.T) {
|
func TestFailedWrite(t *testing.T) {
|
||||||
// Test that a write error during the handshake is returned.
|
// Test that a write error during the handshake is returned.
|
||||||
for _, breakAfter := range []int{0, 1, 2, 3} {
|
for _, breakAfter := range []int{0, 1} {
|
||||||
c, s := net.Pipe()
|
c, s := net.Pipe()
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
|
|
||||||
@ -1003,3 +1003,45 @@ func TestFailedWrite(t *testing.T) {
|
|||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeCountingConn wraps a net.Conn and counts the number of Write calls.
|
||||||
|
type writeCountingConn struct {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
// numWrites is the number of writes that have been done.
|
||||||
|
numWrites int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wcc *writeCountingConn) Write(data []byte) (int, error) {
|
||||||
|
wcc.numWrites++
|
||||||
|
return wcc.Conn.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuffering(t *testing.T) {
|
||||||
|
c, s := net.Pipe()
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
clientWCC := &writeCountingConn{Conn: c}
|
||||||
|
serverWCC := &writeCountingConn{Conn: s}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
Server(serverWCC, testConfig).Handshake()
|
||||||
|
serverWCC.Close()
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := Client(clientWCC, testConfig).Handshake()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
clientWCC.Close()
|
||||||
|
<-done
|
||||||
|
|
||||||
|
if n := clientWCC.numWrites; n != 2 {
|
||||||
|
t.Errorf("expected client handshake to complete with only two writes, but saw %d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := serverWCC.numWrites; n != 2 {
|
||||||
|
t.Errorf("expected server handshake to complete with only two writes, but saw %d", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -52,6 +52,7 @@ func (c *Conn) serverHandshake() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
|
// For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3
|
||||||
|
c.buffering = true
|
||||||
if isResume {
|
if isResume {
|
||||||
// The client has included a session ticket and so we do an abbreviated handshake.
|
// The client has included a session ticket and so we do an abbreviated handshake.
|
||||||
if err := hs.doResumeHandshake(); err != nil {
|
if err := hs.doResumeHandshake(); err != nil {
|
||||||
@ -71,6 +72,9 @@ func (c *Conn) serverHandshake() error {
|
|||||||
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
|
if err := hs.sendFinished(c.serverFinished[:]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
c.clientFinishedIsFirst = false
|
c.clientFinishedIsFirst = false
|
||||||
if err := hs.readFinished(nil); err != nil {
|
if err := hs.readFinished(nil); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -89,12 +93,16 @@ func (c *Conn) serverHandshake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
c.clientFinishedIsFirst = true
|
c.clientFinishedIsFirst = true
|
||||||
|
c.buffering = true
|
||||||
if err := hs.sendSessionTicket(); err != nil {
|
if err := hs.sendSessionTicket(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := hs.sendFinished(nil); err != nil {
|
if err := hs.sendFinished(nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.handshakeComplete = true
|
c.handshakeComplete = true
|
||||||
|
|
||||||
@ -430,6 +438,10 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := c.flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
var pub crypto.PublicKey // public key for client auth, if any
|
var pub crypto.PublicKey // public key for client auth, if any
|
||||||
|
|
||||||
msg, err := c.readHandshake()
|
msg, err := c.readHandshake()
|
||||||
|
Loading…
Reference in New Issue
Block a user