diff --git a/13.go b/13.go index e529517..58e534f 100644 --- a/13.go +++ b/13.go @@ -371,17 +371,23 @@ func (hs *serverHandshakeState) sendCertificate13() error { return nil } -func (c *Conn) handleEndOfEarlyData() { +func (c *Conn) handleEndOfEarlyData() error { if c.phase != readingEarlyData || c.vers < VersionTLS13 { - c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - return + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } - c.phase = waitingClientFinished - if c.hand.Len() > 0 { - c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - return + msg, err := c.readHandshake() + if err != nil { + return err + } + endOfEarlyData, ok := msg.(*endOfEarlyDataMsg) + // No handshake messages are allowed after EOD. + if !ok || c.hand.Len() > 0 { + return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } + c.hs.keySchedule.write(endOfEarlyData.marshal()) + c.phase = waitingClientFinished c.in.setCipher(c.vers, c.hs.hsClientCipher) + return nil } // selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify diff --git a/alert.go b/alert.go index a202981..bb9bb09 100644 --- a/alert.go +++ b/alert.go @@ -16,7 +16,6 @@ const ( const ( alertCloseNotify alert = 0 - alertEndOfEarlyData alert = 1 alertUnexpectedMessage alert = 10 alertBadRecordMAC alert = 20 alertDecryptionFailed alert = 21 diff --git a/common.go b/common.go index 90302f8..87ad18a 100644 --- a/common.go +++ b/common.go @@ -58,6 +58,7 @@ const ( typeClientHello uint8 = 1 typeServerHello uint8 = 2 typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 typeEncryptedExtensions uint8 = 8 typeCertificate uint8 = 11 typeServerKeyExchange uint8 = 12 @@ -90,7 +91,6 @@ const ( extensionEarlyData uint16 = 42 extensionSupportedVersions uint16 = 43 extensionPSKKeyExchangeModes uint16 = 45 - extensionTicketEarlyDataInfo uint16 = 46 extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionRenegotiationInfo uint16 = 0xff01 ) diff --git a/conn.go b/conn.go index d7ae0ab..a8d4ed2 100644 --- a/conn.go +++ b/conn.go @@ -779,10 +779,6 @@ Again: c.in.setErrorLocked(io.EOF) break } - if alert(data[1]) == alertEndOfEarlyData { - c.handleEndOfEarlyData() - break - } switch data[0] { case alertLevelWarning: // drop on the floor @@ -1136,6 +1132,8 @@ func (c *Conn) readHandshake() (interface{}, error) { } else { m = new(newSessionTicketMsg) } + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) case typeCertificate: if c.vers >= VersionTLS13 { m = new(certificateMsg13) @@ -1393,6 +1391,11 @@ func (r earlyDataReader) Read(b []byte) (n int, err error) { if err := c.readRecord(recordTypeApplicationData); err != nil { return 0, err } + if c.hand.Len() > 0 { + if err := c.handleEndOfEarlyData(); err != nil { + return 0, err + } + } } if err := c.in.err; err != nil { return 0, err @@ -1452,7 +1455,12 @@ func (c *Conn) Read(b []byte) (n int, err error) { return 0, err } if c.hand.Len() > 0 { - if c.phase == waitingClientFinished { + if c.phase == readingEarlyData || c.phase == waitingClientFinished { + if c.phase == readingEarlyData { + if err := c.handleEndOfEarlyData(); err != nil { + return 0, err + } + } // Server has received all early data, confirm // by reading the Client Finished message. if err := c.hs.readClientFinished13(true); err != nil { diff --git a/handshake_messages.go b/handshake_messages.go index e9c433c..a896cf8 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -2282,6 +2282,20 @@ func (m *newSessionTicketMsg13) unmarshal(data []byte) alert { return alertSuccess } +type endOfEarlyDataMsg struct { +} + +func (*endOfEarlyDataMsg) marshal() []byte { + return []byte{typeEndOfEarlyData, 0, 0, 0} +} + +func (*endOfEarlyDataMsg) unmarshal(data []byte) alert { + if len(data) != 4 { + return alertDecodeError + } + return alertSuccess +} + type helloRequestMsg struct { }