diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index fd957813..1014dea5 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go @@ -68,6 +68,7 @@ const ( typeClientKeyExchange uint8 = 16 typeFinished uint8 = 20 typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 // draft-ietf-tls-tls13-13 typeNextProtocol uint8 = 67 // Not IANA assigned typeChannelID uint8 = 203 // Not IANA assigned ) diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go index 969c50ab..cefdde33 100644 --- a/ssl/test/runner/conn.go +++ b/ssl/test/runner/conn.go @@ -159,6 +159,9 @@ type halfConn struct { // used to save allocating a new buffer for each MAC. inDigestBuf, outDigestBuf []byte + trafficSecret []byte + keyUpdateGeneration int + config *Config } @@ -203,13 +206,23 @@ func (hc *halfConn) changeCipherSpec(config *Config) error { return nil } -// updateKeys sets the current cipher state. -func (hc *halfConn) updateKeys(cipher interface{}, version uint16) { +// useTrafficSecret sets the current cipher state for TLS 1.3. +func (hc *halfConn) useTrafficSecret(version uint16, suite *cipherSuite, secret, phase []byte, side trafficDirection) { hc.version = version - hc.cipher = cipher + hc.cipher = deriveTrafficAEAD(version, suite, secret, phase, side) + hc.trafficSecret = secret hc.incEpoch() } +func (hc *halfConn) doKeyUpdate(c *Conn, isOutgoing bool) { + side := serverWrite + if c.isClient == isOutgoing { + side = clientWrite + } + hc.useTrafficSecret(hc.version, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), hc.trafficSecret), applicationPhase, side) + hc.keyUpdateGeneration++ +} + // incSeq increments the sequence number. func (hc *halfConn) incSeq(isOutgoing bool) { limit := 0 @@ -1175,6 +1188,8 @@ func (c *Conn) readHandshake() (interface{}, error) { m = new(helloVerifyRequestMsg) case typeChannelID: m = new(channelIDMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) default: return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } @@ -1280,6 +1295,13 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, alertInternalError } + // Catch up with KeyUpdates from the peer. + for c.out.keyUpdateGeneration < c.in.keyUpdateGeneration { + if err := c.sendKeyUpdateLocked(); err != nil { + return 0, err + } + } + if c.config.Bugs.SendSpuriousAlert != 0 { c.sendAlertLocked(alertLevelError, c.config.Bugs.SendSpuriousAlert) } @@ -1357,6 +1379,11 @@ func (c *Conn) handlePostHandshakeMessage() error { } } + if _, ok := msg.(*keyUpdateMsg); ok { + c.in.doKeyUpdate(c, true) + return nil + } + // TODO(davidben): Add support for KeyUpdate. c.sendAlert(alertUnexpectedMessage) return alertUnexpectedMessage @@ -1648,3 +1675,21 @@ func (c *Conn) SendNewSessionTicket() error { _, err := c.writeRecord(recordTypeHandshake, m.marshal()) return err } + +func (c *Conn) SendKeyUpdate() error { + c.out.Lock() + defer c.out.Unlock() + return c.sendKeyUpdateLocked() +} + +func (c *Conn) sendKeyUpdateLocked() error { + m := new(keyUpdateMsg) + if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil { + return err + } + if err := c.flushHandshake(); err != nil { + return err + } + c.out.doKeyUpdate(c, false) + return nil +} diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index 7718447b..b32be0e8 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -594,8 +594,8 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { // Switch to handshake traffic keys. handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel) - c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers) - c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers) + c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite) + c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite) msg, err := c.readHandshake() if err != nil { @@ -767,10 +767,9 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { c.flushHandshake() // Switch to application data keys. - c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers) - c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers) + c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite) + c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite) - // TODO(davidben): Save the traffic secret for KeyUpdate. c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel) c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel) return nil diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 41a8fb22..8e73a3c6 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -1924,6 +1924,17 @@ func (*helloRequestMsg) unmarshal(data []byte) bool { return len(data) == 4 } +type keyUpdateMsg struct { +} + +func (*keyUpdateMsg) marshal() []byte { + return []byte{typeKeyUpdate, 0, 0, 0} +} + +func (*keyUpdateMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + func eqUint16s(x, y []uint16) bool { if len(x) != len(y) { return false diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index a660f72f..012c8364 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go @@ -490,8 +490,8 @@ Curves: // Switch to handshake traffic keys. handshakeTrafficSecret := hs.finishedHash.deriveSecret(handshakeSecret, handshakeTrafficLabel) - c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite), c.vers) - c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite), c.vers) + c.out.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, serverWrite) + c.in.useTrafficSecret(c.vers, hs.suite, handshakeTrafficSecret, handshakePhase, clientWrite) if hs.suite.flags&suitePSK != 0 { return errors.New("tls: PSK ciphers not implemented for TLS 1.3") @@ -591,7 +591,7 @@ Curves: // Switch to application data keys on write. In particular, any alerts // from the client certificate are sent over these keys. - c.out.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite), c.vers) + c.out.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, serverWrite) // If we requested a client certificate, then the client must send a // certificate message, even if it's empty. @@ -664,9 +664,8 @@ Curves: hs.writeClientHash(clientFinished.marshal()) // Switch to application data keys on read. - c.in.updateKeys(deriveTrafficAEAD(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite), c.vers) + c.in.useTrafficSecret(c.vers, hs.suite, trafficSecret, applicationPhase, clientWrite) - // TODO(davidben): Save the traffic secret for KeyUpdate. c.cipherSuite = hs.suite c.exporterSecret = hs.finishedHash.deriveSecret(masterSecret, exporterLabel) c.resumptionSecret = hs.finishedHash.deriveSecret(masterSecret, resumptionLabel) diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go index abee6e6f..220aa44a 100644 --- a/ssl/test/runner/prf.go +++ b/ssl/test/runner/prf.go @@ -493,3 +493,7 @@ func deriveTrafficAEAD(version uint16, suite *cipherSuite, secret, phase []byte, return suite.aead(version, key, iv) } + +func updateTrafficSecret(hash crypto.Hash, secret []byte) []byte { + return hkdfExpandLabel(hash, secret, applicationTrafficLabel, nil, hash.Size()) +}