diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index 8658fd2a..969c50ab 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -877,14 +877,11 @@ Again: b = nil case recordTypeHandshake: + // Allow handshake data while reading application data to + // trigger post-handshake messages. // TODO(rsc): Should at least pick off connection close. - if typ != want { - // A client might need to process a HelloRequest from - // the server, thus receiving a handshake message when - // application data is expected is ok. - if !c.isClient || want != recordTypeApplicationData { - return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) - } + if typ != want && want != recordTypeApplicationData { + return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) } c.hand.Write(data) } @@ -1316,23 +1313,53 @@ func (c *Conn) Write(b []byte) (int, error) { return n + m, c.out.setErrorLocked(err) } -func (c *Conn) handleRenegotiation() error { - c.handshakeComplete = false - if !c.isClient { - panic("renegotiation should only happen for a client") - } - +func (c *Conn) handlePostHandshakeMessage() error { msg, err := c.readHandshake() if err != nil { return err } - _, ok := msg.(*helloRequestMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return alertUnexpectedMessage + + if c.vers < VersionTLS13 { + if !c.isClient { + c.sendAlert(alertUnexpectedMessage) + return errors.New("tls: unexpected post-handshake message") + } + + _, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage + } + + c.handshakeComplete = false + return c.Handshake() } - return c.Handshake() + if c.isClient { + if newSessionTicket, ok := msg.(*newSessionTicketMsg); ok { + if c.config.ClientSessionCache == nil || newSessionTicket.ticketLifetime == 0 { + return nil + } + + session := &ClientSessionState{ + sessionTicket: newSessionTicket.ticket, + vers: c.vers, + cipherSuite: c.cipherSuite.id, + masterSecret: c.resumptionSecret, + serverCertificates: c.peerCertificates, + sctList: c.sctList, + ocspResponse: c.ocspResponse, + } + + cacheKey := clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + c.config.ClientSessionCache.Put(cacheKey, session) + return nil + } + } + + // TODO(davidben): Add support for KeyUpdate. + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage } func (c *Conn) Renegotiate() error { @@ -1369,9 +1396,9 @@ func (c *Conn) Read(b []byte) (n int, err error) { return 0, err } if c.hand.Len() > 0 { - // We received handshake bytes, indicating the - // start of a renegotiation. - if err := c.handleRenegotiation(); err != nil { + // We received handshake bytes, indicating a + // post-handshake message. + if err := c.handlePostHandshakeMessage(); err != nil { return 0, err } continue diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index 50f3fa82..7718447b 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -220,19 +220,21 @@ NextCipherSuite: } if session != nil { - if session.sessionTicket != nil { - hello.sessionTicket = session.sessionTicket - if c.config.Bugs.CorruptTicket { - hello.sessionTicket = make([]byte, len(session.sessionTicket)) - copy(hello.sessionTicket, session.sessionTicket) - if len(hello.sessionTicket) > 0 { - offset := 40 - if offset > len(hello.sessionTicket) { - offset = len(hello.sessionTicket) - 1 - } - hello.sessionTicket[offset] ^= 0x40 - } + ticket := session.sessionTicket + if c.config.Bugs.CorruptTicket && len(ticket) > 0 { + ticket = make([]byte, len(session.sessionTicket)) + copy(ticket, session.sessionTicket) + offset := 40 + if offset >= len(ticket) { + offset = len(ticket) - 1 } + ticket[offset] ^= 0x40 + } + + if session.vers >= VersionTLS13 { + // TODO(davidben): Offer TLS 1.3 tickets. + } else if ticket != nil { + hello.sessionTicket = ticket // A random session ID is used to detect when the // server accepted the ticket and is resuming a session // (see RFC 5077). @@ -768,9 +770,9 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers) c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers) - // TODO(davidben): Derive and save the resumption master secret for receiving tickets. // TODO(davidben): Save the traffic secret for KeyUpdate. c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel) + c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel) return nil }