diff --git a/common.go b/common.go index d5fb3de..d4b0286 100644 --- a/common.go +++ b/common.go @@ -29,10 +29,11 @@ const ( ) const ( - maxPlaintext = 16384 // maximum plaintext payload length - maxCiphertext = 16384 + 2048 // maximum ciphertext payload length - recordHeaderLen = 5 // record header length - maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + recordHeaderLen = 5 // record header length + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + maxWarnAlertCount = 5 // maximum number of consecutive warning alerts minVersion = VersionTLS10 maxVersion = VersionTLS12 diff --git a/conn.go b/conn.go index 22017f5..31c5053 100644 --- a/conn.go +++ b/conn.go @@ -94,6 +94,10 @@ type Conn struct { bytesSent int64 packetsSent int64 + // warnCount counts the number of consecutive warning alerts received + // by Conn.readRecord. Protected by in.Mutex. + warnCount int + // activeCall is an atomic int32; the low bit is whether Close has // been called. the rest of the bits are the number of goroutines // in Conn.Write. @@ -658,6 +662,11 @@ Again: return c.in.setErrorLocked(err) } + if typ != recordTypeAlert && len(data) > 0 { + // this is a valid non-alert message: reset the count of alerts + c.warnCount = 0 + } + switch typ { default: c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) @@ -675,6 +684,13 @@ Again: case alertLevelWarning: // drop on the floor c.in.freeBlock(b) + + c.warnCount++ + if c.warnCount > maxWarnAlertCount { + c.sendAlert(alertUnexpectedMessage) + return c.in.setErrorLocked(errors.New("tls: too many warn alerts")) + } + goto Again case alertLevelError: c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) diff --git a/tls_test.go b/tls_test.go index 86812f0..97934cc 100644 --- a/tls_test.go +++ b/tls_test.go @@ -566,6 +566,58 @@ func TestConnCloseWrite(t *testing.T) { } } +func TestWarningAlertFlood(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + server := func() error { + sconn, err := ln.Accept() + if err != nil { + return fmt.Errorf("accept: %v", err) + } + defer sconn.Close() + + serverConfig := testConfig.Clone() + srv := Server(sconn, serverConfig) + if err := srv.Handshake(); err != nil { + return fmt.Errorf("handshake: %v", err) + } + defer srv.Close() + + _, err = ioutil.ReadAll(srv) + if err == nil { + return errors.New("unexpected lack of error from server") + } + const expected = "too many warn" + if str := err.Error(); !strings.Contains(str, expected) { + return fmt.Errorf("expected error containing %q, but saw: %s", expected, str) + } + + return nil + } + + errChan := make(chan error, 1) + go func() { errChan <- server() }() + + clientConfig := testConfig.Clone() + conn, err := Dial("tcp", ln.Addr().String(), clientConfig) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + if err := conn.Handshake(); err != nil { + t.Fatal(err) + } + + for i := 0; i < maxWarnAlertCount+1; i++ { + conn.sendAlert(alertNoRenegotiation) + } + + if err := <-errChan; err != nil { + t.Fatal(err) + } +} + func TestCloneFuncFields(t *testing.T) { const expectedCount = 5 called := 0