diff --git a/.travis.yml b/.travis.yml index 34728f5..5af35eb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -30,8 +30,8 @@ install: - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh INSTALL $CLIENT $REVISION; fi script: - - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA - - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA + - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT; fi + - if [ "$MODE" = "interop" ] && [ "$CLIENT" = "tstclnt" ]; then ./_dev/interop.sh 0-RTT $CLIENT; fi - if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi after_script: diff --git a/13.go b/13.go index 3be0305..d8490bb 100644 --- a/13.go +++ b/13.go @@ -60,34 +60,33 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { } c.didResume = isPSK + hs.finishedHash13 = hash.New() + hs.finishedHash13.Write(hs.clientHello.marshal()) + + handshakeCtx := hs.finishedHash13.Sum(nil) + earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret") + ecdheSecret := deriveECDHESecret(ks, privateKey) if ecdheSecret == nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: bad ECDHE client share") } - hs.finishedHash13 = hash.New() - hs.finishedHash13.Write(hs.clientHello.marshal()) hs.finishedHash13.Write(hs.hello13.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { return err } handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) - handshakeCtx := hs.finishedHash13.Sum(nil) - - cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize) - cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen) - cIV := hkdfExpandLabel(hash, cHandshakeTS, nil, "iv", 12) - sHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "server handshake traffic secret", hashSize) - sKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "key", hs.suite.keyLen) - sIV := hkdfExpandLabel(hash, sHandshakeTS, nil, "iv", 12) - - clientCipher := hs.suite.aead(cKey, cIV) - c.in.setCipher(c.vers, clientCipher) - serverCipher := hs.suite.aead(sKey, sIV) + handshakeCtx = hs.finishedHash13.Sum(nil) + clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret") + hs.hsClientCipher = clientCipher + serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret") c.out.setCipher(c.vers, serverCipher) + serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize) + hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize) + hs.finishedHash13.Write(hs.hello13Enc.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { return err @@ -99,9 +98,6 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { } } - serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize) - hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize) - verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) serverFinished := &finishedMsg{ verifyData: verifyData, @@ -113,19 +109,20 @@ func (hs *serverHandshakeState) doTLS13Handshake() error { hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) handshakeCtx = hs.finishedHash13.Sum(nil) - - cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize) - cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen) - cIV = hkdfExpandLabel(hash, cTrafficSecret0, nil, "iv", 12) - sTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "server application traffic secret", hashSize) - sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen) - sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12) - - hs.clientCipher = hs.suite.aead(cKey, cIV) - serverCipher = hs.suite.aead(sKey, sIV) + hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret") + serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret") c.out.setCipher(c.vers, serverCipher) - c.phase = waitingClientFinished + if hs.hello13Enc.earlyData { + c.in.setCipher(c.vers, earlyClientCipher) + c.phase = readingEarlyData + } else if hs.clientHello.earlyData { + c.in.setCipher(c.vers, hs.hsClientCipher) + c.phase = discardingEarlyData + } else { + c.in.setCipher(c.vers, hs.hsClientCipher) + c.phase = waitingClientFinished + } return nil } @@ -157,11 +154,10 @@ func (hs *serverHandshakeState) readClientFinished13() error { } hs.finishedHash13.Write(clientFinished.marshal()) - c.in.setCipher(c.vers, hs.clientCipher) - - // Discard the server handshake state - c.hs = nil - c.phase = handshakeComplete + c.hs = nil // Discard the server handshake state + c.phase = handshakeConfirmed + c.in.setCipher(c.vers, hs.appClientCipher) + c.in.traceErr, c.out.traceErr = nil, nil return hs.sendSessionTicket13() } @@ -209,6 +205,15 @@ func (hs *serverHandshakeState) sendCertificate13() error { return nil } +func (c *Conn) handleEndOfEarlyData() { + if c.phase != readingEarlyData || c.vers < VersionTLS13 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + return + } + c.phase = waitingClientFinished + c.in.setCipher(c.vers, c.hs.hsClientCipher) +} + // selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify // based on the certificate type and client supported schemes. If no overlap is found, // a fallback is selected. @@ -377,6 +382,14 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte { return h.Sum(nil) } +func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) { + hash := hashForSuite(hs.suite) + trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size()) + key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen) + iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12) + return hs.suite.aead(key, iv), trafficSecret +} + // Maximum allowed mismatch between the stated age of a ticket // and the server-observed one. See // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. @@ -418,7 +431,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance { continue } - if s.hash != uint16(hash) { + + // This enforces the stricter 0-RTT requirements on all ticket uses. + // The benefit of using PSK+ECDHE without 0-RTT are small enough that + // we can give them up in the edge case of changed suite or ALPN. + if s.suite != hs.suite.id { + continue + } + if s.alpnProtocol != hs.hello13Enc.alpnProtocol { continue } @@ -433,6 +453,9 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { hs.hello13.psk = true hs.hello13.pskIdentity = uint16(i) + if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData { + hs.hello13Enc.earlyData = true + } return earlySecret, true } } @@ -467,8 +490,8 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return err } sessionState := &sessionState13{ - vers: c.vers, - hash: uint16(hash), + vers: c.vers, + suite: hs.suite.id, ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]), createdAt: uint64(time.Now().Unix()), @@ -481,9 +504,11 @@ func (hs *serverHandshakeState) sendSessionTicket13() error { return err } ticketMsg := &newSessionTicketMsg13{ - lifetime: 21600, // TODO(filippo) - ageAdd: sessionState.ageAdd, - ticket: ticket, + lifetime: 24 * 3600, // TODO(filippo) + maxEarlyDataLength: c.config.Max0RTTDataSize, + withEarlyDataInfo: c.config.Max0RTTDataSize > 0, + ageAdd: sessionState.ageAdd, + ticket: ticket, } if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil { return err diff --git a/_dev/Makefile b/_dev/Makefile index c699600..799751a 100644 --- a/_dev/Makefile +++ b/_dev/Makefile @@ -20,7 +20,9 @@ ifeq ($(shell go env CGO_ENABLED),1) endif @touch "$@" -GO_COMMIT := 5782050a487e002acfd14a3f4c2c815c7854928c +# Note: when changing this, if it doesn't change the Go version +# (it should), you need to run make clean. +GO_COMMIT := 5d4f37266d324e48f67b3bb82c6090f1aa94c013 .PHONY: go go: go/.ok_$(GO_COMMIT)_$(GOENV) diff --git a/_dev/interop.sh b/_dev/interop.sh index 1d45ab3..8c6e17c 100755 --- a/_dev/interop.sh +++ b/_dev/interop.sh @@ -11,8 +11,25 @@ if [ "$1" = "INSTALL" ]; then elif [ "$1" = "RUN" ]; then IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver) - docker run --rm tls-tris:$2 $IP:$3 | tee output.txt - grep "Hello TLS 1.3" output.txt | grep -v "resumed" - grep "Hello TLS 1.3" output.txt | grep "resumed" + + docker run --rm tls-tris:$2 $IP:1443 | tee output.txt # RSA + grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT" + grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT" + + docker run --rm tls-tris:$2 $IP:2443 | tee output.txt # ECDSA + grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT" + grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT" + +elif [ "$1" = "0-RTT" ]; then + IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver) + + docker run --rm tls-tris:$2 $IP:3443 | tee output.txt # rejecting 0-RTT + grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT" + + docker run --rm tls-tris:$2 $IP:4443 | tee output.txt # accepting 0-RTT + grep "Hello TLS 1.3" output.txt | grep "resumed" | grep "0-RTT" + + docker run --rm tls-tris:$2 $IP:5443 | tee output.txt # confirming 0-RTT + grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT" fi diff --git a/_dev/tris-localserver/Dockerfile b/_dev/tris-localserver/Dockerfile index 4aba119..c2de2c5 100644 --- a/_dev/tris-localserver/Dockerfile +++ b/_dev/tris-localserver/Dockerfile @@ -2,10 +2,13 @@ FROM scratch ENV TLSDEBUG error -EXPOSE 443 +EXPOSE 1443 +EXPOSE 2443 +EXPOSE 3443 EXPOSE 4443 +EXPOSE 5443 # GOOS=linux ../go.sh build -v -i . ADD tris-localserver ./ -CMD [ "./tris-localserver", "0.0.0.0:443", "0.0.0.0:4443" ] +CMD [ "./tris-localserver", "0.0.0.0:1443", "0.0.0.0:2443", "0.0.0.0:3443", "0.0.0.0:4443", "0.0.0.0:5443" ] diff --git a/_dev/tris-localserver/server.go b/_dev/tris-localserver/server.go index 89aadd3..72ec58c 100644 --- a/_dev/tris-localserver/server.go +++ b/_dev/tris-localserver/server.go @@ -17,13 +17,49 @@ var tlsVersionToName = map[uint16]string{ tls.VersionTLS13Draft18: "1.3 (draft 18)", } +func startServer(addr string, rsa, offer0RTT, accept0RTT bool) { + cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey)) + if rsa { + cert, err = tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey)) + } + if err != nil { + log.Fatal(err) + } + var Max0RTTDataSize uint32 + if offer0RTT { + Max0RTTDataSize = 100 * 1024 + } + s := &http.Server{ + Addr: addr, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + Max0RTTDataSize: Max0RTTDataSize, + Accept0RTTData: accept0RTT, + }, + } + log.Fatal(s.ListenAndServeTLS("", "")) +} + +var confirmingAddr string + func main() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + tlsConn := r.Context().Value(http.TLSConnContextKey).(*tls.Conn) + server := r.Context().Value(http.ServerContextKey).(*http.Server) + if server.Addr == confirmingAddr { + if err := tlsConn.ConfirmHandshake(); err != nil { + log.Fatal(err) + } + } resumed := "" if r.TLS.DidResume { resumed = " [resumed]" } - fmt.Fprintf(w, "

Hello TLS %s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed) + with0RTT := "" + if !tlsConn.ConnectionState().HandshakeConfirmed { + with0RTT = " [0-RTT]" + } + fmt.Fprintf(w, "

Hello TLS %s%s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed, with0RTT) }) http.HandleFunc("/ch", func(w http.ResponseWriter, r *http.Request) { @@ -31,36 +67,17 @@ func main() { fmt.Fprintf(w, "Client Hello packet (%d bytes):\n%s", len(r.TLS.ClientHello), hex.Dump(r.TLS.ClientHello)) }) - go func() { - if len(os.Args) < 3 { - return - } - cert, err := tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey)) - if err != nil { - log.Fatal(err) - } - s := &http.Server{ - Addr: os.Args[2], - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - PreferServerCipherSuites: true, - }, - } - log.Fatal(s.ListenAndServeTLS("", "")) - }() - - cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey)) - if err != nil { - log.Fatal(err) + switch len(os.Args) { + case 2: + startServer(os.Args[1], true, true, true) + case 6: + confirmingAddr = os.Args[5] + go startServer(os.Args[1], false, false, false) // first port: ECDSA (and no 0-RTT) + go startServer(os.Args[2], true, false, true) // second port: RSA (and accept 0-RTT but not offer it) + go startServer(os.Args[3], false, true, false) // third port: offer and reject 0-RTT + go startServer(os.Args[4], false, true, true) // fourth port: offer and accept 0-RTT + startServer(os.Args[5], false, true, true) // fifth port: offer and accept 0-RTT but confirm } - s := &http.Server{ - Addr: os.Args[1], - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - PreferServerCipherSuites: true, - }, - } - log.Fatal(s.ListenAndServeTLS("", "")) } const ( diff --git a/_dev/tstclnt/run.sh b/_dev/tstclnt/run.sh index 352bc9b..40e3d33 100755 --- a/_dev/tstclnt/run.sh +++ b/_dev/tstclnt/run.sh @@ -5,4 +5,4 @@ shift HOST="${ADDR[0]}" PORT="${ADDR[1]}" -exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 "$@" +exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 -Z "$@" diff --git a/alert.go b/alert.go index 4929868..e77631c 100644 --- a/alert.go +++ b/alert.go @@ -16,6 +16,7 @@ const ( const ( alertCloseNotify alert = 0 + alertEndOfEarlyData alert = 1 alertUnexpectedMessage alert = 10 alertBadRecordMAC alert = 20 alertDecryptionFailed alert = 21 diff --git a/common.go b/common.go index 457ac44..992a163 100644 --- a/common.go +++ b/common.go @@ -85,6 +85,7 @@ const ( extensionSessionTicket uint16 = 35 extensionKeyShare uint16 = 40 extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 extensionSupportedVersions uint16 = 43 extensionPSKKeyExchangeModes uint16 = 45 extensionTicketEarlyDataInfo uint16 = 46 @@ -213,6 +214,10 @@ type ConnectionState struct { // been standardized and implemented. TLSUnique []byte + // HandshakeConfirmed is true once all data returned by Read + // (past and future) is guaranteed not to be replayed. + HandshakeConfirmed bool + ClientHello []byte // ClientHello packet } @@ -322,6 +327,18 @@ type ClientHelloInfo struct { // from, or write to, this connection; that will cause the TLS // connection to fail. Conn net.Conn + + // Offered0RTTData is true if the client announced that it will send + // 0-RTT data. If the server Config.Accept0RTTData is true, and the + // client offered a session ticket valid for that purpose, it will + // be notified that the 0-RTT data is accepted and it will be made + // immediately available for Read. + Offered0RTTData bool + + // The Fingerprint is an sequence of bytes unique to this Client Hello. + // It can be used to prevent or mitigate 0-RTT data replays as it's + // guaranteed that a replayed connection will have the same Fingerprint. + Fingerprint []byte } // CertificateRequestInfo contains information from a server's @@ -548,6 +565,28 @@ type Config struct { // used for debugging. KeyLogWriter io.Writer + // If Max0RTTDataSize is not zero, the client will be allowed to use + // session tickets to send at most this number of bytes of 0-RTT data. + // 0-RTT data is subject to replay and has memory DoS implications. + // The server will later be able to refuse the 0-RTT data with + // Accept0RTTData, or wait for the client to prove that it's not + // replayed with Conn.ConfirmHandshake. + // + // It has no meaning on the client. + // + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3. + Max0RTTDataSize uint32 + + // Accept0RTTData makes the 0-RTT data received from the client + // immediately available to Read. 0-RTT data is subject to replay. + // Use Conn.ConfirmHandshake to wait until the data is known not + // to be replayed after reading it. + // + // It has no meaning on the client. + // + // See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3. + Accept0RTTData bool + serverInitOnce sync.Once // guards calling (*Config).serverInit // mutex protects sessionTicketKeys. @@ -622,6 +661,8 @@ func (c *Config) Clone() *Config { DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, KeyLogWriter: c.KeyLogWriter, + Accept0RTTData: c.Accept0RTTData, + Max0RTTDataSize: c.Max0RTTDataSize, sessionTicketKeys: sessionTicketKeys, } } diff --git a/conn.go b/conn.go index 1adf7ea..6afe10c 100644 --- a/conn.go +++ b/conn.go @@ -27,20 +27,21 @@ type Conn struct { conn net.Conn isClient bool - phase handshakePhase - // constant after handshake; protected by handshakeMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex // handshakeCond, if not nil, indicates that a goroutine is committed // to running the handshake for this Conn. Other goroutines that need // to wait for the handshake can wait on this, under handshakeMutex. handshakeCond *sync.Cond - handshakeErr error // error resulting from handshake - connID []byte // Random connection id - clientHello []byte // ClientHello packet contents - vers uint16 // TLS version - haveVers bool // version has been negotiated - config *Config // configuration passed to constructor + // The transition from handshakeRunning to the next phase is covered by + // handshakeMutex. All others by in.Mutex. + phase handshakeStatus + handshakeErr error // error resulting from handshake + connID []byte // Random connection id + clientHello []byte // ClientHello packet contents + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor // handshakes counts the number of handshakes performed on the // connection so far. If renegotiation is disabled then this is either // zero or one. @@ -103,16 +104,22 @@ type Conn struct { // TLS 1.3 needs the server state until it reaches the Client Finished hs *serverHandshakeState + // earlyDataBytes is the number of bytes of early data received so + // far. Tracked to enforce max_early_data_size. + earlyDataBytes int64 + tmp [16]byte } -type handshakePhase int +type handshakeStatus int const ( - earlyHandshake handshakePhase = iota + handshakeRunning handshakeStatus = iota + discardingEarlyData + readingEarlyData waitingClientFinished readingClientFinished - handshakeComplete + handshakeConfirmed ) // Access to net.Conn methods. @@ -548,6 +555,9 @@ func (b *block) readFromUntil(r io.Reader, n int) error { func (b *block) Read(p []byte) (n int, err error) { n = copy(p, b.data[b.off:]) b.off += n + if b.off >= len(b.data) { + err = io.EOF + } return } @@ -606,6 +616,7 @@ func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) { // readRecord reads the next TLS record from the connection // and updates the record layer state. // c.in.Mutex <= L; c.input == nil. +// c.input can still be nil after a call, retry if so. func (c *Conn) readRecord(want recordType) error { // Caller must be in sync with connection: // handshake data if handshake not yet completed, @@ -615,18 +626,17 @@ func (c *Conn) readRecord(want recordType) error { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) case recordTypeHandshake, recordTypeChangeCipherSpec: - if c.phase != earlyHandshake && c.phase != readingClientFinished { + if c.phase != handshakeRunning && c.phase != readingClientFinished { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) } case recordTypeApplicationData: - if c.phase == earlyHandshake || c.phase == earlyHandshake { + if c.phase == handshakeRunning || c.phase == readingClientFinished { c.sendAlert(alertInternalError) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) } } -Again: if c.rawInput == nil { c.rawInput = c.in.newBlock() } @@ -686,7 +696,15 @@ Again: // Process message. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) ok, off, alertValue := c.in.decrypt(b) - if !ok { + switch { + case !ok && c.phase == discardingEarlyData: + // If the client said that it's sending early data and we did not + // accept it, we are expected to fail decryption. + c.in.freeBlock(b) + return nil + case ok && c.phase == discardingEarlyData: + c.phase = waitingClientFinished + case !ok: c.in.freeBlock(b) return c.in.setErrorLocked(c.sendAlert(alertValue)) } @@ -730,11 +748,15 @@ Again: c.in.setErrorLocked(io.EOF) break } + if alert(data[1]) == alertEndOfEarlyData { + c.handleEndOfEarlyData() + break + } switch data[0] { case alertLevelWarning: // drop on the floor c.in.freeBlock(b) - goto Again + return nil case alertLevelError: c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) default: @@ -742,7 +764,7 @@ Again: } case recordTypeChangeCipherSpec: - if typ != want || len(data) != 1 || data[0] != 1 { + if typ != want || len(data) != 1 || data[0] != 1 || c.vers >= VersionTLS13 { c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) break } @@ -752,11 +774,7 @@ Again: } case recordTypeApplicationData: - if c.phase == waitingClientFinished { - c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) - break - } - if typ != want { + if typ != want || c.phase == waitingClientFinished { c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) break } @@ -775,7 +793,6 @@ Again: c.in.setErrorLocked(err) break } - goto Again } } @@ -1131,7 +1148,7 @@ func (c *Conn) Write(b []byte) (int, error) { return 0, err } - if c.phase == earlyHandshake { + if c.phase == handshakeRunning { return 0, alertInternalError } @@ -1181,6 +1198,10 @@ func (c *Conn) handleRenegotiation() error { return c.sendAlert(alertNoRenegotiation) } + if c.vers >= VersionTLS13 { + return c.sendAlert(alertNoRenegotiation) + } + switch c.config.Renegotiation { case RenegotiateNever: return c.sendAlert(alertNoRenegotiation) @@ -1198,13 +1219,102 @@ func (c *Conn) handleRenegotiation() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - c.phase = earlyHandshake + c.phase = handshakeRunning if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { c.handshakes++ } return c.handshakeErr } +// ConfirmHandshake waits for the handshake to reach a point at which +// the connection is certainly not replayed. That is, after receiving +// the Client Finished. +// +// If ConfirmHandshake returns an error and until ConfirmHandshake +// returns, the 0-RTT data should not be trusted not to be replayed. +// +// This is only meaningful in TLS 1.3 when Accept0RTTData is true and the +// client sent valid 0-RTT data. In any other case it's equivalent to +// calling Handshake. +func (c *Conn) ConfirmHandshake() error { + if err := c.Handshake(); err != nil { + return err + } + + c.in.Lock() + defer c.in.Unlock() + + if c.phase == handshakeConfirmed { + return nil + } + + var input *block + if c.phase == readingEarlyData || c.input != nil { + buf := &bytes.Buffer{} + if _, err := buf.ReadFrom(earlyDataReader{c}); err != nil { + c.in.setErrorLocked(err) + return err + } + input = &block{data: buf.Bytes()} + } + + for c.phase != handshakeConfirmed { + if err := c.readRecord(recordTypeApplicationData); err != nil { + c.in.setErrorLocked(err) + return err + } + } + + if c.phase != handshakeConfirmed { + panic("should have reached handshakeConfirmed state") + } + if c.input != nil { + panic("should not have read past the Client Finished") + } + + c.input = input + + return nil +} + +// earlyDataReader wraps a Conn with in locked and reads only early data, +// both buffered and still on the wire. +type earlyDataReader struct { + c *Conn +} + +func (r earlyDataReader) Read(b []byte) (n int, err error) { + c := r.c + + if c.phase == handshakeConfirmed { + // c.input might not be early data + panic("earlyDataReader called at handshakeConfirmed") + } + + for c.input == nil && c.in.err == nil && c.phase == readingEarlyData { + if err := c.readRecord(recordTypeApplicationData); err != nil { + return 0, err + } + } + if err := c.in.err; err != nil { + return 0, err + } + + if c.input != nil { + n, err = c.input.Read(b) + if err == io.EOF { + err = nil + c.in.freeBlock(c.input) + c.input = nil + } + } + + if err == nil && c.phase != readingEarlyData && c.input == nil { + err = io.EOF + } + return +} + // Read can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetReadDeadline. func (c *Conn) Read(b []byte) (n int, err error) { @@ -1242,7 +1352,8 @@ func (c *Conn) Read(b []byte) (n int, err error) { } n, err = c.input.Read(b) - if c.input.off >= len(c.input.data) { + if err == io.EOF { + err = nil c.in.freeBlock(c.input) c.input = nil } @@ -1300,7 +1411,17 @@ func (c *Conn) Close() error { var alertErr error c.handshakeMutex.Lock() - if c.phase != earlyHandshake { + for c.phase == readingEarlyData { + if err := c.readRecord(recordTypeApplicationData); err != nil { + alertErr = err + } + } + if alertErr == nil && c.phase == waitingClientFinished { + if err := c.hs.readClientFinished13(); err != nil { + alertErr = err + } + } + if alertErr == nil && c.phase != handshakeRunning { alertErr = c.closeNotify() } c.handshakeMutex.Unlock() @@ -1319,7 +1440,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com func (c *Conn) CloseWrite() error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() - if c.phase == earlyHandshake { + if c.phase == handshakeRunning { return errEarlyCloseWrite } @@ -1341,8 +1462,11 @@ func (c *Conn) closeNotify() error { // protocol if it has not yet been run. // Most uses of this package need not call Handshake // explicitly: the first Read or Write will call it automatically. +// +// In TLS 1.3 Handshake returns after the client and server first flights, +// without waiting for the Client Finished. func (c *Conn) Handshake() error { - // c.handshakeErr and c.phase == earlyHandshake are protected by + // c.handshakeErr and c.phase == handshakeRunning are protected by // c.handshakeMutex. In order to perform a handshake, we need to lock // c.in also and c.handshakeMutex must be locked after c.in. // @@ -1371,7 +1495,7 @@ func (c *Conn) Handshake() error { if err := c.handshakeErr; err != nil { return err } - if c.phase != earlyHandshake { + if c.phase != handshakeRunning { return nil } if c.handshakeCond == nil { @@ -1393,7 +1517,7 @@ func (c *Conn) Handshake() error { // The handshake cannot have completed when handshakeMutex was unlocked // because this goroutine set handshakeCond. - if c.handshakeErr != nil || c.phase != earlyHandshake { + if c.handshakeErr != nil || c.phase != handshakeRunning { panic("handshake should not have been able to complete after handshakeCond was set") } @@ -1415,7 +1539,7 @@ func (c *Conn) Handshake() error { c.flush() } - if c.handshakeErr == nil && c.phase == earlyHandshake { + if c.handshakeErr == nil && c.phase == handshakeRunning { panic("handshake should have had a result.") } @@ -1433,7 +1557,7 @@ func (c *Conn) ConnectionState() ConnectionState { defer c.handshakeMutex.Unlock() var state ConnectionState - state.HandshakeComplete = c.phase != earlyHandshake + state.HandshakeComplete = c.phase != handshakeRunning state.ServerName = c.serverName if state.HandshakeComplete { @@ -1448,6 +1572,7 @@ func (c *Conn) ConnectionState() ConnectionState { state.VerifiedChains = c.verifiedChains state.SignedCertificateTimestamps = c.scts state.OCSPResponse = c.ocspResponse + state.HandshakeConfirmed = c.phase == handshakeConfirmed if !c.didResume { if c.clientFinishedIsFirst { state.TLSUnique = c.clientFinished[:] @@ -1478,7 +1603,7 @@ func (c *Conn) VerifyHostname(host string) error { if !c.isClient { return errors.New("tls: VerifyHostname called on TLS server connection") } - if c.phase == earlyHandshake { + if c.phase == handshakeRunning { return errors.New("tls: handshake has not yet been performed") } if len(c.verifiedChains) == 0 { diff --git a/handshake_client.go b/handshake_client.go index 3f120e0..4ead1eb 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -251,7 +251,7 @@ NextCipherSuite: } c.didResume = isResume - c.phase = handshakeComplete + c.phase = handshakeConfirmed c.cipherSuite = suite.id return nil } diff --git a/handshake_client_test.go b/handshake_client_test.go index 5851f89..c839926 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -64,12 +64,12 @@ func (i opensslInput) Read(buf []byte) (n int, err error) { } // opensslOutputSink is an io.Writer that receives the stdout and stderr from -// an `openssl` process and sends a value to handshakeComplete when it sees a +// an `openssl` process and sends a value to handshakeConfirmed when it sees a // log message from a completed server handshake. type opensslOutputSink struct { - handshakeComplete chan struct{} - all []byte - line []byte + handshakeConfirmed chan struct{} + all []byte + line []byte } func newOpensslOutputSink() *opensslOutputSink { @@ -91,7 +91,7 @@ func (o *opensslOutputSink) Write(data []byte) (n int, err error) { } if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) { - o.handshakeComplete <- struct{}{} + o.handshakeConfirmed <- struct{}{} } o.line = o.line[i+1:] } @@ -315,9 +315,9 @@ func (test *clientTest) run(t *testing.T, write bool) { for i := 1; i <= test.numRenegotiations; i++ { // The initial handshake will generate a - // handshakeComplete signal which needs to be quashed. + // handshakeConfirmed signal which needs to be quashed. if i == 1 && write { - <-stdout.handshakeComplete + <-stdout.handshakeConfirmed } // OpenSSL will try to interleave application data and @@ -364,7 +364,7 @@ func (test *clientTest) run(t *testing.T, write bool) { }() if write && test.renegotiationExpectedToFail != i { - <-stdout.handshakeComplete + <-stdout.handshakeConfirmed stdin <- opensslSendSentinel } <-signalChan diff --git a/handshake_messages.go b/handshake_messages.go index 31d1513..d1a2b2f 100644 --- a/handshake_messages.go +++ b/handshake_messages.go @@ -33,6 +33,7 @@ type clientHelloMsg struct { supportedVersions []uint16 psks []psk pskKeyExchangeModes []uint8 + earlyData bool } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -60,7 +61,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool { bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && eqStrings(m.alpnProtocols, m1.alpnProtocols) && eqKeyShares(m.keyShares, m1.keyShares) && - eqUint16s(m.supportedVersions, m1.supportedVersions) + eqUint16s(m.supportedVersions, m1.supportedVersions) && + m.earlyData == m1.earlyData } func (m *clientHelloMsg) marshal() []byte { @@ -127,6 +129,9 @@ func (m *clientHelloMsg) marshal() []byte { extensionsLength += 1 + 2*len(m.supportedVersions) numExtensions++ } + if m.earlyData { + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -350,6 +355,11 @@ func (m *clientHelloMsg) marshal() []byte { z = z[2:] } } + if m.earlyData { + z[0] = byte(extensionEarlyData >> 8) + z[1] = byte(extensionEarlyData) + z = z[4:] + } m.raw = x @@ -413,6 +423,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.supportedVersions = nil m.psks = nil m.pskKeyExchangeModes = nil + m.earlyData = false if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -668,6 +679,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { return false } m.pskKeyExchangeModes = data[1:length] + case extensionEarlyData: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8 + m.earlyData = true } data = data[length:] bindersOffset += length @@ -1144,6 +1158,7 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool { type encryptedExtensionsMsg struct { raw []byte alpnProtocol string + earlyData bool } func (m *encryptedExtensionsMsg) equal(i interface{}) bool { @@ -1153,7 +1168,8 @@ func (m *encryptedExtensionsMsg) equal(i interface{}) bool { } return bytes.Equal(m.raw, m1.raw) && - m.alpnProtocol == m1.alpnProtocol + m.alpnProtocol == m1.alpnProtocol && + m.earlyData == m1.earlyData } func (m *encryptedExtensionsMsg) marshal() []byte { @@ -1163,6 +1179,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte { length := 2 + if m.earlyData { + length += 4 + } alpnLen := len(m.alpnProtocol) if alpnLen > 0 { if alpnLen >= 256 { @@ -1196,6 +1215,12 @@ func (m *encryptedExtensionsMsg) marshal() []byte { z = z[7+alpnLen:] } + if m.earlyData { + z[0] = byte(extensionEarlyData >> 8) + z[1] = byte(extensionEarlyData) + z = z[4:] + } + m.raw = x return x } @@ -1205,41 +1230,55 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { return false } m.raw = data - l := int(data[4])<<8 | int(data[5]) - if l != len(data)-6 { - return false - } + m.alpnProtocol = "" - if l == 0 { - return true - } + m.earlyData = false - d := data[6:] - if len(d) < 5 { - return false - } - if uint16(d[0])<<8|uint16(d[1]) != extensionALPN { - return false - } - l = int(d[2])<<8 | int(d[3]) - if l != len(d)-4 { - return false - } - l = int(d[4])<<8 | int(d[5]) - if l != len(d)-6 { - return false - } - d = d[6:] - l = int(d[0]) - if l != len(d)-1 { + extensionsLength := int(data[4])<<8 | int(data[5]) + data = data[6:] + if len(data) != extensionsLength { return false } - d = d[1:] - if len(d) == 0 { - // ALPN protocols must not be empty. - return false + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return false + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return false + } + d = d[1:] + if len(d) == 0 { + // ALPN protocols must not be empty. + return false + } + m.alpnProtocol = string(d) + case extensionEarlyData: + // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8 + m.earlyData = true + } + + data = data[length:] } - m.alpnProtocol = string(d) return true } diff --git a/handshake_messages_test.go b/handshake_messages_test.go index 1dd0e3b..c0d8ac3 100644 --- a/handshake_messages_test.go +++ b/handshake_messages_test.go @@ -164,6 +164,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { for i := range m.supportedVersions { m.supportedVersions[i] = uint16(rand.Intn(30000)) } + if rand.Intn(10) > 5 { + m.earlyData = true + } return reflect.ValueOf(m) } @@ -222,7 +225,12 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value { func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &encryptedExtensionsMsg{} - m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + if rand.Intn(10) > 5 { + m.earlyData = true + } return reflect.ValueOf(m) } @@ -328,10 +336,11 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { s := &sessionState13{} s.vers = uint16(rand.Intn(10000)) - s.hash = uint16(rand.Intn(10000)) + s.suite = uint16(rand.Intn(10000)) s.ageAdd = uint32(rand.Intn(0xffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) s.resumptionSecret = randomBytes(rand.Intn(100), rand) + s.alpnProtocol = randomString(rand.Intn(100), rand) return reflect.ValueOf(s) } diff --git a/handshake_server.go b/handshake_server.go index a66af85..0ba2ac7 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -46,7 +46,8 @@ type serverHandshakeState struct { hello13Enc *encryptedExtensionsMsg finishedHash13 hash.Hash clientFinishedKey []byte - clientCipher interface{} + hsClientCipher interface{} + appClientCipher interface{} } // serverHandshake performs a TLS handshake as a server. @@ -61,7 +62,6 @@ func (c *Conn) serverHandshake() error { } c.in.traceErr = hs.traceErr c.out.traceErr = hs.traceErr - defer func() { c.in.traceErr, c.out.traceErr = nil, nil }() isResume, err := hs.readClientHello() if err != nil { return err @@ -105,7 +105,7 @@ func (c *Conn) serverHandshake() error { return err } c.didResume = true - c.phase = handshakeComplete + c.phase = handshakeConfirmed } else { // The client didn't include a session ticket, or it wasn't // valid so we do a full handshake. @@ -129,7 +129,7 @@ func (c *Conn) serverHandshake() error { if _, err := c.flush(); err != nil { return err } - c.phase = handshakeComplete + c.phase = handshakeConfirmed } return nil @@ -910,6 +910,11 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature)) } + var pskBinder []byte + if len(hs.clientHello.psks) > 0 { + pskBinder = hs.clientHello.psks[0].binder + } + hs.cachedClientHelloInfo = &ClientHelloInfo{ CipherSuites: hs.clientHello.cipherSuites, ServerName: hs.clientHello.serverName, @@ -919,6 +924,8 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { SupportedProtos: hs.clientHello.alpnProtocols, SupportedVersions: supportedVersions, Conn: hs.c.conn, + Offered0RTTData: hs.clientHello.earlyData, + Fingerprint: pskBinder, } return hs.cachedClientHelloInfo diff --git a/ticket.go b/ticket.go index 1215f41..df4c219 100644 --- a/ticket.go +++ b/ticket.go @@ -131,10 +131,13 @@ func (s *sessionState) unmarshal(data []byte) bool { type sessionState13 struct { vers uint16 - hash uint16 // crypto.Hash value + suite uint16 ageAdd uint32 createdAt uint64 resumptionSecret []byte + alpnProtocol string + // TODO(filippo): add and check SNI + // TODO(filippo): add and check maxEarlyDataLength } func (s *sessionState13) equal(i interface{}) bool { @@ -144,19 +147,20 @@ func (s *sessionState13) equal(i interface{}) bool { } return s.vers == s1.vers && - s.hash == s1.hash && + s.suite == s1.suite && + s.alpnProtocol == s1.alpnProtocol && s.ageAdd == s1.ageAdd && bytes.Equal(s.resumptionSecret, s1.resumptionSecret) } func (s *sessionState13) marshal() []byte { - length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol) x := make([]byte, length) x[0] = byte(s.vers >> 8) x[1] = byte(s.vers) - x[2] = byte(s.hash >> 8) - x[3] = byte(s.hash) + x[2] = byte(s.suite >> 8) + x[3] = byte(s.suite) x[4] = byte(s.ageAdd >> 24) x[5] = byte(s.ageAdd >> 16) x[6] = byte(s.ageAdd >> 8) @@ -171,8 +175,11 @@ func (s *sessionState13) marshal() []byte { x[15] = byte(s.createdAt) x[16] = byte(len(s.resumptionSecret) >> 8) x[17] = byte(len(s.resumptionSecret)) - copy(x[18:], s.resumptionSecret) + z := x[18+len(s.resumptionSecret):] + z[0] = byte(len(s.alpnProtocol) >> 8) + z[1] = byte(len(s.alpnProtocol)) + copy(z[2:], s.alpnProtocol) return x } @@ -183,14 +190,19 @@ func (s *sessionState13) unmarshal(data []byte) bool { } s.vers = uint16(data[0])<<8 | uint16(data[1]) - s.hash = uint16(data[2])<<8 | uint16(data[3]) + s.suite = uint16(data[2])<<8 | uint16(data[3]) s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) - l := uint16(data[16])<<8 | uint16(data[17]) - s.resumptionSecret = data[18:] - - return int(l) == len(s.resumptionSecret) + l := int(data[16])<<8 | int(data[17]) + if len(data) < 18+l+2 { + return false + } + s.resumptionSecret = data[18 : 18+l] + z := data[18+l:] + l = int(z[0])<<8 | int(z[1]) + s.alpnProtocol = string(z[2:]) + return l == len(s.alpnProtocol) } func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { diff --git a/tls_test.go b/tls_test.go index fd0127c..d352be8 100644 --- a/tls_test.go +++ b/tls_test.go @@ -641,7 +641,7 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf("b")) case "ClientAuth": f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) - case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": + case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "Accept0RTTData": f.Set(reflect.ValueOf(true)) case "MinVersion", "MaxVersion": f.Set(reflect.ValueOf(uint16(VersionTLS12))) @@ -654,6 +654,8 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf([]CurveID{CurveP256})) case "Renegotiation": f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) + case "Max0RTTDataSize": + f.Set(reflect.ValueOf(uint32(0))) default: t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) }