diff --git a/13.go b/13.go index 88c0c83..814d454 100644 --- a/13.go +++ b/13.go @@ -120,6 +120,9 @@ CurvePreferenceLoop: serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret") c.out.setCipher(c.vers, serverCipher) + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } if hs.hello13Enc.earlyData { c.in.setCipher(c.vers, earlyClientCipher) c.phase = readingEarlyData @@ -162,6 +165,9 @@ func (hs *serverHandshakeState) readClientFinished13() error { hs.finishedHash13.Write(clientFinished.marshal()) c.hs = nil // Discard the server handshake state + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } c.in.setCipher(c.vers, hs.appClientCipher) c.in.traceErr, c.out.traceErr = nil, nil c.phase = handshakeConfirmed @@ -224,6 +230,10 @@ func (c *Conn) handleEndOfEarlyData() { return } c.phase = waitingClientFinished + if c.hand.Len() > 0 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + return + } c.in.setCipher(c.vers, c.hs.hsClientCipher) } diff --git a/handshake_server.go b/handshake_server.go index 618589f..03b6ea2 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -132,6 +132,9 @@ func (c *Conn) serverHandshake() error { return err } } + if c.hand.Len() > 0 { + return c.sendAlert(alertUnexpectedMessage) + } c.phase = handshakeConfirmed atomic.StoreInt32(&c.handshakeConfirmed, 1) c.handshakeComplete = true