crypto/tls: implement TLS 1.3 server 0-RTT

This commit is contained in:
Filippo Valsorda 2016-11-25 21:46:50 +00:00 committed by Peter Wu
parent 1117f76fcc
commit f8c15889af
17 changed files with 472 additions and 172 deletions

View File

@ -30,8 +30,8 @@ install:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh INSTALL $CLIENT $REVISION; fi - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh INSTALL $CLIENT $REVISION; fi
script: script:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA - if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT; fi
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA - if [ "$MODE" = "interop" ] && [ "$CLIENT" = "tstclnt" ]; then ./_dev/interop.sh 0-RTT $CLIENT; fi
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi - if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi
after_script: after_script:

95
13.go
View File

@ -60,34 +60,33 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
} }
c.didResume = isPSK c.didResume = isPSK
hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
handshakeCtx := hs.finishedHash13.Sum(nil)
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
ecdheSecret := deriveECDHESecret(ks, privateKey) ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil { if ecdheSecret == nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE client share") return errors.New("tls: bad ECDHE client share")
} }
hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
hs.finishedHash13.Write(hs.hello13.marshal()) hs.finishedHash13.Write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err return err
} }
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx := hs.finishedHash13.Sum(nil) handshakeCtx = hs.finishedHash13.Sum(nil)
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize) hs.hsClientCipher = clientCipher
cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen) serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
cIV := hkdfExpandLabel(hash, cHandshakeTS, nil, "iv", 12)
sHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "server handshake traffic secret", hashSize)
sKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "key", hs.suite.keyLen)
sIV := hkdfExpandLabel(hash, sHandshakeTS, nil, "iv", 12)
clientCipher := hs.suite.aead(cKey, cIV)
c.in.setCipher(c.vers, clientCipher)
serverCipher := hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
hs.finishedHash13.Write(hs.hello13Enc.marshal()) hs.finishedHash13.Write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err return err
@ -99,9 +98,6 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
} }
} }
serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey) verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
serverFinished := &finishedMsg{ serverFinished := &finishedMsg{
verifyData: verifyData, verifyData: verifyData,
@ -113,19 +109,20 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret) hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash13.Sum(nil) handshakeCtx = hs.finishedHash13.Sum(nil)
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize) serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen)
cIV = hkdfExpandLabel(hash, cTrafficSecret0, nil, "iv", 12)
sTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "server application traffic secret", hashSize)
sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen)
sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12)
hs.clientCipher = hs.suite.aead(cKey, cIV)
serverCipher = hs.suite.aead(sKey, sIV)
c.out.setCipher(c.vers, serverCipher) c.out.setCipher(c.vers, serverCipher)
if hs.hello13Enc.earlyData {
c.in.setCipher(c.vers, earlyClientCipher)
c.phase = readingEarlyData
} else if hs.clientHello.earlyData {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = discardingEarlyData
} else {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = waitingClientFinished c.phase = waitingClientFinished
}
return nil return nil
} }
@ -157,11 +154,10 @@ func (hs *serverHandshakeState) readClientFinished13() error {
} }
hs.finishedHash13.Write(clientFinished.marshal()) hs.finishedHash13.Write(clientFinished.marshal())
c.in.setCipher(c.vers, hs.clientCipher) c.hs = nil // Discard the server handshake state
c.phase = handshakeConfirmed
// Discard the server handshake state c.in.setCipher(c.vers, hs.appClientCipher)
c.hs = nil c.in.traceErr, c.out.traceErr = nil, nil
c.phase = handshakeComplete
return hs.sendSessionTicket13() return hs.sendSessionTicket13()
} }
@ -209,6 +205,15 @@ func (hs *serverHandshakeState) sendCertificate13() error {
return nil return nil
} }
func (c *Conn) handleEndOfEarlyData() {
if c.phase != readingEarlyData || c.vers < VersionTLS13 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
return
}
c.phase = waitingClientFinished
c.in.setCipher(c.vers, c.hs.hsClientCipher)
}
// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify // selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify
// based on the certificate type and client supported schemes. If no overlap is found, // based on the certificate type and client supported schemes. If no overlap is found,
// a fallback is selected. // a fallback is selected.
@ -377,6 +382,14 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
return h.Sum(nil) return h.Sum(nil)
} }
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
hash := hashForSuite(hs.suite)
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return hs.suite.aead(key, iv), trafficSecret
}
// Maximum allowed mismatch between the stated age of a ticket // Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See // and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
@ -418,7 +431,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance { if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
continue continue
} }
if s.hash != uint16(hash) {
// This enforces the stricter 0-RTT requirements on all ticket uses.
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
// we can give them up in the edge case of changed suite or ALPN.
if s.suite != hs.suite.id {
continue
}
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
continue continue
} }
@ -433,6 +453,9 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
hs.hello13.psk = true hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i) hs.hello13.pskIdentity = uint16(i)
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
hs.hello13Enc.earlyData = true
}
return earlySecret, true return earlySecret, true
} }
} }
@ -468,7 +491,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
} }
sessionState := &sessionState13{ sessionState := &sessionState13{
vers: c.vers, vers: c.vers,
hash: uint16(hash), suite: hs.suite.id,
ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 | ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]), uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
createdAt: uint64(time.Now().Unix()), createdAt: uint64(time.Now().Unix()),
@ -481,7 +504,9 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return err return err
} }
ticketMsg := &newSessionTicketMsg13{ ticketMsg := &newSessionTicketMsg13{
lifetime: 21600, // TODO(filippo) lifetime: 24 * 3600, // TODO(filippo)
maxEarlyDataLength: c.config.Max0RTTDataSize,
withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
ageAdd: sessionState.ageAdd, ageAdd: sessionState.ageAdd,
ticket: ticket, ticket: ticket,
} }

View File

@ -20,7 +20,9 @@ ifeq ($(shell go env CGO_ENABLED),1)
endif endif
@touch "$@" @touch "$@"
GO_COMMIT := 5782050a487e002acfd14a3f4c2c815c7854928c # Note: when changing this, if it doesn't change the Go version
# (it should), you need to run make clean.
GO_COMMIT := 5d4f37266d324e48f67b3bb82c6090f1aa94c013
.PHONY: go .PHONY: go
go: go/.ok_$(GO_COMMIT)_$(GOENV) go: go/.ok_$(GO_COMMIT)_$(GOENV)

View File

@ -11,8 +11,25 @@ if [ "$1" = "INSTALL" ]; then
elif [ "$1" = "RUN" ]; then elif [ "$1" = "RUN" ]; then
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver) IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)
docker run --rm tls-tris:$2 $IP:$3 | tee output.txt
grep "Hello TLS 1.3" output.txt | grep -v "resumed" docker run --rm tls-tris:$2 $IP:1443 | tee output.txt # RSA
grep "Hello TLS 1.3" output.txt | grep "resumed" grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
docker run --rm tls-tris:$2 $IP:2443 | tee output.txt # ECDSA
grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
elif [ "$1" = "0-RTT" ]; then
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)
docker run --rm tls-tris:$2 $IP:3443 | tee output.txt # rejecting 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
docker run --rm tls-tris:$2 $IP:4443 | tee output.txt # accepting 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep "0-RTT"
docker run --rm tls-tris:$2 $IP:5443 | tee output.txt # confirming 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
fi fi

View File

@ -2,10 +2,13 @@ FROM scratch
ENV TLSDEBUG error ENV TLSDEBUG error
EXPOSE 443 EXPOSE 1443
EXPOSE 2443
EXPOSE 3443
EXPOSE 4443 EXPOSE 4443
EXPOSE 5443
# GOOS=linux ../go.sh build -v -i . # GOOS=linux ../go.sh build -v -i .
ADD tris-localserver ./ ADD tris-localserver ./
CMD [ "./tris-localserver", "0.0.0.0:443", "0.0.0.0:4443" ] CMD [ "./tris-localserver", "0.0.0.0:1443", "0.0.0.0:2443", "0.0.0.0:3443", "0.0.0.0:4443", "0.0.0.0:5443" ]

View File

@ -17,13 +17,49 @@ var tlsVersionToName = map[uint16]string{
tls.VersionTLS13Draft18: "1.3 (draft 18)", tls.VersionTLS13Draft18: "1.3 (draft 18)",
} }
func startServer(addr string, rsa, offer0RTT, accept0RTT bool) {
cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
if rsa {
cert, err = tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
}
if err != nil {
log.Fatal(err)
}
var Max0RTTDataSize uint32
if offer0RTT {
Max0RTTDataSize = 100 * 1024
}
s := &http.Server{
Addr: addr,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
Max0RTTDataSize: Max0RTTDataSize,
Accept0RTTData: accept0RTT,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
}
var confirmingAddr string
func main() { func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
tlsConn := r.Context().Value(http.TLSConnContextKey).(*tls.Conn)
server := r.Context().Value(http.ServerContextKey).(*http.Server)
if server.Addr == confirmingAddr {
if err := tlsConn.ConfirmHandshake(); err != nil {
log.Fatal(err)
}
}
resumed := "" resumed := ""
if r.TLS.DidResume { if r.TLS.DidResume {
resumed = " [resumed]" resumed = " [resumed]"
} }
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed) with0RTT := ""
if !tlsConn.ConnectionState().HandshakeConfirmed {
with0RTT = " [0-RTT]"
}
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed, with0RTT)
}) })
http.HandleFunc("/ch", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/ch", func(w http.ResponseWriter, r *http.Request) {
@ -31,36 +67,17 @@ func main() {
fmt.Fprintf(w, "Client Hello packet (%d bytes):\n%s", len(r.TLS.ClientHello), hex.Dump(r.TLS.ClientHello)) fmt.Fprintf(w, "Client Hello packet (%d bytes):\n%s", len(r.TLS.ClientHello), hex.Dump(r.TLS.ClientHello))
}) })
go func() { switch len(os.Args) {
if len(os.Args) < 3 { case 2:
return startServer(os.Args[1], true, true, true)
case 6:
confirmingAddr = os.Args[5]
go startServer(os.Args[1], false, false, false) // first port: ECDSA (and no 0-RTT)
go startServer(os.Args[2], true, false, true) // second port: RSA (and accept 0-RTT but not offer it)
go startServer(os.Args[3], false, true, false) // third port: offer and reject 0-RTT
go startServer(os.Args[4], false, true, true) // fourth port: offer and accept 0-RTT
startServer(os.Args[5], false, true, true) // fifth port: offer and accept 0-RTT but confirm
} }
cert, err := tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
if err != nil {
log.Fatal(err)
}
s := &http.Server{
Addr: os.Args[2],
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
}()
cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
if err != nil {
log.Fatal(err)
}
s := &http.Server{
Addr: os.Args[1],
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
} }
const ( const (

View File

@ -5,4 +5,4 @@ shift
HOST="${ADDR[0]}" HOST="${ADDR[0]}"
PORT="${ADDR[1]}" PORT="${ADDR[1]}"
exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 "$@" exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 -Z "$@"

View File

@ -16,6 +16,7 @@ const (
const ( const (
alertCloseNotify alert = 0 alertCloseNotify alert = 0
alertEndOfEarlyData alert = 1
alertUnexpectedMessage alert = 10 alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20 alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21 alertDecryptionFailed alert = 21

View File

@ -85,6 +85,7 @@ const (
extensionSessionTicket uint16 = 35 extensionSessionTicket uint16 = 35
extensionKeyShare uint16 = 40 extensionKeyShare uint16 = 40
extensionPreSharedKey uint16 = 41 extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43 extensionSupportedVersions uint16 = 43
extensionPSKKeyExchangeModes uint16 = 45 extensionPSKKeyExchangeModes uint16 = 45
extensionTicketEarlyDataInfo uint16 = 46 extensionTicketEarlyDataInfo uint16 = 46
@ -213,6 +214,10 @@ type ConnectionState struct {
// been standardized and implemented. // been standardized and implemented.
TLSUnique []byte TLSUnique []byte
// HandshakeConfirmed is true once all data returned by Read
// (past and future) is guaranteed not to be replayed.
HandshakeConfirmed bool
ClientHello []byte // ClientHello packet ClientHello []byte // ClientHello packet
} }
@ -322,6 +327,18 @@ type ClientHelloInfo struct {
// from, or write to, this connection; that will cause the TLS // from, or write to, this connection; that will cause the TLS
// connection to fail. // connection to fail.
Conn net.Conn Conn net.Conn
// Offered0RTTData is true if the client announced that it will send
// 0-RTT data. If the server Config.Accept0RTTData is true, and the
// client offered a session ticket valid for that purpose, it will
// be notified that the 0-RTT data is accepted and it will be made
// immediately available for Read.
Offered0RTTData bool
// The Fingerprint is an sequence of bytes unique to this Client Hello.
// It can be used to prevent or mitigate 0-RTT data replays as it's
// guaranteed that a replayed connection will have the same Fingerprint.
Fingerprint []byte
} }
// CertificateRequestInfo contains information from a server's // CertificateRequestInfo contains information from a server's
@ -548,6 +565,28 @@ type Config struct {
// used for debugging. // used for debugging.
KeyLogWriter io.Writer KeyLogWriter io.Writer
// If Max0RTTDataSize is not zero, the client will be allowed to use
// session tickets to send at most this number of bytes of 0-RTT data.
// 0-RTT data is subject to replay and has memory DoS implications.
// The server will later be able to refuse the 0-RTT data with
// Accept0RTTData, or wait for the client to prove that it's not
// replayed with Conn.ConfirmHandshake.
//
// It has no meaning on the client.
//
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
Max0RTTDataSize uint32
// Accept0RTTData makes the 0-RTT data received from the client
// immediately available to Read. 0-RTT data is subject to replay.
// Use Conn.ConfirmHandshake to wait until the data is known not
// to be replayed after reading it.
//
// It has no meaning on the client.
//
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
Accept0RTTData bool
serverInitOnce sync.Once // guards calling (*Config).serverInit serverInitOnce sync.Once // guards calling (*Config).serverInit
// mutex protects sessionTicketKeys. // mutex protects sessionTicketKeys.
@ -622,6 +661,8 @@ func (c *Config) Clone() *Config {
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation, Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter, KeyLogWriter: c.KeyLogWriter,
Accept0RTTData: c.Accept0RTTData,
Max0RTTDataSize: c.Max0RTTDataSize,
sessionTicketKeys: sessionTicketKeys, sessionTicketKeys: sessionTicketKeys,
} }
} }

181
conn.go
View File

@ -27,14 +27,15 @@ type Conn struct {
conn net.Conn conn net.Conn
isClient bool isClient bool
phase handshakePhase
// constant after handshake; protected by handshakeMutex // constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed // handshakeCond, if not nil, indicates that a goroutine is committed
// to running the handshake for this Conn. Other goroutines that need // to running the handshake for this Conn. Other goroutines that need
// to wait for the handshake can wait on this, under handshakeMutex. // to wait for the handshake can wait on this, under handshakeMutex.
handshakeCond *sync.Cond handshakeCond *sync.Cond
// The transition from handshakeRunning to the next phase is covered by
// handshakeMutex. All others by in.Mutex.
phase handshakeStatus
handshakeErr error // error resulting from handshake handshakeErr error // error resulting from handshake
connID []byte // Random connection id connID []byte // Random connection id
clientHello []byte // ClientHello packet contents clientHello []byte // ClientHello packet contents
@ -103,16 +104,22 @@ type Conn struct {
// TLS 1.3 needs the server state until it reaches the Client Finished // TLS 1.3 needs the server state until it reaches the Client Finished
hs *serverHandshakeState hs *serverHandshakeState
// earlyDataBytes is the number of bytes of early data received so
// far. Tracked to enforce max_early_data_size.
earlyDataBytes int64
tmp [16]byte tmp [16]byte
} }
type handshakePhase int type handshakeStatus int
const ( const (
earlyHandshake handshakePhase = iota handshakeRunning handshakeStatus = iota
discardingEarlyData
readingEarlyData
waitingClientFinished waitingClientFinished
readingClientFinished readingClientFinished
handshakeComplete handshakeConfirmed
) )
// Access to net.Conn methods. // Access to net.Conn methods.
@ -548,6 +555,9 @@ func (b *block) readFromUntil(r io.Reader, n int) error {
func (b *block) Read(p []byte) (n int, err error) { func (b *block) Read(p []byte) (n int, err error) {
n = copy(p, b.data[b.off:]) n = copy(p, b.data[b.off:])
b.off += n b.off += n
if b.off >= len(b.data) {
err = io.EOF
}
return return
} }
@ -606,6 +616,7 @@ func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
// readRecord reads the next TLS record from the connection // readRecord reads the next TLS record from the connection
// and updates the record layer state. // and updates the record layer state.
// c.in.Mutex <= L; c.input == nil. // c.in.Mutex <= L; c.input == nil.
// c.input can still be nil after a call, retry if so.
func (c *Conn) readRecord(want recordType) error { func (c *Conn) readRecord(want recordType) error {
// Caller must be in sync with connection: // Caller must be in sync with connection:
// handshake data if handshake not yet completed, // handshake data if handshake not yet completed,
@ -615,18 +626,17 @@ func (c *Conn) readRecord(want recordType) error {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec: case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.phase != earlyHandshake && c.phase != readingClientFinished { if c.phase != handshakeRunning && c.phase != readingClientFinished {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake")) return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if c.phase == earlyHandshake || c.phase == earlyHandshake { if c.phase == handshakeRunning || c.phase == readingClientFinished {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake")) return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
} }
} }
Again:
if c.rawInput == nil { if c.rawInput == nil {
c.rawInput = c.in.newBlock() c.rawInput = c.in.newBlock()
} }
@ -686,7 +696,15 @@ Again:
// Process message. // Process message.
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, alertValue := c.in.decrypt(b) ok, off, alertValue := c.in.decrypt(b)
if !ok { switch {
case !ok && c.phase == discardingEarlyData:
// If the client said that it's sending early data and we did not
// accept it, we are expected to fail decryption.
c.in.freeBlock(b)
return nil
case ok && c.phase == discardingEarlyData:
c.phase = waitingClientFinished
case !ok:
c.in.freeBlock(b) c.in.freeBlock(b)
return c.in.setErrorLocked(c.sendAlert(alertValue)) return c.in.setErrorLocked(c.sendAlert(alertValue))
} }
@ -730,11 +748,15 @@ Again:
c.in.setErrorLocked(io.EOF) c.in.setErrorLocked(io.EOF)
break break
} }
if alert(data[1]) == alertEndOfEarlyData {
c.handleEndOfEarlyData()
break
}
switch data[0] { switch data[0] {
case alertLevelWarning: case alertLevelWarning:
// drop on the floor // drop on the floor
c.in.freeBlock(b) c.in.freeBlock(b)
goto Again return nil
case alertLevelError: case alertLevelError:
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default: default:
@ -742,7 +764,7 @@ Again:
} }
case recordTypeChangeCipherSpec: case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 { if typ != want || len(data) != 1 || data[0] != 1 || c.vers >= VersionTLS13 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break break
} }
@ -752,11 +774,7 @@ Again:
} }
case recordTypeApplicationData: case recordTypeApplicationData:
if c.phase == waitingClientFinished { if typ != want || c.phase == waitingClientFinished {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if typ != want {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break break
} }
@ -775,7 +793,6 @@ Again:
c.in.setErrorLocked(err) c.in.setErrorLocked(err)
break break
} }
goto Again
} }
} }
@ -1131,7 +1148,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err return 0, err
} }
if c.phase == earlyHandshake { if c.phase == handshakeRunning {
return 0, alertInternalError return 0, alertInternalError
} }
@ -1181,6 +1198,10 @@ func (c *Conn) handleRenegotiation() error {
return c.sendAlert(alertNoRenegotiation) return c.sendAlert(alertNoRenegotiation)
} }
if c.vers >= VersionTLS13 {
return c.sendAlert(alertNoRenegotiation)
}
switch c.config.Renegotiation { switch c.config.Renegotiation {
case RenegotiateNever: case RenegotiateNever:
return c.sendAlert(alertNoRenegotiation) return c.sendAlert(alertNoRenegotiation)
@ -1198,13 +1219,102 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
c.phase = earlyHandshake c.phase = handshakeRunning
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil { if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++ c.handshakes++
} }
return c.handshakeErr return c.handshakeErr
} }
// ConfirmHandshake waits for the handshake to reach a point at which
// the connection is certainly not replayed. That is, after receiving
// the Client Finished.
//
// If ConfirmHandshake returns an error and until ConfirmHandshake
// returns, the 0-RTT data should not be trusted not to be replayed.
//
// This is only meaningful in TLS 1.3 when Accept0RTTData is true and the
// client sent valid 0-RTT data. In any other case it's equivalent to
// calling Handshake.
func (c *Conn) ConfirmHandshake() error {
if err := c.Handshake(); err != nil {
return err
}
c.in.Lock()
defer c.in.Unlock()
if c.phase == handshakeConfirmed {
return nil
}
var input *block
if c.phase == readingEarlyData || c.input != nil {
buf := &bytes.Buffer{}
if _, err := buf.ReadFrom(earlyDataReader{c}); err != nil {
c.in.setErrorLocked(err)
return err
}
input = &block{data: buf.Bytes()}
}
for c.phase != handshakeConfirmed {
if err := c.readRecord(recordTypeApplicationData); err != nil {
c.in.setErrorLocked(err)
return err
}
}
if c.phase != handshakeConfirmed {
panic("should have reached handshakeConfirmed state")
}
if c.input != nil {
panic("should not have read past the Client Finished")
}
c.input = input
return nil
}
// earlyDataReader wraps a Conn with in locked and reads only early data,
// both buffered and still on the wire.
type earlyDataReader struct {
c *Conn
}
func (r earlyDataReader) Read(b []byte) (n int, err error) {
c := r.c
if c.phase == handshakeConfirmed {
// c.input might not be early data
panic("earlyDataReader called at handshakeConfirmed")
}
for c.input == nil && c.in.err == nil && c.phase == readingEarlyData {
if err := c.readRecord(recordTypeApplicationData); err != nil {
return 0, err
}
}
if err := c.in.err; err != nil {
return 0, err
}
if c.input != nil {
n, err = c.input.Read(b)
if err == io.EOF {
err = nil
c.in.freeBlock(c.input)
c.input = nil
}
}
if err == nil && c.phase != readingEarlyData && c.input == nil {
err = io.EOF
}
return
}
// Read can be made to time out and return a net.Error with Timeout() == true // Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline. // after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
@ -1242,7 +1352,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
} }
n, err = c.input.Read(b) n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) { if err == io.EOF {
err = nil
c.in.freeBlock(c.input) c.in.freeBlock(c.input)
c.input = nil c.input = nil
} }
@ -1300,7 +1411,17 @@ func (c *Conn) Close() error {
var alertErr error var alertErr error
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
if c.phase != earlyHandshake { for c.phase == readingEarlyData {
if err := c.readRecord(recordTypeApplicationData); err != nil {
alertErr = err
}
}
if alertErr == nil && c.phase == waitingClientFinished {
if err := c.hs.readClientFinished13(); err != nil {
alertErr = err
}
}
if alertErr == nil && c.phase != handshakeRunning {
alertErr = c.closeNotify() alertErr = c.closeNotify()
} }
c.handshakeMutex.Unlock() c.handshakeMutex.Unlock()
@ -1319,7 +1440,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error { func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock() c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
if c.phase == earlyHandshake { if c.phase == handshakeRunning {
return errEarlyCloseWrite return errEarlyCloseWrite
} }
@ -1341,8 +1462,11 @@ func (c *Conn) closeNotify() error {
// protocol if it has not yet been run. // protocol if it has not yet been run.
// Most uses of this package need not call Handshake // Most uses of this package need not call Handshake
// explicitly: the first Read or Write will call it automatically. // explicitly: the first Read or Write will call it automatically.
//
// In TLS 1.3 Handshake returns after the client and server first flights,
// without waiting for the Client Finished.
func (c *Conn) Handshake() error { func (c *Conn) Handshake() error {
// c.handshakeErr and c.phase == earlyHandshake are protected by // c.handshakeErr and c.phase == handshakeRunning are protected by
// c.handshakeMutex. In order to perform a handshake, we need to lock // c.handshakeMutex. In order to perform a handshake, we need to lock
// c.in also and c.handshakeMutex must be locked after c.in. // c.in also and c.handshakeMutex must be locked after c.in.
// //
@ -1371,7 +1495,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil { if err := c.handshakeErr; err != nil {
return err return err
} }
if c.phase != earlyHandshake { if c.phase != handshakeRunning {
return nil return nil
} }
if c.handshakeCond == nil { if c.handshakeCond == nil {
@ -1393,7 +1517,7 @@ func (c *Conn) Handshake() error {
// The handshake cannot have completed when handshakeMutex was unlocked // The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond. // because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.phase != earlyHandshake { if c.handshakeErr != nil || c.phase != handshakeRunning {
panic("handshake should not have been able to complete after handshakeCond was set") panic("handshake should not have been able to complete after handshakeCond was set")
} }
@ -1415,7 +1539,7 @@ func (c *Conn) Handshake() error {
c.flush() c.flush()
} }
if c.handshakeErr == nil && c.phase == earlyHandshake { if c.handshakeErr == nil && c.phase == handshakeRunning {
panic("handshake should have had a result.") panic("handshake should have had a result.")
} }
@ -1433,7 +1557,7 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock() defer c.handshakeMutex.Unlock()
var state ConnectionState var state ConnectionState
state.HandshakeComplete = c.phase != earlyHandshake state.HandshakeComplete = c.phase != handshakeRunning
state.ServerName = c.serverName state.ServerName = c.serverName
if state.HandshakeComplete { if state.HandshakeComplete {
@ -1448,6 +1572,7 @@ func (c *Conn) ConnectionState() ConnectionState {
state.VerifiedChains = c.verifiedChains state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse state.OCSPResponse = c.ocspResponse
state.HandshakeConfirmed = c.phase == handshakeConfirmed
if !c.didResume { if !c.didResume {
if c.clientFinishedIsFirst { if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:] state.TLSUnique = c.clientFinished[:]
@ -1478,7 +1603,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient { if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection") return errors.New("tls: VerifyHostname called on TLS server connection")
} }
if c.phase == earlyHandshake { if c.phase == handshakeRunning {
return errors.New("tls: handshake has not yet been performed") return errors.New("tls: handshake has not yet been performed")
} }
if len(c.verifiedChains) == 0 { if len(c.verifiedChains) == 0 {

View File

@ -251,7 +251,7 @@ NextCipherSuite:
} }
c.didResume = isResume c.didResume = isResume
c.phase = handshakeComplete c.phase = handshakeConfirmed
c.cipherSuite = suite.id c.cipherSuite = suite.id
return nil return nil
} }

View File

@ -64,10 +64,10 @@ func (i opensslInput) Read(buf []byte) (n int, err error) {
} }
// opensslOutputSink is an io.Writer that receives the stdout and stderr from // opensslOutputSink is an io.Writer that receives the stdout and stderr from
// an `openssl` process and sends a value to handshakeComplete when it sees a // an `openssl` process and sends a value to handshakeConfirmed when it sees a
// log message from a completed server handshake. // log message from a completed server handshake.
type opensslOutputSink struct { type opensslOutputSink struct {
handshakeComplete chan struct{} handshakeConfirmed chan struct{}
all []byte all []byte
line []byte line []byte
} }
@ -91,7 +91,7 @@ func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
} }
if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) { if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
o.handshakeComplete <- struct{}{} o.handshakeConfirmed <- struct{}{}
} }
o.line = o.line[i+1:] o.line = o.line[i+1:]
} }
@ -315,9 +315,9 @@ func (test *clientTest) run(t *testing.T, write bool) {
for i := 1; i <= test.numRenegotiations; i++ { for i := 1; i <= test.numRenegotiations; i++ {
// The initial handshake will generate a // The initial handshake will generate a
// handshakeComplete signal which needs to be quashed. // handshakeConfirmed signal which needs to be quashed.
if i == 1 && write { if i == 1 && write {
<-stdout.handshakeComplete <-stdout.handshakeConfirmed
} }
// OpenSSL will try to interleave application data and // OpenSSL will try to interleave application data and
@ -364,7 +364,7 @@ func (test *clientTest) run(t *testing.T, write bool) {
}() }()
if write && test.renegotiationExpectedToFail != i { if write && test.renegotiationExpectedToFail != i {
<-stdout.handshakeComplete <-stdout.handshakeConfirmed
stdin <- opensslSendSentinel stdin <- opensslSendSentinel
} }
<-signalChan <-signalChan

View File

@ -33,6 +33,7 @@ type clientHelloMsg struct {
supportedVersions []uint16 supportedVersions []uint16
psks []psk psks []psk
pskKeyExchangeModes []uint8 pskKeyExchangeModes []uint8
earlyData bool
} }
func (m *clientHelloMsg) equal(i interface{}) bool { func (m *clientHelloMsg) equal(i interface{}) bool {
@ -60,7 +61,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols) && eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
eqKeyShares(m.keyShares, m1.keyShares) && eqKeyShares(m.keyShares, m1.keyShares) &&
eqUint16s(m.supportedVersions, m1.supportedVersions) eqUint16s(m.supportedVersions, m1.supportedVersions) &&
m.earlyData == m1.earlyData
} }
func (m *clientHelloMsg) marshal() []byte { func (m *clientHelloMsg) marshal() []byte {
@ -127,6 +129,9 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += 1 + 2*len(m.supportedVersions) extensionsLength += 1 + 2*len(m.supportedVersions)
numExtensions++ numExtensions++
} }
if m.earlyData {
numExtensions++
}
if numExtensions > 0 { if numExtensions > 0 {
extensionsLength += 4 * numExtensions extensionsLength += 4 * numExtensions
length += 2 + extensionsLength length += 2 + extensionsLength
@ -350,6 +355,11 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[2:] z = z[2:]
} }
} }
if m.earlyData {
z[0] = byte(extensionEarlyData >> 8)
z[1] = byte(extensionEarlyData)
z = z[4:]
}
m.raw = x m.raw = x
@ -413,6 +423,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.supportedVersions = nil m.supportedVersions = nil
m.psks = nil m.psks = nil
m.pskKeyExchangeModes = nil m.pskKeyExchangeModes = nil
m.earlyData = false
if len(data) == 0 { if len(data) == 0 {
// ClientHello is optionally followed by extension data // ClientHello is optionally followed by extension data
@ -668,6 +679,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false return false
} }
m.pskKeyExchangeModes = data[1:length] m.pskKeyExchangeModes = data[1:length]
case extensionEarlyData:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
m.earlyData = true
} }
data = data[length:] data = data[length:]
bindersOffset += length bindersOffset += length
@ -1144,6 +1158,7 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
type encryptedExtensionsMsg struct { type encryptedExtensionsMsg struct {
raw []byte raw []byte
alpnProtocol string alpnProtocol string
earlyData bool
} }
func (m *encryptedExtensionsMsg) equal(i interface{}) bool { func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
@ -1153,7 +1168,8 @@ func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
} }
return bytes.Equal(m.raw, m1.raw) && return bytes.Equal(m.raw, m1.raw) &&
m.alpnProtocol == m1.alpnProtocol m.alpnProtocol == m1.alpnProtocol &&
m.earlyData == m1.earlyData
} }
func (m *encryptedExtensionsMsg) marshal() []byte { func (m *encryptedExtensionsMsg) marshal() []byte {
@ -1163,6 +1179,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
length := 2 length := 2
if m.earlyData {
length += 4
}
alpnLen := len(m.alpnProtocol) alpnLen := len(m.alpnProtocol)
if alpnLen > 0 { if alpnLen > 0 {
if alpnLen >= 256 { if alpnLen >= 256 {
@ -1196,6 +1215,12 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
z = z[7+alpnLen:] z = z[7+alpnLen:]
} }
if m.earlyData {
z[0] = byte(extensionEarlyData >> 8)
z[1] = byte(extensionEarlyData)
z = z[4:]
}
m.raw = x m.raw = x
return x return x
} }
@ -1205,31 +1230,38 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
return false return false
} }
m.raw = data m.raw = data
l := int(data[4])<<8 | int(data[5])
if l != len(data)-6 {
return false
}
m.alpnProtocol = "" m.alpnProtocol = ""
if l == 0 { m.earlyData = false
return true
extensionsLength := int(data[4])<<8 | int(data[5])
data = data[6:]
if len(data) != extensionsLength {
return false
} }
d := data[6:] for len(data) != 0 {
if len(d) < 5 { if len(data) < 4 {
return false return false
} }
if uint16(d[0])<<8|uint16(d[1]) != extensionALPN { extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return false return false
} }
l = int(d[2])<<8 | int(d[3])
if l != len(d)-4 { switch extension {
case extensionALPN:
d := data[:length]
if len(d) < 3 {
return false return false
} }
l = int(d[4])<<8 | int(d[5]) l := int(d[0])<<8 | int(d[1])
if l != len(d)-6 { if l != len(d)-2 {
return false return false
} }
d = d[6:] d = d[2:]
l = int(d[0]) l = int(d[0])
if l != len(d)-1 { if l != len(d)-1 {
return false return false
@ -1240,6 +1272,13 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
return false return false
} }
m.alpnProtocol = string(d) m.alpnProtocol = string(d)
case extensionEarlyData:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
m.earlyData = true
}
data = data[length:]
}
return true return true
} }

View File

@ -164,6 +164,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
for i := range m.supportedVersions { for i := range m.supportedVersions {
m.supportedVersions[i] = uint16(rand.Intn(30000)) m.supportedVersions[i] = uint16(rand.Intn(30000))
} }
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -222,7 +225,12 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{} m := &encryptedExtensionsMsg{}
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand) m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}
return reflect.ValueOf(m) return reflect.ValueOf(m)
} }
@ -328,10 +336,11 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value { func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionState13{} s := &sessionState13{}
s.vers = uint16(rand.Intn(10000)) s.vers = uint16(rand.Intn(10000))
s.hash = uint16(rand.Intn(10000)) s.suite = uint16(rand.Intn(10000))
s.ageAdd = uint32(rand.Intn(0xffffffff)) s.ageAdd = uint32(rand.Intn(0xffffffff))
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff)) s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
s.resumptionSecret = randomBytes(rand.Intn(100), rand) s.resumptionSecret = randomBytes(rand.Intn(100), rand)
s.alpnProtocol = randomString(rand.Intn(100), rand)
return reflect.ValueOf(s) return reflect.ValueOf(s)
} }

View File

@ -46,7 +46,8 @@ type serverHandshakeState struct {
hello13Enc *encryptedExtensionsMsg hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash finishedHash13 hash.Hash
clientFinishedKey []byte clientFinishedKey []byte
clientCipher interface{} hsClientCipher interface{}
appClientCipher interface{}
} }
// serverHandshake performs a TLS handshake as a server. // serverHandshake performs a TLS handshake as a server.
@ -61,7 +62,6 @@ func (c *Conn) serverHandshake() error {
} }
c.in.traceErr = hs.traceErr c.in.traceErr = hs.traceErr
c.out.traceErr = hs.traceErr c.out.traceErr = hs.traceErr
defer func() { c.in.traceErr, c.out.traceErr = nil, nil }()
isResume, err := hs.readClientHello() isResume, err := hs.readClientHello()
if err != nil { if err != nil {
return err return err
@ -105,7 +105,7 @@ func (c *Conn) serverHandshake() error {
return err return err
} }
c.didResume = true c.didResume = true
c.phase = handshakeComplete c.phase = handshakeConfirmed
} else { } else {
// The client didn't include a session ticket, or it wasn't // The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake. // valid so we do a full handshake.
@ -129,7 +129,7 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil { if _, err := c.flush(); err != nil {
return err return err
} }
c.phase = handshakeComplete c.phase = handshakeConfirmed
} }
return nil return nil
@ -910,6 +910,11 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature)) signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
} }
var pskBinder []byte
if len(hs.clientHello.psks) > 0 {
pskBinder = hs.clientHello.psks[0].binder
}
hs.cachedClientHelloInfo = &ClientHelloInfo{ hs.cachedClientHelloInfo = &ClientHelloInfo{
CipherSuites: hs.clientHello.cipherSuites, CipherSuites: hs.clientHello.cipherSuites,
ServerName: hs.clientHello.serverName, ServerName: hs.clientHello.serverName,
@ -919,6 +924,8 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
SupportedProtos: hs.clientHello.alpnProtocols, SupportedProtos: hs.clientHello.alpnProtocols,
SupportedVersions: supportedVersions, SupportedVersions: supportedVersions,
Conn: hs.c.conn, Conn: hs.c.conn,
Offered0RTTData: hs.clientHello.earlyData,
Fingerprint: pskBinder,
} }
return hs.cachedClientHelloInfo return hs.cachedClientHelloInfo

View File

@ -131,10 +131,13 @@ func (s *sessionState) unmarshal(data []byte) bool {
type sessionState13 struct { type sessionState13 struct {
vers uint16 vers uint16
hash uint16 // crypto.Hash value suite uint16
ageAdd uint32 ageAdd uint32
createdAt uint64 createdAt uint64
resumptionSecret []byte resumptionSecret []byte
alpnProtocol string
// TODO(filippo): add and check SNI
// TODO(filippo): add and check maxEarlyDataLength
} }
func (s *sessionState13) equal(i interface{}) bool { func (s *sessionState13) equal(i interface{}) bool {
@ -144,19 +147,20 @@ func (s *sessionState13) equal(i interface{}) bool {
} }
return s.vers == s1.vers && return s.vers == s1.vers &&
s.hash == s1.hash && s.suite == s1.suite &&
s.alpnProtocol == s1.alpnProtocol &&
s.ageAdd == s1.ageAdd && s.ageAdd == s1.ageAdd &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret) bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
} }
func (s *sessionState13) marshal() []byte { func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)
x := make([]byte, length) x := make([]byte, length)
x[0] = byte(s.vers >> 8) x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers) x[1] = byte(s.vers)
x[2] = byte(s.hash >> 8) x[2] = byte(s.suite >> 8)
x[3] = byte(s.hash) x[3] = byte(s.suite)
x[4] = byte(s.ageAdd >> 24) x[4] = byte(s.ageAdd >> 24)
x[5] = byte(s.ageAdd >> 16) x[5] = byte(s.ageAdd >> 16)
x[6] = byte(s.ageAdd >> 8) x[6] = byte(s.ageAdd >> 8)
@ -171,8 +175,11 @@ func (s *sessionState13) marshal() []byte {
x[15] = byte(s.createdAt) x[15] = byte(s.createdAt)
x[16] = byte(len(s.resumptionSecret) >> 8) x[16] = byte(len(s.resumptionSecret) >> 8)
x[17] = byte(len(s.resumptionSecret)) x[17] = byte(len(s.resumptionSecret))
copy(x[18:], s.resumptionSecret) copy(x[18:], s.resumptionSecret)
z := x[18+len(s.resumptionSecret):]
z[0] = byte(len(s.alpnProtocol) >> 8)
z[1] = byte(len(s.alpnProtocol))
copy(z[2:], s.alpnProtocol)
return x return x
} }
@ -183,14 +190,19 @@ func (s *sessionState13) unmarshal(data []byte) bool {
} }
s.vers = uint16(data[0])<<8 | uint16(data[1]) s.vers = uint16(data[0])<<8 | uint16(data[1])
s.hash = uint16(data[2])<<8 | uint16(data[3]) s.suite = uint16(data[2])<<8 | uint16(data[3])
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 | s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15]) uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
l := uint16(data[16])<<8 | uint16(data[17]) l := int(data[16])<<8 | int(data[17])
s.resumptionSecret = data[18:] if len(data) < 18+l+2 {
return false
return int(l) == len(s.resumptionSecret) }
s.resumptionSecret = data[18 : 18+l]
z := data[18+l:]
l = int(z[0])<<8 | int(z[1])
s.alpnProtocol = string(z[2:])
return l == len(s.alpnProtocol)
} }
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {

View File

@ -641,7 +641,7 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf("b")) f.Set(reflect.ValueOf("b"))
case "ClientAuth": case "ClientAuth":
f.Set(reflect.ValueOf(VerifyClientCertIfGiven)) f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites": case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "Accept0RTTData":
f.Set(reflect.ValueOf(true)) f.Set(reflect.ValueOf(true))
case "MinVersion", "MaxVersion": case "MinVersion", "MaxVersion":
f.Set(reflect.ValueOf(uint16(VersionTLS12))) f.Set(reflect.ValueOf(uint16(VersionTLS12)))
@ -654,6 +654,8 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf([]CurveID{CurveP256})) f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "Renegotiation": case "Renegotiation":
f.Set(reflect.ValueOf(RenegotiateOnceAsClient)) f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
case "Max0RTTDataSize":
f.Set(reflect.ValueOf(uint32(0)))
default: default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn) t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
} }