crypto/tls: implement TLS 1.3 server 0-RTT
This commit is contained in:
parent
1117f76fcc
commit
f8c15889af
@ -30,8 +30,8 @@ install:
|
||||
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh INSTALL $CLIENT $REVISION; fi
|
||||
|
||||
script:
|
||||
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA
|
||||
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA
|
||||
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT; fi
|
||||
- if [ "$MODE" = "interop" ] && [ "$CLIENT" = "tstclnt" ]; then ./_dev/interop.sh 0-RTT $CLIENT; fi
|
||||
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi
|
||||
|
||||
after_script:
|
||||
|
103
13.go
103
13.go
@ -60,34 +60,33 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
|
||||
}
|
||||
c.didResume = isPSK
|
||||
|
||||
hs.finishedHash13 = hash.New()
|
||||
hs.finishedHash13.Write(hs.clientHello.marshal())
|
||||
|
||||
handshakeCtx := hs.finishedHash13.Sum(nil)
|
||||
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")
|
||||
|
||||
ecdheSecret := deriveECDHESecret(ks, privateKey)
|
||||
if ecdheSecret == nil {
|
||||
c.sendAlert(alertIllegalParameter)
|
||||
return errors.New("tls: bad ECDHE client share")
|
||||
}
|
||||
|
||||
hs.finishedHash13 = hash.New()
|
||||
hs.finishedHash13.Write(hs.clientHello.marshal())
|
||||
hs.finishedHash13.Write(hs.hello13.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
|
||||
handshakeCtx := hs.finishedHash13.Sum(nil)
|
||||
|
||||
cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize)
|
||||
cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen)
|
||||
cIV := hkdfExpandLabel(hash, cHandshakeTS, nil, "iv", 12)
|
||||
sHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "server handshake traffic secret", hashSize)
|
||||
sKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "key", hs.suite.keyLen)
|
||||
sIV := hkdfExpandLabel(hash, sHandshakeTS, nil, "iv", 12)
|
||||
|
||||
clientCipher := hs.suite.aead(cKey, cIV)
|
||||
c.in.setCipher(c.vers, clientCipher)
|
||||
serverCipher := hs.suite.aead(sKey, sIV)
|
||||
handshakeCtx = hs.finishedHash13.Sum(nil)
|
||||
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
|
||||
hs.hsClientCipher = clientCipher
|
||||
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
|
||||
c.out.setCipher(c.vers, serverCipher)
|
||||
|
||||
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
|
||||
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
|
||||
|
||||
hs.finishedHash13.Write(hs.hello13Enc.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
|
||||
return err
|
||||
@ -99,9 +98,6 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
|
||||
}
|
||||
}
|
||||
|
||||
serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize)
|
||||
hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)
|
||||
|
||||
verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
|
||||
serverFinished := &finishedMsg{
|
||||
verifyData: verifyData,
|
||||
@ -113,19 +109,20 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
|
||||
|
||||
hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
|
||||
handshakeCtx = hs.finishedHash13.Sum(nil)
|
||||
|
||||
cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize)
|
||||
cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen)
|
||||
cIV = hkdfExpandLabel(hash, cTrafficSecret0, nil, "iv", 12)
|
||||
sTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "server application traffic secret", hashSize)
|
||||
sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen)
|
||||
sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12)
|
||||
|
||||
hs.clientCipher = hs.suite.aead(cKey, cIV)
|
||||
serverCipher = hs.suite.aead(sKey, sIV)
|
||||
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
|
||||
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
|
||||
c.out.setCipher(c.vers, serverCipher)
|
||||
|
||||
c.phase = waitingClientFinished
|
||||
if hs.hello13Enc.earlyData {
|
||||
c.in.setCipher(c.vers, earlyClientCipher)
|
||||
c.phase = readingEarlyData
|
||||
} else if hs.clientHello.earlyData {
|
||||
c.in.setCipher(c.vers, hs.hsClientCipher)
|
||||
c.phase = discardingEarlyData
|
||||
} else {
|
||||
c.in.setCipher(c.vers, hs.hsClientCipher)
|
||||
c.phase = waitingClientFinished
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -157,11 +154,10 @@ func (hs *serverHandshakeState) readClientFinished13() error {
|
||||
}
|
||||
hs.finishedHash13.Write(clientFinished.marshal())
|
||||
|
||||
c.in.setCipher(c.vers, hs.clientCipher)
|
||||
|
||||
// Discard the server handshake state
|
||||
c.hs = nil
|
||||
c.phase = handshakeComplete
|
||||
c.hs = nil // Discard the server handshake state
|
||||
c.phase = handshakeConfirmed
|
||||
c.in.setCipher(c.vers, hs.appClientCipher)
|
||||
c.in.traceErr, c.out.traceErr = nil, nil
|
||||
|
||||
return hs.sendSessionTicket13()
|
||||
}
|
||||
@ -209,6 +205,15 @@ func (hs *serverHandshakeState) sendCertificate13() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) handleEndOfEarlyData() {
|
||||
if c.phase != readingEarlyData || c.vers < VersionTLS13 {
|
||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
return
|
||||
}
|
||||
c.phase = waitingClientFinished
|
||||
c.in.setCipher(c.vers, c.hs.hsClientCipher)
|
||||
}
|
||||
|
||||
// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify
|
||||
// based on the certificate type and client supported schemes. If no overlap is found,
|
||||
// a fallback is selected.
|
||||
@ -377,6 +382,14 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
|
||||
hash := hashForSuite(hs.suite)
|
||||
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
|
||||
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
|
||||
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
|
||||
return hs.suite.aead(key, iv), trafficSecret
|
||||
}
|
||||
|
||||
// Maximum allowed mismatch between the stated age of a ticket
|
||||
// and the server-observed one. See
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
|
||||
@ -418,7 +431,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
||||
if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
|
||||
continue
|
||||
}
|
||||
if s.hash != uint16(hash) {
|
||||
|
||||
// This enforces the stricter 0-RTT requirements on all ticket uses.
|
||||
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
|
||||
// we can give them up in the edge case of changed suite or ALPN.
|
||||
if s.suite != hs.suite.id {
|
||||
continue
|
||||
}
|
||||
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -433,6 +453,9 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
|
||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
|
||||
hs.hello13.psk = true
|
||||
hs.hello13.pskIdentity = uint16(i)
|
||||
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
|
||||
hs.hello13Enc.earlyData = true
|
||||
}
|
||||
return earlySecret, true
|
||||
}
|
||||
}
|
||||
@ -467,8 +490,8 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
|
||||
return err
|
||||
}
|
||||
sessionState := &sessionState13{
|
||||
vers: c.vers,
|
||||
hash: uint16(hash),
|
||||
vers: c.vers,
|
||||
suite: hs.suite.id,
|
||||
ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
|
||||
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
|
||||
createdAt: uint64(time.Now().Unix()),
|
||||
@ -481,9 +504,11 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
|
||||
return err
|
||||
}
|
||||
ticketMsg := &newSessionTicketMsg13{
|
||||
lifetime: 21600, // TODO(filippo)
|
||||
ageAdd: sessionState.ageAdd,
|
||||
ticket: ticket,
|
||||
lifetime: 24 * 3600, // TODO(filippo)
|
||||
maxEarlyDataLength: c.config.Max0RTTDataSize,
|
||||
withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
|
||||
ageAdd: sessionState.ageAdd,
|
||||
ticket: ticket,
|
||||
}
|
||||
if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil {
|
||||
return err
|
||||
|
@ -20,7 +20,9 @@ ifeq ($(shell go env CGO_ENABLED),1)
|
||||
endif
|
||||
@touch "$@"
|
||||
|
||||
GO_COMMIT := 5782050a487e002acfd14a3f4c2c815c7854928c
|
||||
# Note: when changing this, if it doesn't change the Go version
|
||||
# (it should), you need to run make clean.
|
||||
GO_COMMIT := 5d4f37266d324e48f67b3bb82c6090f1aa94c013
|
||||
|
||||
.PHONY: go
|
||||
go: go/.ok_$(GO_COMMIT)_$(GOENV)
|
||||
|
@ -11,8 +11,25 @@ if [ "$1" = "INSTALL" ]; then
|
||||
|
||||
elif [ "$1" = "RUN" ]; then
|
||||
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)
|
||||
docker run --rm tls-tris:$2 $IP:$3 | tee output.txt
|
||||
grep "Hello TLS 1.3" output.txt | grep -v "resumed"
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed"
|
||||
|
||||
docker run --rm tls-tris:$2 $IP:1443 | tee output.txt # RSA
|
||||
grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
|
||||
|
||||
docker run --rm tls-tris:$2 $IP:2443 | tee output.txt # ECDSA
|
||||
grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
|
||||
|
||||
elif [ "$1" = "0-RTT" ]; then
|
||||
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)
|
||||
|
||||
docker run --rm tls-tris:$2 $IP:3443 | tee output.txt # rejecting 0-RTT
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
|
||||
|
||||
docker run --rm tls-tris:$2 $IP:4443 | tee output.txt # accepting 0-RTT
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep "0-RTT"
|
||||
|
||||
docker run --rm tls-tris:$2 $IP:5443 | tee output.txt # confirming 0-RTT
|
||||
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"
|
||||
|
||||
fi
|
||||
|
@ -2,10 +2,13 @@ FROM scratch
|
||||
|
||||
ENV TLSDEBUG error
|
||||
|
||||
EXPOSE 443
|
||||
EXPOSE 1443
|
||||
EXPOSE 2443
|
||||
EXPOSE 3443
|
||||
EXPOSE 4443
|
||||
EXPOSE 5443
|
||||
|
||||
# GOOS=linux ../go.sh build -v -i .
|
||||
ADD tris-localserver ./
|
||||
|
||||
CMD [ "./tris-localserver", "0.0.0.0:443", "0.0.0.0:4443" ]
|
||||
CMD [ "./tris-localserver", "0.0.0.0:1443", "0.0.0.0:2443", "0.0.0.0:3443", "0.0.0.0:4443", "0.0.0.0:5443" ]
|
||||
|
@ -17,13 +17,49 @@ var tlsVersionToName = map[uint16]string{
|
||||
tls.VersionTLS13Draft18: "1.3 (draft 18)",
|
||||
}
|
||||
|
||||
func startServer(addr string, rsa, offer0RTT, accept0RTT bool) {
|
||||
cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
|
||||
if rsa {
|
||||
cert, err = tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
var Max0RTTDataSize uint32
|
||||
if offer0RTT {
|
||||
Max0RTTDataSize = 100 * 1024
|
||||
}
|
||||
s := &http.Server{
|
||||
Addr: addr,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
Max0RTTDataSize: Max0RTTDataSize,
|
||||
Accept0RTTData: accept0RTT,
|
||||
},
|
||||
}
|
||||
log.Fatal(s.ListenAndServeTLS("", ""))
|
||||
}
|
||||
|
||||
var confirmingAddr string
|
||||
|
||||
func main() {
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
tlsConn := r.Context().Value(http.TLSConnContextKey).(*tls.Conn)
|
||||
server := r.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
if server.Addr == confirmingAddr {
|
||||
if err := tlsConn.ConfirmHandshake(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
resumed := ""
|
||||
if r.TLS.DidResume {
|
||||
resumed = " [resumed]"
|
||||
}
|
||||
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed)
|
||||
with0RTT := ""
|
||||
if !tlsConn.ConnectionState().HandshakeConfirmed {
|
||||
with0RTT = " [0-RTT]"
|
||||
}
|
||||
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed, with0RTT)
|
||||
})
|
||||
|
||||
http.HandleFunc("/ch", func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -31,36 +67,17 @@ func main() {
|
||||
fmt.Fprintf(w, "Client Hello packet (%d bytes):\n%s", len(r.TLS.ClientHello), hex.Dump(r.TLS.ClientHello))
|
||||
})
|
||||
|
||||
go func() {
|
||||
if len(os.Args) < 3 {
|
||||
return
|
||||
}
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
s := &http.Server{
|
||||
Addr: os.Args[2],
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
PreferServerCipherSuites: true,
|
||||
},
|
||||
}
|
||||
log.Fatal(s.ListenAndServeTLS("", ""))
|
||||
}()
|
||||
|
||||
cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
switch len(os.Args) {
|
||||
case 2:
|
||||
startServer(os.Args[1], true, true, true)
|
||||
case 6:
|
||||
confirmingAddr = os.Args[5]
|
||||
go startServer(os.Args[1], false, false, false) // first port: ECDSA (and no 0-RTT)
|
||||
go startServer(os.Args[2], true, false, true) // second port: RSA (and accept 0-RTT but not offer it)
|
||||
go startServer(os.Args[3], false, true, false) // third port: offer and reject 0-RTT
|
||||
go startServer(os.Args[4], false, true, true) // fourth port: offer and accept 0-RTT
|
||||
startServer(os.Args[5], false, true, true) // fifth port: offer and accept 0-RTT but confirm
|
||||
}
|
||||
s := &http.Server{
|
||||
Addr: os.Args[1],
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
PreferServerCipherSuites: true,
|
||||
},
|
||||
}
|
||||
log.Fatal(s.ListenAndServeTLS("", ""))
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -5,4 +5,4 @@ shift
|
||||
HOST="${ADDR[0]}"
|
||||
PORT="${ADDR[1]}"
|
||||
|
||||
exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 "$@"
|
||||
exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 -Z "$@"
|
||||
|
1
alert.go
1
alert.go
@ -16,6 +16,7 @@ const (
|
||||
|
||||
const (
|
||||
alertCloseNotify alert = 0
|
||||
alertEndOfEarlyData alert = 1
|
||||
alertUnexpectedMessage alert = 10
|
||||
alertBadRecordMAC alert = 20
|
||||
alertDecryptionFailed alert = 21
|
||||
|
41
common.go
41
common.go
@ -85,6 +85,7 @@ const (
|
||||
extensionSessionTicket uint16 = 35
|
||||
extensionKeyShare uint16 = 40
|
||||
extensionPreSharedKey uint16 = 41
|
||||
extensionEarlyData uint16 = 42
|
||||
extensionSupportedVersions uint16 = 43
|
||||
extensionPSKKeyExchangeModes uint16 = 45
|
||||
extensionTicketEarlyDataInfo uint16 = 46
|
||||
@ -213,6 +214,10 @@ type ConnectionState struct {
|
||||
// been standardized and implemented.
|
||||
TLSUnique []byte
|
||||
|
||||
// HandshakeConfirmed is true once all data returned by Read
|
||||
// (past and future) is guaranteed not to be replayed.
|
||||
HandshakeConfirmed bool
|
||||
|
||||
ClientHello []byte // ClientHello packet
|
||||
}
|
||||
|
||||
@ -322,6 +327,18 @@ type ClientHelloInfo struct {
|
||||
// from, or write to, this connection; that will cause the TLS
|
||||
// connection to fail.
|
||||
Conn net.Conn
|
||||
|
||||
// Offered0RTTData is true if the client announced that it will send
|
||||
// 0-RTT data. If the server Config.Accept0RTTData is true, and the
|
||||
// client offered a session ticket valid for that purpose, it will
|
||||
// be notified that the 0-RTT data is accepted and it will be made
|
||||
// immediately available for Read.
|
||||
Offered0RTTData bool
|
||||
|
||||
// The Fingerprint is an sequence of bytes unique to this Client Hello.
|
||||
// It can be used to prevent or mitigate 0-RTT data replays as it's
|
||||
// guaranteed that a replayed connection will have the same Fingerprint.
|
||||
Fingerprint []byte
|
||||
}
|
||||
|
||||
// CertificateRequestInfo contains information from a server's
|
||||
@ -548,6 +565,28 @@ type Config struct {
|
||||
// used for debugging.
|
||||
KeyLogWriter io.Writer
|
||||
|
||||
// If Max0RTTDataSize is not zero, the client will be allowed to use
|
||||
// session tickets to send at most this number of bytes of 0-RTT data.
|
||||
// 0-RTT data is subject to replay and has memory DoS implications.
|
||||
// The server will later be able to refuse the 0-RTT data with
|
||||
// Accept0RTTData, or wait for the client to prove that it's not
|
||||
// replayed with Conn.ConfirmHandshake.
|
||||
//
|
||||
// It has no meaning on the client.
|
||||
//
|
||||
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
|
||||
Max0RTTDataSize uint32
|
||||
|
||||
// Accept0RTTData makes the 0-RTT data received from the client
|
||||
// immediately available to Read. 0-RTT data is subject to replay.
|
||||
// Use Conn.ConfirmHandshake to wait until the data is known not
|
||||
// to be replayed after reading it.
|
||||
//
|
||||
// It has no meaning on the client.
|
||||
//
|
||||
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
|
||||
Accept0RTTData bool
|
||||
|
||||
serverInitOnce sync.Once // guards calling (*Config).serverInit
|
||||
|
||||
// mutex protects sessionTicketKeys.
|
||||
@ -622,6 +661,8 @@ func (c *Config) Clone() *Config {
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
KeyLogWriter: c.KeyLogWriter,
|
||||
Accept0RTTData: c.Accept0RTTData,
|
||||
Max0RTTDataSize: c.Max0RTTDataSize,
|
||||
sessionTicketKeys: sessionTicketKeys,
|
||||
}
|
||||
}
|
||||
|
193
conn.go
193
conn.go
@ -27,20 +27,21 @@ type Conn struct {
|
||||
conn net.Conn
|
||||
isClient bool
|
||||
|
||||
phase handshakePhase
|
||||
|
||||
// constant after handshake; protected by handshakeMutex
|
||||
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
|
||||
// handshakeCond, if not nil, indicates that a goroutine is committed
|
||||
// to running the handshake for this Conn. Other goroutines that need
|
||||
// to wait for the handshake can wait on this, under handshakeMutex.
|
||||
handshakeCond *sync.Cond
|
||||
handshakeErr error // error resulting from handshake
|
||||
connID []byte // Random connection id
|
||||
clientHello []byte // ClientHello packet contents
|
||||
vers uint16 // TLS version
|
||||
haveVers bool // version has been negotiated
|
||||
config *Config // configuration passed to constructor
|
||||
// The transition from handshakeRunning to the next phase is covered by
|
||||
// handshakeMutex. All others by in.Mutex.
|
||||
phase handshakeStatus
|
||||
handshakeErr error // error resulting from handshake
|
||||
connID []byte // Random connection id
|
||||
clientHello []byte // ClientHello packet contents
|
||||
vers uint16 // TLS version
|
||||
haveVers bool // version has been negotiated
|
||||
config *Config // configuration passed to constructor
|
||||
// handshakes counts the number of handshakes performed on the
|
||||
// connection so far. If renegotiation is disabled then this is either
|
||||
// zero or one.
|
||||
@ -103,16 +104,22 @@ type Conn struct {
|
||||
// TLS 1.3 needs the server state until it reaches the Client Finished
|
||||
hs *serverHandshakeState
|
||||
|
||||
// earlyDataBytes is the number of bytes of early data received so
|
||||
// far. Tracked to enforce max_early_data_size.
|
||||
earlyDataBytes int64
|
||||
|
||||
tmp [16]byte
|
||||
}
|
||||
|
||||
type handshakePhase int
|
||||
type handshakeStatus int
|
||||
|
||||
const (
|
||||
earlyHandshake handshakePhase = iota
|
||||
handshakeRunning handshakeStatus = iota
|
||||
discardingEarlyData
|
||||
readingEarlyData
|
||||
waitingClientFinished
|
||||
readingClientFinished
|
||||
handshakeComplete
|
||||
handshakeConfirmed
|
||||
)
|
||||
|
||||
// Access to net.Conn methods.
|
||||
@ -548,6 +555,9 @@ func (b *block) readFromUntil(r io.Reader, n int) error {
|
||||
func (b *block) Read(p []byte) (n int, err error) {
|
||||
n = copy(p, b.data[b.off:])
|
||||
b.off += n
|
||||
if b.off >= len(b.data) {
|
||||
err = io.EOF
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -606,6 +616,7 @@ func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
|
||||
// readRecord reads the next TLS record from the connection
|
||||
// and updates the record layer state.
|
||||
// c.in.Mutex <= L; c.input == nil.
|
||||
// c.input can still be nil after a call, retry if so.
|
||||
func (c *Conn) readRecord(want recordType) error {
|
||||
// Caller must be in sync with connection:
|
||||
// handshake data if handshake not yet completed,
|
||||
@ -615,18 +626,17 @@ func (c *Conn) readRecord(want recordType) error {
|
||||
c.sendAlert(alertInternalError)
|
||||
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
|
||||
case recordTypeHandshake, recordTypeChangeCipherSpec:
|
||||
if c.phase != earlyHandshake && c.phase != readingClientFinished {
|
||||
if c.phase != handshakeRunning && c.phase != readingClientFinished {
|
||||
c.sendAlert(alertInternalError)
|
||||
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
|
||||
}
|
||||
case recordTypeApplicationData:
|
||||
if c.phase == earlyHandshake || c.phase == earlyHandshake {
|
||||
if c.phase == handshakeRunning || c.phase == readingClientFinished {
|
||||
c.sendAlert(alertInternalError)
|
||||
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
|
||||
}
|
||||
}
|
||||
|
||||
Again:
|
||||
if c.rawInput == nil {
|
||||
c.rawInput = c.in.newBlock()
|
||||
}
|
||||
@ -686,7 +696,15 @@ Again:
|
||||
// Process message.
|
||||
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
|
||||
ok, off, alertValue := c.in.decrypt(b)
|
||||
if !ok {
|
||||
switch {
|
||||
case !ok && c.phase == discardingEarlyData:
|
||||
// If the client said that it's sending early data and we did not
|
||||
// accept it, we are expected to fail decryption.
|
||||
c.in.freeBlock(b)
|
||||
return nil
|
||||
case ok && c.phase == discardingEarlyData:
|
||||
c.phase = waitingClientFinished
|
||||
case !ok:
|
||||
c.in.freeBlock(b)
|
||||
return c.in.setErrorLocked(c.sendAlert(alertValue))
|
||||
}
|
||||
@ -730,11 +748,15 @@ Again:
|
||||
c.in.setErrorLocked(io.EOF)
|
||||
break
|
||||
}
|
||||
if alert(data[1]) == alertEndOfEarlyData {
|
||||
c.handleEndOfEarlyData()
|
||||
break
|
||||
}
|
||||
switch data[0] {
|
||||
case alertLevelWarning:
|
||||
// drop on the floor
|
||||
c.in.freeBlock(b)
|
||||
goto Again
|
||||
return nil
|
||||
case alertLevelError:
|
||||
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
|
||||
default:
|
||||
@ -742,7 +764,7 @@ Again:
|
||||
}
|
||||
|
||||
case recordTypeChangeCipherSpec:
|
||||
if typ != want || len(data) != 1 || data[0] != 1 {
|
||||
if typ != want || len(data) != 1 || data[0] != 1 || c.vers >= VersionTLS13 {
|
||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
break
|
||||
}
|
||||
@ -752,11 +774,7 @@ Again:
|
||||
}
|
||||
|
||||
case recordTypeApplicationData:
|
||||
if c.phase == waitingClientFinished {
|
||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
break
|
||||
}
|
||||
if typ != want {
|
||||
if typ != want || c.phase == waitingClientFinished {
|
||||
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
|
||||
break
|
||||
}
|
||||
@ -775,7 +793,6 @@ Again:
|
||||
c.in.setErrorLocked(err)
|
||||
break
|
||||
}
|
||||
goto Again
|
||||
}
|
||||
}
|
||||
|
||||
@ -1131,7 +1148,7 @@ func (c *Conn) Write(b []byte) (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if c.phase == earlyHandshake {
|
||||
if c.phase == handshakeRunning {
|
||||
return 0, alertInternalError
|
||||
}
|
||||
|
||||
@ -1181,6 +1198,10 @@ func (c *Conn) handleRenegotiation() error {
|
||||
return c.sendAlert(alertNoRenegotiation)
|
||||
}
|
||||
|
||||
if c.vers >= VersionTLS13 {
|
||||
return c.sendAlert(alertNoRenegotiation)
|
||||
}
|
||||
|
||||
switch c.config.Renegotiation {
|
||||
case RenegotiateNever:
|
||||
return c.sendAlert(alertNoRenegotiation)
|
||||
@ -1198,13 +1219,102 @@ func (c *Conn) handleRenegotiation() error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
c.phase = earlyHandshake
|
||||
c.phase = handshakeRunning
|
||||
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
|
||||
c.handshakes++
|
||||
}
|
||||
return c.handshakeErr
|
||||
}
|
||||
|
||||
// ConfirmHandshake waits for the handshake to reach a point at which
|
||||
// the connection is certainly not replayed. That is, after receiving
|
||||
// the Client Finished.
|
||||
//
|
||||
// If ConfirmHandshake returns an error and until ConfirmHandshake
|
||||
// returns, the 0-RTT data should not be trusted not to be replayed.
|
||||
//
|
||||
// This is only meaningful in TLS 1.3 when Accept0RTTData is true and the
|
||||
// client sent valid 0-RTT data. In any other case it's equivalent to
|
||||
// calling Handshake.
|
||||
func (c *Conn) ConfirmHandshake() error {
|
||||
if err := c.Handshake(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.in.Lock()
|
||||
defer c.in.Unlock()
|
||||
|
||||
if c.phase == handshakeConfirmed {
|
||||
return nil
|
||||
}
|
||||
|
||||
var input *block
|
||||
if c.phase == readingEarlyData || c.input != nil {
|
||||
buf := &bytes.Buffer{}
|
||||
if _, err := buf.ReadFrom(earlyDataReader{c}); err != nil {
|
||||
c.in.setErrorLocked(err)
|
||||
return err
|
||||
}
|
||||
input = &block{data: buf.Bytes()}
|
||||
}
|
||||
|
||||
for c.phase != handshakeConfirmed {
|
||||
if err := c.readRecord(recordTypeApplicationData); err != nil {
|
||||
c.in.setErrorLocked(err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if c.phase != handshakeConfirmed {
|
||||
panic("should have reached handshakeConfirmed state")
|
||||
}
|
||||
if c.input != nil {
|
||||
panic("should not have read past the Client Finished")
|
||||
}
|
||||
|
||||
c.input = input
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// earlyDataReader wraps a Conn with in locked and reads only early data,
|
||||
// both buffered and still on the wire.
|
||||
type earlyDataReader struct {
|
||||
c *Conn
|
||||
}
|
||||
|
||||
func (r earlyDataReader) Read(b []byte) (n int, err error) {
|
||||
c := r.c
|
||||
|
||||
if c.phase == handshakeConfirmed {
|
||||
// c.input might not be early data
|
||||
panic("earlyDataReader called at handshakeConfirmed")
|
||||
}
|
||||
|
||||
for c.input == nil && c.in.err == nil && c.phase == readingEarlyData {
|
||||
if err := c.readRecord(recordTypeApplicationData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := c.in.err; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if c.input != nil {
|
||||
n, err = c.input.Read(b)
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
c.in.freeBlock(c.input)
|
||||
c.input = nil
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil && c.phase != readingEarlyData && c.input == nil {
|
||||
err = io.EOF
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Read can be made to time out and return a net.Error with Timeout() == true
|
||||
// after a fixed time limit; see SetDeadline and SetReadDeadline.
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
@ -1242,7 +1352,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
n, err = c.input.Read(b)
|
||||
if c.input.off >= len(c.input.data) {
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
c.in.freeBlock(c.input)
|
||||
c.input = nil
|
||||
}
|
||||
@ -1300,7 +1411,17 @@ func (c *Conn) Close() error {
|
||||
var alertErr error
|
||||
|
||||
c.handshakeMutex.Lock()
|
||||
if c.phase != earlyHandshake {
|
||||
for c.phase == readingEarlyData {
|
||||
if err := c.readRecord(recordTypeApplicationData); err != nil {
|
||||
alertErr = err
|
||||
}
|
||||
}
|
||||
if alertErr == nil && c.phase == waitingClientFinished {
|
||||
if err := c.hs.readClientFinished13(); err != nil {
|
||||
alertErr = err
|
||||
}
|
||||
}
|
||||
if alertErr == nil && c.phase != handshakeRunning {
|
||||
alertErr = c.closeNotify()
|
||||
}
|
||||
c.handshakeMutex.Unlock()
|
||||
@ -1319,7 +1440,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
|
||||
func (c *Conn) CloseWrite() error {
|
||||
c.handshakeMutex.Lock()
|
||||
defer c.handshakeMutex.Unlock()
|
||||
if c.phase == earlyHandshake {
|
||||
if c.phase == handshakeRunning {
|
||||
return errEarlyCloseWrite
|
||||
}
|
||||
|
||||
@ -1341,8 +1462,11 @@ func (c *Conn) closeNotify() error {
|
||||
// protocol if it has not yet been run.
|
||||
// Most uses of this package need not call Handshake
|
||||
// explicitly: the first Read or Write will call it automatically.
|
||||
//
|
||||
// In TLS 1.3 Handshake returns after the client and server first flights,
|
||||
// without waiting for the Client Finished.
|
||||
func (c *Conn) Handshake() error {
|
||||
// c.handshakeErr and c.phase == earlyHandshake are protected by
|
||||
// c.handshakeErr and c.phase == handshakeRunning are protected by
|
||||
// c.handshakeMutex. In order to perform a handshake, we need to lock
|
||||
// c.in also and c.handshakeMutex must be locked after c.in.
|
||||
//
|
||||
@ -1371,7 +1495,7 @@ func (c *Conn) Handshake() error {
|
||||
if err := c.handshakeErr; err != nil {
|
||||
return err
|
||||
}
|
||||
if c.phase != earlyHandshake {
|
||||
if c.phase != handshakeRunning {
|
||||
return nil
|
||||
}
|
||||
if c.handshakeCond == nil {
|
||||
@ -1393,7 +1517,7 @@ func (c *Conn) Handshake() error {
|
||||
|
||||
// The handshake cannot have completed when handshakeMutex was unlocked
|
||||
// because this goroutine set handshakeCond.
|
||||
if c.handshakeErr != nil || c.phase != earlyHandshake {
|
||||
if c.handshakeErr != nil || c.phase != handshakeRunning {
|
||||
panic("handshake should not have been able to complete after handshakeCond was set")
|
||||
}
|
||||
|
||||
@ -1415,7 +1539,7 @@ func (c *Conn) Handshake() error {
|
||||
c.flush()
|
||||
}
|
||||
|
||||
if c.handshakeErr == nil && c.phase == earlyHandshake {
|
||||
if c.handshakeErr == nil && c.phase == handshakeRunning {
|
||||
panic("handshake should have had a result.")
|
||||
}
|
||||
|
||||
@ -1433,7 +1557,7 @@ func (c *Conn) ConnectionState() ConnectionState {
|
||||
defer c.handshakeMutex.Unlock()
|
||||
|
||||
var state ConnectionState
|
||||
state.HandshakeComplete = c.phase != earlyHandshake
|
||||
state.HandshakeComplete = c.phase != handshakeRunning
|
||||
state.ServerName = c.serverName
|
||||
|
||||
if state.HandshakeComplete {
|
||||
@ -1448,6 +1572,7 @@ func (c *Conn) ConnectionState() ConnectionState {
|
||||
state.VerifiedChains = c.verifiedChains
|
||||
state.SignedCertificateTimestamps = c.scts
|
||||
state.OCSPResponse = c.ocspResponse
|
||||
state.HandshakeConfirmed = c.phase == handshakeConfirmed
|
||||
if !c.didResume {
|
||||
if c.clientFinishedIsFirst {
|
||||
state.TLSUnique = c.clientFinished[:]
|
||||
@ -1478,7 +1603,7 @@ func (c *Conn) VerifyHostname(host string) error {
|
||||
if !c.isClient {
|
||||
return errors.New("tls: VerifyHostname called on TLS server connection")
|
||||
}
|
||||
if c.phase == earlyHandshake {
|
||||
if c.phase == handshakeRunning {
|
||||
return errors.New("tls: handshake has not yet been performed")
|
||||
}
|
||||
if len(c.verifiedChains) == 0 {
|
||||
|
@ -251,7 +251,7 @@ NextCipherSuite:
|
||||
}
|
||||
|
||||
c.didResume = isResume
|
||||
c.phase = handshakeComplete
|
||||
c.phase = handshakeConfirmed
|
||||
c.cipherSuite = suite.id
|
||||
return nil
|
||||
}
|
||||
|
@ -64,12 +64,12 @@ func (i opensslInput) Read(buf []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
// opensslOutputSink is an io.Writer that receives the stdout and stderr from
|
||||
// an `openssl` process and sends a value to handshakeComplete when it sees a
|
||||
// an `openssl` process and sends a value to handshakeConfirmed when it sees a
|
||||
// log message from a completed server handshake.
|
||||
type opensslOutputSink struct {
|
||||
handshakeComplete chan struct{}
|
||||
all []byte
|
||||
line []byte
|
||||
handshakeConfirmed chan struct{}
|
||||
all []byte
|
||||
line []byte
|
||||
}
|
||||
|
||||
func newOpensslOutputSink() *opensslOutputSink {
|
||||
@ -91,7 +91,7 @@ func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
|
||||
o.handshakeComplete <- struct{}{}
|
||||
o.handshakeConfirmed <- struct{}{}
|
||||
}
|
||||
o.line = o.line[i+1:]
|
||||
}
|
||||
@ -315,9 +315,9 @@ func (test *clientTest) run(t *testing.T, write bool) {
|
||||
|
||||
for i := 1; i <= test.numRenegotiations; i++ {
|
||||
// The initial handshake will generate a
|
||||
// handshakeComplete signal which needs to be quashed.
|
||||
// handshakeConfirmed signal which needs to be quashed.
|
||||
if i == 1 && write {
|
||||
<-stdout.handshakeComplete
|
||||
<-stdout.handshakeConfirmed
|
||||
}
|
||||
|
||||
// OpenSSL will try to interleave application data and
|
||||
@ -364,7 +364,7 @@ func (test *clientTest) run(t *testing.T, write bool) {
|
||||
}()
|
||||
|
||||
if write && test.renegotiationExpectedToFail != i {
|
||||
<-stdout.handshakeComplete
|
||||
<-stdout.handshakeConfirmed
|
||||
stdin <- opensslSendSentinel
|
||||
}
|
||||
<-signalChan
|
||||
|
@ -33,6 +33,7 @@ type clientHelloMsg struct {
|
||||
supportedVersions []uint16
|
||||
psks []psk
|
||||
pskKeyExchangeModes []uint8
|
||||
earlyData bool
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||
@ -60,7 +61,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
|
||||
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
|
||||
eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
|
||||
eqKeyShares(m.keyShares, m1.keyShares) &&
|
||||
eqUint16s(m.supportedVersions, m1.supportedVersions)
|
||||
eqUint16s(m.supportedVersions, m1.supportedVersions) &&
|
||||
m.earlyData == m1.earlyData
|
||||
}
|
||||
|
||||
func (m *clientHelloMsg) marshal() []byte {
|
||||
@ -127,6 +129,9 @@ func (m *clientHelloMsg) marshal() []byte {
|
||||
extensionsLength += 1 + 2*len(m.supportedVersions)
|
||||
numExtensions++
|
||||
}
|
||||
if m.earlyData {
|
||||
numExtensions++
|
||||
}
|
||||
if numExtensions > 0 {
|
||||
extensionsLength += 4 * numExtensions
|
||||
length += 2 + extensionsLength
|
||||
@ -350,6 +355,11 @@ func (m *clientHelloMsg) marshal() []byte {
|
||||
z = z[2:]
|
||||
}
|
||||
}
|
||||
if m.earlyData {
|
||||
z[0] = byte(extensionEarlyData >> 8)
|
||||
z[1] = byte(extensionEarlyData)
|
||||
z = z[4:]
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
|
||||
@ -413,6 +423,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
m.supportedVersions = nil
|
||||
m.psks = nil
|
||||
m.pskKeyExchangeModes = nil
|
||||
m.earlyData = false
|
||||
|
||||
if len(data) == 0 {
|
||||
// ClientHello is optionally followed by extension data
|
||||
@ -668,6 +679,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
|
||||
return false
|
||||
}
|
||||
m.pskKeyExchangeModes = data[1:length]
|
||||
case extensionEarlyData:
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
|
||||
m.earlyData = true
|
||||
}
|
||||
data = data[length:]
|
||||
bindersOffset += length
|
||||
@ -1144,6 +1158,7 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
|
||||
type encryptedExtensionsMsg struct {
|
||||
raw []byte
|
||||
alpnProtocol string
|
||||
earlyData bool
|
||||
}
|
||||
|
||||
func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
|
||||
@ -1153,7 +1168,8 @@ func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
|
||||
}
|
||||
|
||||
return bytes.Equal(m.raw, m1.raw) &&
|
||||
m.alpnProtocol == m1.alpnProtocol
|
||||
m.alpnProtocol == m1.alpnProtocol &&
|
||||
m.earlyData == m1.earlyData
|
||||
}
|
||||
|
||||
func (m *encryptedExtensionsMsg) marshal() []byte {
|
||||
@ -1163,6 +1179,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
|
||||
|
||||
length := 2
|
||||
|
||||
if m.earlyData {
|
||||
length += 4
|
||||
}
|
||||
alpnLen := len(m.alpnProtocol)
|
||||
if alpnLen > 0 {
|
||||
if alpnLen >= 256 {
|
||||
@ -1196,6 +1215,12 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
|
||||
z = z[7+alpnLen:]
|
||||
}
|
||||
|
||||
if m.earlyData {
|
||||
z[0] = byte(extensionEarlyData >> 8)
|
||||
z[1] = byte(extensionEarlyData)
|
||||
z = z[4:]
|
||||
}
|
||||
|
||||
m.raw = x
|
||||
return x
|
||||
}
|
||||
@ -1205,41 +1230,55 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
|
||||
return false
|
||||
}
|
||||
m.raw = data
|
||||
l := int(data[4])<<8 | int(data[5])
|
||||
if l != len(data)-6 {
|
||||
return false
|
||||
}
|
||||
|
||||
m.alpnProtocol = ""
|
||||
if l == 0 {
|
||||
return true
|
||||
m.earlyData = false
|
||||
|
||||
extensionsLength := int(data[4])<<8 | int(data[5])
|
||||
data = data[6:]
|
||||
if len(data) != extensionsLength {
|
||||
return false
|
||||
}
|
||||
|
||||
d := data[6:]
|
||||
if len(d) < 5 {
|
||||
return false
|
||||
for len(data) != 0 {
|
||||
if len(data) < 4 {
|
||||
return false
|
||||
}
|
||||
extension := uint16(data[0])<<8 | uint16(data[1])
|
||||
length := int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if len(data) < length {
|
||||
return false
|
||||
}
|
||||
|
||||
switch extension {
|
||||
case extensionALPN:
|
||||
d := data[:length]
|
||||
if len(d) < 3 {
|
||||
return false
|
||||
}
|
||||
l := int(d[0])<<8 | int(d[1])
|
||||
if l != len(d)-2 {
|
||||
return false
|
||||
}
|
||||
d = d[2:]
|
||||
l = int(d[0])
|
||||
if l != len(d)-1 {
|
||||
return false
|
||||
}
|
||||
d = d[1:]
|
||||
if len(d) == 0 {
|
||||
// ALPN protocols must not be empty.
|
||||
return false
|
||||
}
|
||||
m.alpnProtocol = string(d)
|
||||
case extensionEarlyData:
|
||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
|
||||
m.earlyData = true
|
||||
}
|
||||
|
||||
data = data[length:]
|
||||
}
|
||||
if uint16(d[0])<<8|uint16(d[1]) != extensionALPN {
|
||||
return false
|
||||
}
|
||||
l = int(d[2])<<8 | int(d[3])
|
||||
if l != len(d)-4 {
|
||||
return false
|
||||
}
|
||||
l = int(d[4])<<8 | int(d[5])
|
||||
if l != len(d)-6 {
|
||||
return false
|
||||
}
|
||||
d = d[6:]
|
||||
l = int(d[0])
|
||||
if l != len(d)-1 {
|
||||
return false
|
||||
}
|
||||
d = d[1:]
|
||||
if len(d) == 0 {
|
||||
// ALPN protocols must not be empty.
|
||||
return false
|
||||
}
|
||||
m.alpnProtocol = string(d)
|
||||
|
||||
return true
|
||||
}
|
||||
|
@ -164,6 +164,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
for i := range m.supportedVersions {
|
||||
m.supportedVersions[i] = uint16(rand.Intn(30000))
|
||||
}
|
||||
if rand.Intn(10) > 5 {
|
||||
m.earlyData = true
|
||||
}
|
||||
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
@ -222,7 +225,12 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
|
||||
func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
m := &encryptedExtensionsMsg{}
|
||||
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
|
||||
if rand.Intn(10) > 5 {
|
||||
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
|
||||
}
|
||||
if rand.Intn(10) > 5 {
|
||||
m.earlyData = true
|
||||
}
|
||||
|
||||
return reflect.ValueOf(m)
|
||||
}
|
||||
@ -328,10 +336,11 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
|
||||
s := &sessionState13{}
|
||||
s.vers = uint16(rand.Intn(10000))
|
||||
s.hash = uint16(rand.Intn(10000))
|
||||
s.suite = uint16(rand.Intn(10000))
|
||||
s.ageAdd = uint32(rand.Intn(0xffffffff))
|
||||
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
|
||||
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
|
||||
s.alpnProtocol = randomString(rand.Intn(100), rand)
|
||||
return reflect.ValueOf(s)
|
||||
}
|
||||
|
||||
|
@ -46,7 +46,8 @@ type serverHandshakeState struct {
|
||||
hello13Enc *encryptedExtensionsMsg
|
||||
finishedHash13 hash.Hash
|
||||
clientFinishedKey []byte
|
||||
clientCipher interface{}
|
||||
hsClientCipher interface{}
|
||||
appClientCipher interface{}
|
||||
}
|
||||
|
||||
// serverHandshake performs a TLS handshake as a server.
|
||||
@ -61,7 +62,6 @@ func (c *Conn) serverHandshake() error {
|
||||
}
|
||||
c.in.traceErr = hs.traceErr
|
||||
c.out.traceErr = hs.traceErr
|
||||
defer func() { c.in.traceErr, c.out.traceErr = nil, nil }()
|
||||
isResume, err := hs.readClientHello()
|
||||
if err != nil {
|
||||
return err
|
||||
@ -105,7 +105,7 @@ func (c *Conn) serverHandshake() error {
|
||||
return err
|
||||
}
|
||||
c.didResume = true
|
||||
c.phase = handshakeComplete
|
||||
c.phase = handshakeConfirmed
|
||||
} else {
|
||||
// The client didn't include a session ticket, or it wasn't
|
||||
// valid so we do a full handshake.
|
||||
@ -129,7 +129,7 @@ func (c *Conn) serverHandshake() error {
|
||||
if _, err := c.flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.phase = handshakeComplete
|
||||
c.phase = handshakeConfirmed
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -910,6 +910,11 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
|
||||
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
|
||||
}
|
||||
|
||||
var pskBinder []byte
|
||||
if len(hs.clientHello.psks) > 0 {
|
||||
pskBinder = hs.clientHello.psks[0].binder
|
||||
}
|
||||
|
||||
hs.cachedClientHelloInfo = &ClientHelloInfo{
|
||||
CipherSuites: hs.clientHello.cipherSuites,
|
||||
ServerName: hs.clientHello.serverName,
|
||||
@ -919,6 +924,8 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
|
||||
SupportedProtos: hs.clientHello.alpnProtocols,
|
||||
SupportedVersions: supportedVersions,
|
||||
Conn: hs.c.conn,
|
||||
Offered0RTTData: hs.clientHello.earlyData,
|
||||
Fingerprint: pskBinder,
|
||||
}
|
||||
|
||||
return hs.cachedClientHelloInfo
|
||||
|
34
ticket.go
34
ticket.go
@ -131,10 +131,13 @@ func (s *sessionState) unmarshal(data []byte) bool {
|
||||
|
||||
type sessionState13 struct {
|
||||
vers uint16
|
||||
hash uint16 // crypto.Hash value
|
||||
suite uint16
|
||||
ageAdd uint32
|
||||
createdAt uint64
|
||||
resumptionSecret []byte
|
||||
alpnProtocol string
|
||||
// TODO(filippo): add and check SNI
|
||||
// TODO(filippo): add and check maxEarlyDataLength
|
||||
}
|
||||
|
||||
func (s *sessionState13) equal(i interface{}) bool {
|
||||
@ -144,19 +147,20 @@ func (s *sessionState13) equal(i interface{}) bool {
|
||||
}
|
||||
|
||||
return s.vers == s1.vers &&
|
||||
s.hash == s1.hash &&
|
||||
s.suite == s1.suite &&
|
||||
s.alpnProtocol == s1.alpnProtocol &&
|
||||
s.ageAdd == s1.ageAdd &&
|
||||
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
|
||||
}
|
||||
|
||||
func (s *sessionState13) marshal() []byte {
|
||||
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret)
|
||||
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)
|
||||
|
||||
x := make([]byte, length)
|
||||
x[0] = byte(s.vers >> 8)
|
||||
x[1] = byte(s.vers)
|
||||
x[2] = byte(s.hash >> 8)
|
||||
x[3] = byte(s.hash)
|
||||
x[2] = byte(s.suite >> 8)
|
||||
x[3] = byte(s.suite)
|
||||
x[4] = byte(s.ageAdd >> 24)
|
||||
x[5] = byte(s.ageAdd >> 16)
|
||||
x[6] = byte(s.ageAdd >> 8)
|
||||
@ -171,8 +175,11 @@ func (s *sessionState13) marshal() []byte {
|
||||
x[15] = byte(s.createdAt)
|
||||
x[16] = byte(len(s.resumptionSecret) >> 8)
|
||||
x[17] = byte(len(s.resumptionSecret))
|
||||
|
||||
copy(x[18:], s.resumptionSecret)
|
||||
z := x[18+len(s.resumptionSecret):]
|
||||
z[0] = byte(len(s.alpnProtocol) >> 8)
|
||||
z[1] = byte(len(s.alpnProtocol))
|
||||
copy(z[2:], s.alpnProtocol)
|
||||
|
||||
return x
|
||||
}
|
||||
@ -183,14 +190,19 @@ func (s *sessionState13) unmarshal(data []byte) bool {
|
||||
}
|
||||
|
||||
s.vers = uint16(data[0])<<8 | uint16(data[1])
|
||||
s.hash = uint16(data[2])<<8 | uint16(data[3])
|
||||
s.suite = uint16(data[2])<<8 | uint16(data[3])
|
||||
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
|
||||
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
|
||||
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
|
||||
l := uint16(data[16])<<8 | uint16(data[17])
|
||||
s.resumptionSecret = data[18:]
|
||||
|
||||
return int(l) == len(s.resumptionSecret)
|
||||
l := int(data[16])<<8 | int(data[17])
|
||||
if len(data) < 18+l+2 {
|
||||
return false
|
||||
}
|
||||
s.resumptionSecret = data[18 : 18+l]
|
||||
z := data[18+l:]
|
||||
l = int(z[0])<<8 | int(z[1])
|
||||
s.alpnProtocol = string(z[2:])
|
||||
return l == len(s.alpnProtocol)
|
||||
}
|
||||
|
||||
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {
|
||||
|
@ -641,7 +641,7 @@ func TestCloneNonFuncFields(t *testing.T) {
|
||||
f.Set(reflect.ValueOf("b"))
|
||||
case "ClientAuth":
|
||||
f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
|
||||
case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
|
||||
case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "Accept0RTTData":
|
||||
f.Set(reflect.ValueOf(true))
|
||||
case "MinVersion", "MaxVersion":
|
||||
f.Set(reflect.ValueOf(uint16(VersionTLS12)))
|
||||
@ -654,6 +654,8 @@ func TestCloneNonFuncFields(t *testing.T) {
|
||||
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
|
||||
case "Renegotiation":
|
||||
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
|
||||
case "Max0RTTDataSize":
|
||||
f.Set(reflect.ValueOf(uint32(0)))
|
||||
default:
|
||||
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user