瀏覽代碼

crypto/tls: implement TLS 1.3 server 0-RTT

v1.2.3
Filippo Valsorda 8 年之前
committed by Peter Wu
父節點
當前提交
f8c15889af
共有 17 個文件被更改,包括 471 次插入171 次删除
  1. +2
    -2
      .travis.yml
  2. +64
    -39
      13.go
  3. +3
    -1
      _dev/Makefile
  4. +20
    -3
      _dev/interop.sh
  5. +5
    -2
      _dev/tris-localserver/Dockerfile
  6. +47
    -30
      _dev/tris-localserver/server.go
  7. +1
    -1
      _dev/tstclnt/run.sh
  8. +1
    -0
      alert.go
  9. +41
    -0
      common.go
  10. +159
    -34
      conn.go
  11. +1
    -1
      handshake_client.go
  12. +8
    -8
      handshake_client_test.go
  13. +71
    -32
      handshake_messages.go
  14. +11
    -2
      handshake_messages_test.go
  15. +11
    -4
      handshake_server.go
  16. +23
    -11
      ticket.go
  17. +3
    -1
      tls_test.go

+ 2
- 2
.travis.yml 查看文件

@@ -30,8 +30,8 @@ install:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh INSTALL $CLIENT $REVISION; fi

script:
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 443; fi # ECDSA
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT 4443; fi # RSA
- if [ "$MODE" = "interop" ]; then ./_dev/interop.sh RUN $CLIENT; fi
- if [ "$MODE" = "interop" ] && [ "$CLIENT" = "tstclnt" ]; then ./_dev/interop.sh 0-RTT $CLIENT; fi
- if [ "$MODE" = "gotest" ]; then ./_dev/go.sh test -race crypto/tls; fi

after_script:


+ 64
- 39
13.go 查看文件

@@ -60,34 +60,33 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
}
c.didResume = isPSK

hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())

handshakeCtx := hs.finishedHash13.Sum(nil)
earlyClientCipher, _ := hs.prepareCipher(handshakeCtx, earlySecret, "client early traffic secret")

ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE client share")
}

hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal())
hs.finishedHash13.Write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err
}

handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret)
handshakeCtx := hs.finishedHash13.Sum(nil)

cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize)
cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen)
cIV := hkdfExpandLabel(hash, cHandshakeTS, nil, "iv", 12)
sHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "server handshake traffic secret", hashSize)
sKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "key", hs.suite.keyLen)
sIV := hkdfExpandLabel(hash, sHandshakeTS, nil, "iv", 12)

clientCipher := hs.suite.aead(cKey, cIV)
c.in.setCipher(c.vers, clientCipher)
serverCipher := hs.suite.aead(sKey, sIV)
handshakeCtx = hs.finishedHash13.Sum(nil)
clientCipher, cTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "client handshake traffic secret")
hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.prepareCipher(handshakeCtx, handshakeSecret, "server handshake traffic secret")
c.out.setCipher(c.vers, serverCipher)

serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)

hs.finishedHash13.Write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err
@@ -99,9 +98,6 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {
}
}

serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize)

verifyData := hmacOfSum(hash, hs.finishedHash13, serverFinishedKey)
serverFinished := &finishedMsg{
verifyData: verifyData,
@@ -113,19 +109,20 @@ func (hs *serverHandshakeState) doTLS13Handshake() error {

hs.masterSecret = hkdfExtract(hash, nil, handshakeSecret)
handshakeCtx = hs.finishedHash13.Sum(nil)

cTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "client application traffic secret", hashSize)
cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen)
cIV = hkdfExpandLabel(hash, cTrafficSecret0, nil, "iv", 12)
sTrafficSecret0 := hkdfExpandLabel(hash, hs.masterSecret, handshakeCtx, "server application traffic secret", hashSize)
sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen)
sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12)

hs.clientCipher = hs.suite.aead(cKey, cIV)
serverCipher = hs.suite.aead(sKey, sIV)
hs.appClientCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "client application traffic secret")
serverCipher, _ = hs.prepareCipher(handshakeCtx, hs.masterSecret, "server application traffic secret")
c.out.setCipher(c.vers, serverCipher)

c.phase = waitingClientFinished
if hs.hello13Enc.earlyData {
c.in.setCipher(c.vers, earlyClientCipher)
c.phase = readingEarlyData
} else if hs.clientHello.earlyData {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = discardingEarlyData
} else {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = waitingClientFinished
}

return nil
}
@@ -157,11 +154,10 @@ func (hs *serverHandshakeState) readClientFinished13() error {
}
hs.finishedHash13.Write(clientFinished.marshal())

c.in.setCipher(c.vers, hs.clientCipher)

// Discard the server handshake state
c.hs = nil
c.phase = handshakeComplete
c.hs = nil // Discard the server handshake state
c.phase = handshakeConfirmed
c.in.setCipher(c.vers, hs.appClientCipher)
c.in.traceErr, c.out.traceErr = nil, nil

return hs.sendSessionTicket13()
}
@@ -209,6 +205,15 @@ func (hs *serverHandshakeState) sendCertificate13() error {
return nil
}

func (c *Conn) handleEndOfEarlyData() {
if c.phase != readingEarlyData || c.vers < VersionTLS13 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
return
}
c.phase = waitingClientFinished
c.in.setCipher(c.vers, c.hs.hsClientCipher)
}

// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify
// based on the certificate type and client supported schemes. If no overlap is found,
// a fallback is selected.
@@ -377,6 +382,14 @@ func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
return h.Sum(nil)
}

func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label string) (interface{}, []byte) {
hash := hashForSuite(hs.suite)
trafficSecret := hkdfExpandLabel(hash, secret, handshakeCtx, label, hash.Size())
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", hs.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return hs.suite.aead(key, iv), trafficSecret
}

// Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
@@ -418,7 +431,14 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
continue
}
if s.hash != uint16(hash) {

// This enforces the stricter 0-RTT requirements on all ticket uses.
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
// we can give them up in the edge case of changed suite or ALPN.
if s.suite != hs.suite.id {
continue
}
if s.alpnProtocol != hs.hello13Enc.alpnProtocol {
continue
}

@@ -433,6 +453,9 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
if i == 0 && hs.clientHello.earlyData && hs.c.config.Accept0RTTData {
hs.hello13Enc.earlyData = true
}
return earlySecret, true
}
}
@@ -467,8 +490,8 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return err
}
sessionState := &sessionState13{
vers: c.vers,
hash: uint16(hash),
vers: c.vers,
suite: hs.suite.id,
ageAdd: uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3]),
createdAt: uint64(time.Now().Unix()),
@@ -481,9 +504,11 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
return err
}
ticketMsg := &newSessionTicketMsg13{
lifetime: 21600, // TODO(filippo)
ageAdd: sessionState.ageAdd,
ticket: ticket,
lifetime: 24 * 3600, // TODO(filippo)
maxEarlyDataLength: c.config.Max0RTTDataSize,
withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
ageAdd: sessionState.ageAdd,
ticket: ticket,
}
if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil {
return err


+ 3
- 1
_dev/Makefile 查看文件

@@ -20,7 +20,9 @@ ifeq ($(shell go env CGO_ENABLED),1)
endif
@touch "$@"

GO_COMMIT := 5782050a487e002acfd14a3f4c2c815c7854928c
# Note: when changing this, if it doesn't change the Go version
# (it should), you need to run make clean.
GO_COMMIT := 5d4f37266d324e48f67b3bb82c6090f1aa94c013

.PHONY: go
go: go/.ok_$(GO_COMMIT)_$(GOENV)


+ 20
- 3
_dev/interop.sh 查看文件

@@ -11,8 +11,25 @@ if [ "$1" = "INSTALL" ]; then

elif [ "$1" = "RUN" ]; then
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)
docker run --rm tls-tris:$2 $IP:$3 | tee output.txt
grep "Hello TLS 1.3" output.txt | grep -v "resumed"
grep "Hello TLS 1.3" output.txt | grep "resumed"

docker run --rm tls-tris:$2 $IP:1443 | tee output.txt # RSA
grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"

docker run --rm tls-tris:$2 $IP:2443 | tee output.txt # ECDSA
grep "Hello TLS 1.3" output.txt | grep -v "resumed" | grep -v "0-RTT"
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"

elif [ "$1" = "0-RTT" ]; then
IP=$(docker inspect -f '{{ .NetworkSettings.IPAddress }}' tris-localserver)

docker run --rm tls-tris:$2 $IP:3443 | tee output.txt # rejecting 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"

docker run --rm tls-tris:$2 $IP:4443 | tee output.txt # accepting 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep "0-RTT"

docker run --rm tls-tris:$2 $IP:5443 | tee output.txt # confirming 0-RTT
grep "Hello TLS 1.3" output.txt | grep "resumed" | grep -v "0-RTT"

fi

+ 5
- 2
_dev/tris-localserver/Dockerfile 查看文件

@@ -2,10 +2,13 @@ FROM scratch

ENV TLSDEBUG error

EXPOSE 443
EXPOSE 1443
EXPOSE 2443
EXPOSE 3443
EXPOSE 4443
EXPOSE 5443

# GOOS=linux ../go.sh build -v -i .
ADD tris-localserver ./

CMD [ "./tris-localserver", "0.0.0.0:443", "0.0.0.0:4443" ]
CMD [ "./tris-localserver", "0.0.0.0:1443", "0.0.0.0:2443", "0.0.0.0:3443", "0.0.0.0:4443", "0.0.0.0:5443" ]

+ 47
- 30
_dev/tris-localserver/server.go 查看文件

@@ -17,13 +17,49 @@ var tlsVersionToName = map[uint16]string{
tls.VersionTLS13Draft18: "1.3 (draft 18)",
}

func startServer(addr string, rsa, offer0RTT, accept0RTT bool) {
cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
if rsa {
cert, err = tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
}
if err != nil {
log.Fatal(err)
}
var Max0RTTDataSize uint32
if offer0RTT {
Max0RTTDataSize = 100 * 1024
}
s := &http.Server{
Addr: addr,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
Max0RTTDataSize: Max0RTTDataSize,
Accept0RTTData: accept0RTT,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
}

var confirmingAddr string

func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
tlsConn := r.Context().Value(http.TLSConnContextKey).(*tls.Conn)
server := r.Context().Value(http.ServerContextKey).(*http.Server)
if server.Addr == confirmingAddr {
if err := tlsConn.ConfirmHandshake(); err != nil {
log.Fatal(err)
}
}
resumed := ""
if r.TLS.DidResume {
resumed = " [resumed]"
}
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed)
with0RTT := ""
if !tlsConn.ConnectionState().HandshakeConfirmed {
with0RTT = " [0-RTT]"
}
fmt.Fprintf(w, "<!DOCTYPE html><p>Hello TLS %s%s%s _o/\n", tlsVersionToName[r.TLS.Version], resumed, with0RTT)
})

http.HandleFunc("/ch", func(w http.ResponseWriter, r *http.Request) {
@@ -31,36 +67,17 @@ func main() {
fmt.Fprintf(w, "Client Hello packet (%d bytes):\n%s", len(r.TLS.ClientHello), hex.Dump(r.TLS.ClientHello))
})

go func() {
if len(os.Args) < 3 {
return
}
cert, err := tls.X509KeyPair([]byte(rsaCert), []byte(rsaKey))
if err != nil {
log.Fatal(err)
}
s := &http.Server{
Addr: os.Args[2],
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
}()

cert, err := tls.X509KeyPair([]byte(ecdsaCert), []byte(ecdsaKey))
if err != nil {
log.Fatal(err)
switch len(os.Args) {
case 2:
startServer(os.Args[1], true, true, true)
case 6:
confirmingAddr = os.Args[5]
go startServer(os.Args[1], false, false, false) // first port: ECDSA (and no 0-RTT)
go startServer(os.Args[2], true, false, true) // second port: RSA (and accept 0-RTT but not offer it)
go startServer(os.Args[3], false, true, false) // third port: offer and reject 0-RTT
go startServer(os.Args[4], false, true, true) // fourth port: offer and accept 0-RTT
startServer(os.Args[5], false, true, true) // fifth port: offer and accept 0-RTT but confirm
}
s := &http.Server{
Addr: os.Args[1],
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
},
}
log.Fatal(s.ListenAndServeTLS("", ""))
}

const (


+ 1
- 1
_dev/tstclnt/run.sh 查看文件

@@ -5,4 +5,4 @@ shift
HOST="${ADDR[0]}"
PORT="${ADDR[1]}"

exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 "$@"
exec /dist/OBJ-PATH/bin/tstclnt -D -V tls1.3:tls1.3 -o -O -h $HOST -p $PORT -v -A /httpreq.txt -L 2 -Z "$@"

+ 1
- 0
alert.go 查看文件

@@ -16,6 +16,7 @@ const (

const (
alertCloseNotify alert = 0
alertEndOfEarlyData alert = 1
alertUnexpectedMessage alert = 10
alertBadRecordMAC alert = 20
alertDecryptionFailed alert = 21


+ 41
- 0
common.go 查看文件

@@ -85,6 +85,7 @@ const (
extensionSessionTicket uint16 = 35
extensionKeyShare uint16 = 40
extensionPreSharedKey uint16 = 41
extensionEarlyData uint16 = 42
extensionSupportedVersions uint16 = 43
extensionPSKKeyExchangeModes uint16 = 45
extensionTicketEarlyDataInfo uint16 = 46
@@ -213,6 +214,10 @@ type ConnectionState struct {
// been standardized and implemented.
TLSUnique []byte

// HandshakeConfirmed is true once all data returned by Read
// (past and future) is guaranteed not to be replayed.
HandshakeConfirmed bool

ClientHello []byte // ClientHello packet
}

@@ -322,6 +327,18 @@ type ClientHelloInfo struct {
// from, or write to, this connection; that will cause the TLS
// connection to fail.
Conn net.Conn

// Offered0RTTData is true if the client announced that it will send
// 0-RTT data. If the server Config.Accept0RTTData is true, and the
// client offered a session ticket valid for that purpose, it will
// be notified that the 0-RTT data is accepted and it will be made
// immediately available for Read.
Offered0RTTData bool

// The Fingerprint is an sequence of bytes unique to this Client Hello.
// It can be used to prevent or mitigate 0-RTT data replays as it's
// guaranteed that a replayed connection will have the same Fingerprint.
Fingerprint []byte
}

// CertificateRequestInfo contains information from a server's
@@ -548,6 +565,28 @@ type Config struct {
// used for debugging.
KeyLogWriter io.Writer

// If Max0RTTDataSize is not zero, the client will be allowed to use
// session tickets to send at most this number of bytes of 0-RTT data.
// 0-RTT data is subject to replay and has memory DoS implications.
// The server will later be able to refuse the 0-RTT data with
// Accept0RTTData, or wait for the client to prove that it's not
// replayed with Conn.ConfirmHandshake.
//
// It has no meaning on the client.
//
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
Max0RTTDataSize uint32

// Accept0RTTData makes the 0-RTT data received from the client
// immediately available to Read. 0-RTT data is subject to replay.
// Use Conn.ConfirmHandshake to wait until the data is known not
// to be replayed after reading it.
//
// It has no meaning on the client.
//
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-2.3.
Accept0RTTData bool

serverInitOnce sync.Once // guards calling (*Config).serverInit

// mutex protects sessionTicketKeys.
@@ -622,6 +661,8 @@ func (c *Config) Clone() *Config {
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
KeyLogWriter: c.KeyLogWriter,
Accept0RTTData: c.Accept0RTTData,
Max0RTTDataSize: c.Max0RTTDataSize,
sessionTicketKeys: sessionTicketKeys,
}
}


+ 159
- 34
conn.go 查看文件

@@ -27,20 +27,21 @@ type Conn struct {
conn net.Conn
isClient bool

phase handshakePhase

// constant after handshake; protected by handshakeMutex
handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
// handshakeCond, if not nil, indicates that a goroutine is committed
// to running the handshake for this Conn. Other goroutines that need
// to wait for the handshake can wait on this, under handshakeMutex.
handshakeCond *sync.Cond
handshakeErr error // error resulting from handshake
connID []byte // Random connection id
clientHello []byte // ClientHello packet contents
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// The transition from handshakeRunning to the next phase is covered by
// handshakeMutex. All others by in.Mutex.
phase handshakeStatus
handshakeErr error // error resulting from handshake
connID []byte // Random connection id
clientHello []byte // ClientHello packet contents
vers uint16 // TLS version
haveVers bool // version has been negotiated
config *Config // configuration passed to constructor
// handshakes counts the number of handshakes performed on the
// connection so far. If renegotiation is disabled then this is either
// zero or one.
@@ -103,16 +104,22 @@ type Conn struct {
// TLS 1.3 needs the server state until it reaches the Client Finished
hs *serverHandshakeState

// earlyDataBytes is the number of bytes of early data received so
// far. Tracked to enforce max_early_data_size.
earlyDataBytes int64

tmp [16]byte
}

type handshakePhase int
type handshakeStatus int

const (
earlyHandshake handshakePhase = iota
handshakeRunning handshakeStatus = iota
discardingEarlyData
readingEarlyData
waitingClientFinished
readingClientFinished
handshakeComplete
handshakeConfirmed
)

// Access to net.Conn methods.
@@ -548,6 +555,9 @@ func (b *block) readFromUntil(r io.Reader, n int) error {
func (b *block) Read(p []byte) (n int, err error) {
n = copy(p, b.data[b.off:])
b.off += n
if b.off >= len(b.data) {
err = io.EOF
}
return
}

@@ -606,6 +616,7 @@ func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
// readRecord reads the next TLS record from the connection
// and updates the record layer state.
// c.in.Mutex <= L; c.input == nil.
// c.input can still be nil after a call, retry if so.
func (c *Conn) readRecord(want recordType) error {
// Caller must be in sync with connection:
// handshake data if handshake not yet completed,
@@ -615,18 +626,17 @@ func (c *Conn) readRecord(want recordType) error {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
case recordTypeHandshake, recordTypeChangeCipherSpec:
if c.phase != earlyHandshake && c.phase != readingClientFinished {
if c.phase != handshakeRunning && c.phase != readingClientFinished {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
}
case recordTypeApplicationData:
if c.phase == earlyHandshake || c.phase == earlyHandshake {
if c.phase == handshakeRunning || c.phase == readingClientFinished {
c.sendAlert(alertInternalError)
return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
}
}

Again:
if c.rawInput == nil {
c.rawInput = c.in.newBlock()
}
@@ -686,7 +696,15 @@ Again:
// Process message.
b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
ok, off, alertValue := c.in.decrypt(b)
if !ok {
switch {
case !ok && c.phase == discardingEarlyData:
// If the client said that it's sending early data and we did not
// accept it, we are expected to fail decryption.
c.in.freeBlock(b)
return nil
case ok && c.phase == discardingEarlyData:
c.phase = waitingClientFinished
case !ok:
c.in.freeBlock(b)
return c.in.setErrorLocked(c.sendAlert(alertValue))
}
@@ -730,11 +748,15 @@ Again:
c.in.setErrorLocked(io.EOF)
break
}
if alert(data[1]) == alertEndOfEarlyData {
c.handleEndOfEarlyData()
break
}
switch data[0] {
case alertLevelWarning:
// drop on the floor
c.in.freeBlock(b)
goto Again
return nil
case alertLevelError:
c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
default:
@@ -742,7 +764,7 @@ Again:
}

case recordTypeChangeCipherSpec:
if typ != want || len(data) != 1 || data[0] != 1 {
if typ != want || len(data) != 1 || data[0] != 1 || c.vers >= VersionTLS13 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
@@ -752,11 +774,7 @@ Again:
}

case recordTypeApplicationData:
if c.phase == waitingClientFinished {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
if typ != want {
if typ != want || c.phase == waitingClientFinished {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
break
}
@@ -775,7 +793,6 @@ Again:
c.in.setErrorLocked(err)
break
}
goto Again
}
}

@@ -1131,7 +1148,7 @@ func (c *Conn) Write(b []byte) (int, error) {
return 0, err
}

if c.phase == earlyHandshake {
if c.phase == handshakeRunning {
return 0, alertInternalError
}

@@ -1181,6 +1198,10 @@ func (c *Conn) handleRenegotiation() error {
return c.sendAlert(alertNoRenegotiation)
}

if c.vers >= VersionTLS13 {
return c.sendAlert(alertNoRenegotiation)
}

switch c.config.Renegotiation {
case RenegotiateNever:
return c.sendAlert(alertNoRenegotiation)
@@ -1198,13 +1219,102 @@ func (c *Conn) handleRenegotiation() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()

c.phase = earlyHandshake
c.phase = handshakeRunning
if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
c.handshakes++
}
return c.handshakeErr
}

// ConfirmHandshake waits for the handshake to reach a point at which
// the connection is certainly not replayed. That is, after receiving
// the Client Finished.
//
// If ConfirmHandshake returns an error and until ConfirmHandshake
// returns, the 0-RTT data should not be trusted not to be replayed.
//
// This is only meaningful in TLS 1.3 when Accept0RTTData is true and the
// client sent valid 0-RTT data. In any other case it's equivalent to
// calling Handshake.
func (c *Conn) ConfirmHandshake() error {
if err := c.Handshake(); err != nil {
return err
}

c.in.Lock()
defer c.in.Unlock()

if c.phase == handshakeConfirmed {
return nil
}

var input *block
if c.phase == readingEarlyData || c.input != nil {
buf := &bytes.Buffer{}
if _, err := buf.ReadFrom(earlyDataReader{c}); err != nil {
c.in.setErrorLocked(err)
return err
}
input = &block{data: buf.Bytes()}
}

for c.phase != handshakeConfirmed {
if err := c.readRecord(recordTypeApplicationData); err != nil {
c.in.setErrorLocked(err)
return err
}
}

if c.phase != handshakeConfirmed {
panic("should have reached handshakeConfirmed state")
}
if c.input != nil {
panic("should not have read past the Client Finished")
}

c.input = input

return nil
}

// earlyDataReader wraps a Conn with in locked and reads only early data,
// both buffered and still on the wire.
type earlyDataReader struct {
c *Conn
}

func (r earlyDataReader) Read(b []byte) (n int, err error) {
c := r.c

if c.phase == handshakeConfirmed {
// c.input might not be early data
panic("earlyDataReader called at handshakeConfirmed")
}

for c.input == nil && c.in.err == nil && c.phase == readingEarlyData {
if err := c.readRecord(recordTypeApplicationData); err != nil {
return 0, err
}
}
if err := c.in.err; err != nil {
return 0, err
}

if c.input != nil {
n, err = c.input.Read(b)
if err == io.EOF {
err = nil
c.in.freeBlock(c.input)
c.input = nil
}
}

if err == nil && c.phase != readingEarlyData && c.input == nil {
err = io.EOF
}
return
}

// Read can be made to time out and return a net.Error with Timeout() == true
// after a fixed time limit; see SetDeadline and SetReadDeadline.
func (c *Conn) Read(b []byte) (n int, err error) {
@@ -1242,7 +1352,8 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}

n, err = c.input.Read(b)
if c.input.off >= len(c.input.data) {
if err == io.EOF {
err = nil
c.in.freeBlock(c.input)
c.input = nil
}
@@ -1300,7 +1411,17 @@ func (c *Conn) Close() error {
var alertErr error

c.handshakeMutex.Lock()
if c.phase != earlyHandshake {
for c.phase == readingEarlyData {
if err := c.readRecord(recordTypeApplicationData); err != nil {
alertErr = err
}
}
if alertErr == nil && c.phase == waitingClientFinished {
if err := c.hs.readClientFinished13(); err != nil {
alertErr = err
}
}
if alertErr == nil && c.phase != handshakeRunning {
alertErr = c.closeNotify()
}
c.handshakeMutex.Unlock()
@@ -1319,7 +1440,7 @@ var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake com
func (c *Conn) CloseWrite() error {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
if c.phase == earlyHandshake {
if c.phase == handshakeRunning {
return errEarlyCloseWrite
}

@@ -1341,8 +1462,11 @@ func (c *Conn) closeNotify() error {
// protocol if it has not yet been run.
// Most uses of this package need not call Handshake
// explicitly: the first Read or Write will call it automatically.
//
// In TLS 1.3 Handshake returns after the client and server first flights,
// without waiting for the Client Finished.
func (c *Conn) Handshake() error {
// c.handshakeErr and c.phase == earlyHandshake are protected by
// c.handshakeErr and c.phase == handshakeRunning are protected by
// c.handshakeMutex. In order to perform a handshake, we need to lock
// c.in also and c.handshakeMutex must be locked after c.in.
//
@@ -1371,7 +1495,7 @@ func (c *Conn) Handshake() error {
if err := c.handshakeErr; err != nil {
return err
}
if c.phase != earlyHandshake {
if c.phase != handshakeRunning {
return nil
}
if c.handshakeCond == nil {
@@ -1393,7 +1517,7 @@ func (c *Conn) Handshake() error {

// The handshake cannot have completed when handshakeMutex was unlocked
// because this goroutine set handshakeCond.
if c.handshakeErr != nil || c.phase != earlyHandshake {
if c.handshakeErr != nil || c.phase != handshakeRunning {
panic("handshake should not have been able to complete after handshakeCond was set")
}

@@ -1415,7 +1539,7 @@ func (c *Conn) Handshake() error {
c.flush()
}

if c.handshakeErr == nil && c.phase == earlyHandshake {
if c.handshakeErr == nil && c.phase == handshakeRunning {
panic("handshake should have had a result.")
}

@@ -1433,7 +1557,7 @@ func (c *Conn) ConnectionState() ConnectionState {
defer c.handshakeMutex.Unlock()

var state ConnectionState
state.HandshakeComplete = c.phase != earlyHandshake
state.HandshakeComplete = c.phase != handshakeRunning
state.ServerName = c.serverName

if state.HandshakeComplete {
@@ -1448,6 +1572,7 @@ func (c *Conn) ConnectionState() ConnectionState {
state.VerifiedChains = c.verifiedChains
state.SignedCertificateTimestamps = c.scts
state.OCSPResponse = c.ocspResponse
state.HandshakeConfirmed = c.phase == handshakeConfirmed
if !c.didResume {
if c.clientFinishedIsFirst {
state.TLSUnique = c.clientFinished[:]
@@ -1478,7 +1603,7 @@ func (c *Conn) VerifyHostname(host string) error {
if !c.isClient {
return errors.New("tls: VerifyHostname called on TLS server connection")
}
if c.phase == earlyHandshake {
if c.phase == handshakeRunning {
return errors.New("tls: handshake has not yet been performed")
}
if len(c.verifiedChains) == 0 {


+ 1
- 1
handshake_client.go 查看文件

@@ -251,7 +251,7 @@ NextCipherSuite:
}

c.didResume = isResume
c.phase = handshakeComplete
c.phase = handshakeConfirmed
c.cipherSuite = suite.id
return nil
}


+ 8
- 8
handshake_client_test.go 查看文件

@@ -64,12 +64,12 @@ func (i opensslInput) Read(buf []byte) (n int, err error) {
}

// opensslOutputSink is an io.Writer that receives the stdout and stderr from
// an `openssl` process and sends a value to handshakeComplete when it sees a
// an `openssl` process and sends a value to handshakeConfirmed when it sees a
// log message from a completed server handshake.
type opensslOutputSink struct {
handshakeComplete chan struct{}
all []byte
line []byte
handshakeConfirmed chan struct{}
all []byte
line []byte
}

func newOpensslOutputSink() *opensslOutputSink {
@@ -91,7 +91,7 @@ func (o *opensslOutputSink) Write(data []byte) (n int, err error) {
}

if bytes.Equal([]byte(opensslEndOfHandshake), o.line[:i]) {
o.handshakeComplete <- struct{}{}
o.handshakeConfirmed <- struct{}{}
}
o.line = o.line[i+1:]
}
@@ -315,9 +315,9 @@ func (test *clientTest) run(t *testing.T, write bool) {

for i := 1; i <= test.numRenegotiations; i++ {
// The initial handshake will generate a
// handshakeComplete signal which needs to be quashed.
// handshakeConfirmed signal which needs to be quashed.
if i == 1 && write {
<-stdout.handshakeComplete
<-stdout.handshakeConfirmed
}

// OpenSSL will try to interleave application data and
@@ -364,7 +364,7 @@ func (test *clientTest) run(t *testing.T, write bool) {
}()

if write && test.renegotiationExpectedToFail != i {
<-stdout.handshakeComplete
<-stdout.handshakeConfirmed
stdin <- opensslSendSentinel
}
<-signalChan


+ 71
- 32
handshake_messages.go 查看文件

@@ -33,6 +33,7 @@ type clientHelloMsg struct {
supportedVersions []uint16
psks []psk
pskKeyExchangeModes []uint8
earlyData bool
}

func (m *clientHelloMsg) equal(i interface{}) bool {
@@ -60,7 +61,8 @@ func (m *clientHelloMsg) equal(i interface{}) bool {
bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
eqStrings(m.alpnProtocols, m1.alpnProtocols) &&
eqKeyShares(m.keyShares, m1.keyShares) &&
eqUint16s(m.supportedVersions, m1.supportedVersions)
eqUint16s(m.supportedVersions, m1.supportedVersions) &&
m.earlyData == m1.earlyData
}

func (m *clientHelloMsg) marshal() []byte {
@@ -127,6 +129,9 @@ func (m *clientHelloMsg) marshal() []byte {
extensionsLength += 1 + 2*len(m.supportedVersions)
numExtensions++
}
if m.earlyData {
numExtensions++
}
if numExtensions > 0 {
extensionsLength += 4 * numExtensions
length += 2 + extensionsLength
@@ -350,6 +355,11 @@ func (m *clientHelloMsg) marshal() []byte {
z = z[2:]
}
}
if m.earlyData {
z[0] = byte(extensionEarlyData >> 8)
z[1] = byte(extensionEarlyData)
z = z[4:]
}

m.raw = x

@@ -413,6 +423,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
m.supportedVersions = nil
m.psks = nil
m.pskKeyExchangeModes = nil
m.earlyData = false

if len(data) == 0 {
// ClientHello is optionally followed by extension data
@@ -668,6 +679,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
return false
}
m.pskKeyExchangeModes = data[1:length]
case extensionEarlyData:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
m.earlyData = true
}
data = data[length:]
bindersOffset += length
@@ -1144,6 +1158,7 @@ func (m *serverHelloMsg13) unmarshal(data []byte) bool {
type encryptedExtensionsMsg struct {
raw []byte
alpnProtocol string
earlyData bool
}

func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
@@ -1153,7 +1168,8 @@ func (m *encryptedExtensionsMsg) equal(i interface{}) bool {
}

return bytes.Equal(m.raw, m1.raw) &&
m.alpnProtocol == m1.alpnProtocol
m.alpnProtocol == m1.alpnProtocol &&
m.earlyData == m1.earlyData
}

func (m *encryptedExtensionsMsg) marshal() []byte {
@@ -1163,6 +1179,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte {

length := 2

if m.earlyData {
length += 4
}
alpnLen := len(m.alpnProtocol)
if alpnLen > 0 {
if alpnLen >= 256 {
@@ -1196,6 +1215,12 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
z = z[7+alpnLen:]
}

if m.earlyData {
z[0] = byte(extensionEarlyData >> 8)
z[1] = byte(extensionEarlyData)
z = z[4:]
}

m.raw = x
return x
}
@@ -1205,41 +1230,55 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
return false
}
m.raw = data
l := int(data[4])<<8 | int(data[5])
if l != len(data)-6 {
return false
}

m.alpnProtocol = ""
if l == 0 {
return true
}
m.earlyData = false

d := data[6:]
if len(d) < 5 {
return false
}
if uint16(d[0])<<8|uint16(d[1]) != extensionALPN {
return false
}
l = int(d[2])<<8 | int(d[3])
if l != len(d)-4 {
return false
}
l = int(d[4])<<8 | int(d[5])
if l != len(d)-6 {
return false
}
d = d[6:]
l = int(d[0])
if l != len(d)-1 {
extensionsLength := int(data[4])<<8 | int(data[5])
data = data[6:]
if len(data) != extensionsLength {
return false
}
d = d[1:]
if len(d) == 0 {
// ALPN protocols must not be empty.
return false

for len(data) != 0 {
if len(data) < 4 {
return false
}
extension := uint16(data[0])<<8 | uint16(data[1])
length := int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < length {
return false
}

switch extension {
case extensionALPN:
d := data[:length]
if len(d) < 3 {
return false
}
l := int(d[0])<<8 | int(d[1])
if l != len(d)-2 {
return false
}
d = d[2:]
l = int(d[0])
if l != len(d)-1 {
return false
}
d = d[1:]
if len(d) == 0 {
// ALPN protocols must not be empty.
return false
}
m.alpnProtocol = string(d)
case extensionEarlyData:
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8
m.earlyData = true
}

data = data[length:]
}
m.alpnProtocol = string(d)

return true
}


+ 11
- 2
handshake_messages_test.go 查看文件

@@ -164,6 +164,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
for i := range m.supportedVersions {
m.supportedVersions[i] = uint16(rand.Intn(30000))
}
if rand.Intn(10) > 5 {
m.earlyData = true
}

return reflect.ValueOf(m)
}
@@ -222,7 +225,12 @@ func (*serverHelloMsg13) Generate(rand *rand.Rand, size int) reflect.Value {

func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value {
m := &encryptedExtensionsMsg{}
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
if rand.Intn(10) > 5 {
m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
}
if rand.Intn(10) > 5 {
m.earlyData = true
}

return reflect.ValueOf(m)
}
@@ -328,10 +336,11 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value {
func (*sessionState13) Generate(rand *rand.Rand, size int) reflect.Value {
s := &sessionState13{}
s.vers = uint16(rand.Intn(10000))
s.hash = uint16(rand.Intn(10000))
s.suite = uint16(rand.Intn(10000))
s.ageAdd = uint32(rand.Intn(0xffffffff))
s.createdAt = uint64(rand.Int63n(0xfffffffffffffff))
s.resumptionSecret = randomBytes(rand.Intn(100), rand)
s.alpnProtocol = randomString(rand.Intn(100), rand)
return reflect.ValueOf(s)
}



+ 11
- 4
handshake_server.go 查看文件

@@ -46,7 +46,8 @@ type serverHandshakeState struct {
hello13Enc *encryptedExtensionsMsg
finishedHash13 hash.Hash
clientFinishedKey []byte
clientCipher interface{}
hsClientCipher interface{}
appClientCipher interface{}
}

// serverHandshake performs a TLS handshake as a server.
@@ -61,7 +62,6 @@ func (c *Conn) serverHandshake() error {
}
c.in.traceErr = hs.traceErr
c.out.traceErr = hs.traceErr
defer func() { c.in.traceErr, c.out.traceErr = nil, nil }()
isResume, err := hs.readClientHello()
if err != nil {
return err
@@ -105,7 +105,7 @@ func (c *Conn) serverHandshake() error {
return err
}
c.didResume = true
c.phase = handshakeComplete
c.phase = handshakeConfirmed
} else {
// The client didn't include a session ticket, or it wasn't
// valid so we do a full handshake.
@@ -129,7 +129,7 @@ func (c *Conn) serverHandshake() error {
if _, err := c.flush(); err != nil {
return err
}
c.phase = handshakeComplete
c.phase = handshakeConfirmed
}

return nil
@@ -910,6 +910,11 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
}

var pskBinder []byte
if len(hs.clientHello.psks) > 0 {
pskBinder = hs.clientHello.psks[0].binder
}

hs.cachedClientHelloInfo = &ClientHelloInfo{
CipherSuites: hs.clientHello.cipherSuites,
ServerName: hs.clientHello.serverName,
@@ -919,6 +924,8 @@ func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo {
SupportedProtos: hs.clientHello.alpnProtocols,
SupportedVersions: supportedVersions,
Conn: hs.c.conn,
Offered0RTTData: hs.clientHello.earlyData,
Fingerprint: pskBinder,
}

return hs.cachedClientHelloInfo


+ 23
- 11
ticket.go 查看文件

@@ -131,10 +131,13 @@ func (s *sessionState) unmarshal(data []byte) bool {

type sessionState13 struct {
vers uint16
hash uint16 // crypto.Hash value
suite uint16
ageAdd uint32
createdAt uint64
resumptionSecret []byte
alpnProtocol string
// TODO(filippo): add and check SNI
// TODO(filippo): add and check maxEarlyDataLength
}

func (s *sessionState13) equal(i interface{}) bool {
@@ -144,19 +147,20 @@ func (s *sessionState13) equal(i interface{}) bool {
}

return s.vers == s1.vers &&
s.hash == s1.hash &&
s.suite == s1.suite &&
s.alpnProtocol == s1.alpnProtocol &&
s.ageAdd == s1.ageAdd &&
bytes.Equal(s.resumptionSecret, s1.resumptionSecret)
}

func (s *sessionState13) marshal() []byte {
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret)
length := 2 + 2 + 4 + 8 + 2 + len(s.resumptionSecret) + 2 + len(s.alpnProtocol)

x := make([]byte, length)
x[0] = byte(s.vers >> 8)
x[1] = byte(s.vers)
x[2] = byte(s.hash >> 8)
x[3] = byte(s.hash)
x[2] = byte(s.suite >> 8)
x[3] = byte(s.suite)
x[4] = byte(s.ageAdd >> 24)
x[5] = byte(s.ageAdd >> 16)
x[6] = byte(s.ageAdd >> 8)
@@ -171,8 +175,11 @@ func (s *sessionState13) marshal() []byte {
x[15] = byte(s.createdAt)
x[16] = byte(len(s.resumptionSecret) >> 8)
x[17] = byte(len(s.resumptionSecret))

copy(x[18:], s.resumptionSecret)
z := x[18+len(s.resumptionSecret):]
z[0] = byte(len(s.alpnProtocol) >> 8)
z[1] = byte(len(s.alpnProtocol))
copy(z[2:], s.alpnProtocol)

return x
}
@@ -183,14 +190,19 @@ func (s *sessionState13) unmarshal(data []byte) bool {
}

s.vers = uint16(data[0])<<8 | uint16(data[1])
s.hash = uint16(data[2])<<8 | uint16(data[3])
s.suite = uint16(data[2])<<8 | uint16(data[3])
s.ageAdd = uint32(data[4])<<24 | uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
s.createdAt = uint64(data[8])<<56 | uint64(data[9])<<48 | uint64(data[10])<<40 | uint64(data[11])<<32 |
uint64(data[12])<<24 | uint64(data[13])<<16 | uint64(data[14])<<8 | uint64(data[15])
l := uint16(data[16])<<8 | uint16(data[17])
s.resumptionSecret = data[18:]

return int(l) == len(s.resumptionSecret)
l := int(data[16])<<8 | int(data[17])
if len(data) < 18+l+2 {
return false
}
s.resumptionSecret = data[18 : 18+l]
z := data[18+l:]
l = int(z[0])<<8 | int(z[1])
s.alpnProtocol = string(z[2:])
return l == len(s.alpnProtocol)
}

func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {


+ 3
- 1
tls_test.go 查看文件

@@ -641,7 +641,7 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf("b"))
case "ClientAuth":
f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites", "Accept0RTTData":
f.Set(reflect.ValueOf(true))
case "MinVersion", "MaxVersion":
f.Set(reflect.ValueOf(uint16(VersionTLS12)))
@@ -654,6 +654,8 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
case "Renegotiation":
f.Set(reflect.ValueOf(RenegotiateOnceAsClient))
case "Max0RTTDataSize":
f.Set(reflect.ValueOf(uint32(0)))
default:
t.Errorf("all fields must be accounted for, but saw unknown field %q", fn)
}


Loading…
取消
儲存