th5/13.go
Peter Wu 6f580251ca tris: use keySchedule13 for the server
Use the new keySchedule13 type instead of hash.Hash to avoid tracking
the hashContext and intermediate secrets manually.

checkPSK is modified not to return the calculated early secret, this is
internal to keySchedule13 now. The caller just learns whether it was
resumed using a PSK or not.
2017-09-21 15:37:34 +01:00

704 lines
20 KiB
Go

package tls
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/hmac"
"crypto/rsa"
"crypto/subtle"
"encoding/hex"
"errors"
"fmt"
"hash"
"io"
"log"
"os"
"runtime"
"runtime/debug"
"strings"
"sync/atomic"
"time"
"golang_org/x/crypto/curve25519"
)
// numSessionTickets is the number of different session tickets the
// server sends to a TLS 1.3 client, who will use each only once.
const numSessionTickets = 2
type secretLabel int
const (
secretResumptionPskBinder secretLabel = iota
secretEarlyClient
secretHandshakeClient
secretHandshakeServer
secretApplicationClient
secretApplicationServer
secretResumption
)
type keySchedule13 struct {
suite *cipherSuite
transcriptHash hash.Hash // uses the cipher suite hash algo
secret []byte // Current secret as used for Derive-Secret
handshakeCtx []byte // cached handshake context, invalidated on updates.
}
func newKeySchedule13(suite *cipherSuite) *keySchedule13 {
return &keySchedule13{
suite: suite,
transcriptHash: hashForSuite(suite).New(),
}
}
// setSecret sets the early/handshake/master secret based on the given secret
// (IKM). The salt is based on previous secrets (nil for the early secret).
func (ks *keySchedule13) setSecret(secret []byte) {
hash := hashForSuite(ks.suite)
salt := ks.secret
ks.secret = hkdfExtract(hash, secret, salt)
}
// write appends the data to the transcript hash context.
func (ks *keySchedule13) write(data []byte) {
ks.handshakeCtx = nil
ks.transcriptHash.Write(data)
}
func (ks *keySchedule13) getLabel(secretLabel secretLabel) (label string) {
switch secretLabel {
case secretResumptionPskBinder:
label = "resumption psk binder key"
case secretEarlyClient:
label = "client early traffic secret"
case secretHandshakeClient:
label = "client handshake traffic secret"
case secretHandshakeServer:
label = "server handshake traffic secret"
case secretApplicationClient:
label = "client application traffic secret"
case secretApplicationServer:
label = "server application traffic secret"
case secretResumption:
label = "resumption master secret"
}
return
}
// deriveSecret returns the secret derived from the handshake context and label.
func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
label := ks.getLabel(secretLabel)
if ks.handshakeCtx == nil {
ks.handshakeCtx = ks.transcriptHash.Sum(nil)
}
hash := hashForSuite(ks.suite)
return hkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
}
func (ks *keySchedule13) prepareCipher(secretLabel secretLabel) (interface{}, []byte) {
trafficSecret := ks.deriveSecret(secretLabel)
hash := hashForSuite(ks.suite)
key := hkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", 12)
return ks.suite.aead(key, iv), trafficSecret
}
func (hs *serverHandshakeState) doTLS13Handshake() error {
config := hs.c.config
c := hs.c
hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id
hs.c.clientHello = hs.clientHello.marshal()
// When picking the group for the handshake, priority is given to groups
// that the client provided a keyShare for, so to avoid a round-trip.
// After that the order of CurvePreferences is respected.
var ks keyShare
CurvePreferenceLoop:
for _, curveID := range config.curvePreferences() {
for _, keyShare := range hs.clientHello.keyShares {
if curveID == keyShare.group {
ks = keyShare
break CurvePreferenceLoop
}
}
}
if ks.group == 0 {
c.sendAlert(alertInternalError)
return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo)
}
if committer, ok := c.conn.(Committer); ok {
if err := committer.Commit(); err != nil {
return err
}
}
privateKey, serverKS, err := config.generateKeyShare(ks.group)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello13.keyShare = serverKS
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
hs.keySchedule = newKeySchedule13(hs.suite)
// Check for PSK and update key schedule with new early secret key
isResumed, pskAlert := hs.checkPSK()
switch {
case pskAlert != alertSuccess:
c.sendAlert(pskAlert)
return errors.New("tls: invalid client PSK")
case !isResumed:
// apply an empty PSK if not resumed.
hs.keySchedule.setSecret(nil)
case isResumed:
c.didResume = true
}
hs.keySchedule.write(hs.clientHello.marshal())
earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
ecdheSecret := deriveECDHESecret(ks, privateKey)
if ecdheSecret == nil {
c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE client share")
}
hs.keySchedule.write(hs.hello13.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil {
return err
}
hs.keySchedule.setSecret(ecdheSecret)
clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
c.out.setCipher(c.vers, serverCipher)
serverFinishedKey := hkdfExpandLabel(hash, sTrafficSecret, nil, "finished", hashSize)
hs.clientFinishedKey = hkdfExpandLabel(hash, cTrafficSecret, nil, "finished", hashSize)
hs.keySchedule.write(hs.hello13Enc.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil {
return err
}
if !c.didResume {
if err := hs.sendCertificate13(); err != nil {
return err
}
}
verifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, serverFinishedKey)
serverFinished := &finishedMsg{
verifyData: verifyData,
}
hs.keySchedule.write(serverFinished.marshal())
if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil {
return err
}
hs.keySchedule.setSecret(nil) // derive master secret
hs.appClientCipher, _ = hs.keySchedule.prepareCipher(secretApplicationClient)
serverCipher, _ = hs.keySchedule.prepareCipher(secretApplicationServer)
c.out.setCipher(c.vers, serverCipher)
if c.hand.Len() > 0 {
return c.sendAlert(alertUnexpectedMessage)
}
if hs.hello13Enc.earlyData {
c.in.setCipher(c.vers, earlyClientCipher)
c.phase = readingEarlyData
} else if hs.clientHello.earlyData {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = discardingEarlyData
} else {
c.in.setCipher(c.vers, hs.hsClientCipher)
c.phase = waitingClientFinished
}
return nil
}
// readClientFinished13 is called when, on the second flight of the client,
// a handshake message is received. This might be immediately or after the
// early data. Once done it sends the session tickets. Under c.in lock.
func (hs *serverHandshakeState) readClientFinished13() error {
c := hs.c
c.phase = readingClientFinished
msg, err := c.readHandshake()
if err != nil {
return err
}
clientFinished, ok := msg.(*finishedMsg)
if !ok {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(clientFinished, msg)
}
hash := hashForSuite(hs.suite)
expectedVerifyData := hmacOfSum(hash, hs.keySchedule.transcriptHash, hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect")
}
hs.keySchedule.write(clientFinished.marshal())
c.hs = nil // Discard the server handshake state
if c.hand.Len() > 0 {
return c.sendAlert(alertUnexpectedMessage)
}
c.in.setCipher(c.vers, hs.appClientCipher)
c.in.traceErr, c.out.traceErr = nil, nil
c.phase = handshakeConfirmed
atomic.StoreInt32(&c.handshakeConfirmed, 1)
// Any read operation after handshakeRunning and before handshakeConfirmed
// will be holding this lock, which we release as soon as the confirmation
// happens, even if the Read call might do more work.
c.confirmMutex.Unlock()
return hs.sendSessionTicket13() // TODO: do in a goroutine
}
func (hs *serverHandshakeState) sendCertificate13() error {
c := hs.c
certEntries := []certificateEntry{}
for _, cert := range hs.cert.Certificate {
certEntries = append(certEntries, certificateEntry{data: cert})
}
if len(certEntries) > 0 && hs.clientHello.ocspStapling {
certEntries[0].ocspStaple = hs.cert.OCSPStaple
}
if len(certEntries) > 0 && hs.clientHello.scts {
certEntries[0].sctList = hs.cert.SignedCertificateTimestamps
}
certMsg := &certificateMsg13{certificates: certEntries}
hs.keySchedule.write(certMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
return err
}
sigScheme, err := hs.selectTLS13SignatureScheme()
if err != nil {
c.sendAlert(alertInternalError)
return err
}
sigHash := hashForSignatureScheme(sigScheme)
opts := crypto.SignerOpts(sigHash)
if signatureSchemeIsPSS(sigScheme) {
opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash}
}
toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.keySchedule.transcriptHash.Sum(nil))
signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(c.config.rand(), toSign[:], opts)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
verifyMsg := &certificateVerifyMsg{
hasSignatureAndHash: true,
signatureAndHash: sigSchemeToSigAndHash(sigScheme),
signature: signature,
}
hs.keySchedule.write(verifyMsg.marshal())
if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil {
return err
}
return nil
}
func (c *Conn) handleEndOfEarlyData() {
if c.phase != readingEarlyData || c.vers < VersionTLS13 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
return
}
c.phase = waitingClientFinished
if c.hand.Len() > 0 {
c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
return
}
c.in.setCipher(c.vers, c.hs.hsClientCipher)
}
// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify
// based on the certificate type and client supported schemes. If no overlap is found,
// a fallback is selected.
//
// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.4.1.2
func (hs *serverHandshakeState) selectTLS13SignatureScheme() (sigScheme SignatureScheme, err error) {
var supportedSchemes []SignatureScheme
signer, ok := hs.cert.PrivateKey.(crypto.Signer)
if !ok {
return 0, errors.New("tls: certificate private key does not implement crypto.Signer")
}
pk := signer.Public()
if _, ok := pk.(*rsa.PublicKey); ok {
sigScheme = PSSWithSHA256
supportedSchemes = []SignatureScheme{PSSWithSHA256, PSSWithSHA384, PSSWithSHA512}
} else if pk, ok := pk.(*ecdsa.PublicKey); ok {
switch pk.Curve {
case elliptic.P256():
sigScheme = ECDSAWithP256AndSHA256
supportedSchemes = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigScheme = ECDSAWithP384AndSHA384
supportedSchemes = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigScheme = ECDSAWithP521AndSHA512
supportedSchemes = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return 0, errors.New("tls: unknown ECDSA certificate curve")
}
} else {
return 0, errors.New("tls: unknown certificate key type")
}
for _, ss := range supportedSchemes {
for _, cs := range hs.clientHello.signatureAndHashes {
if ss == sigAndHashToSigScheme(cs) {
return ss, nil
}
}
}
return sigScheme, nil
}
func sigSchemeToSigAndHash(s SignatureScheme) (sah signatureAndHash) {
sah.hash = byte(s >> 8)
sah.signature = byte(s)
return
}
func sigAndHashToSigScheme(sah signatureAndHash) SignatureScheme {
return SignatureScheme(sah.hash)<<8 | SignatureScheme(sah.signature)
}
func signatureSchemeIsPSS(s SignatureScheme) bool {
return s == PSSWithSHA256 || s == PSSWithSHA384 || s == PSSWithSHA512
}
// hashForSignatureScheme returns the Hash used by a SignatureScheme which is
// supported by selectTLS13SignatureScheme.
func hashForSignatureScheme(ss SignatureScheme) crypto.Hash {
switch ss {
case PSSWithSHA256, ECDSAWithP256AndSHA256:
return crypto.SHA256
case PSSWithSHA384, ECDSAWithP384AndSHA384:
return crypto.SHA384
case PSSWithSHA512, ECDSAWithP521AndSHA512:
return crypto.SHA512
default:
panic("unsupported SignatureScheme passed to hashForSignatureScheme")
}
}
func hashForSuite(suite *cipherSuite) crypto.Hash {
if suite.flags&suiteSHA384 != 0 {
return crypto.SHA384
}
return crypto.SHA256
}
func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte {
message := bytes.Repeat([]byte{32}, 64)
message = append(message, context...)
message = append(message, 0)
message = append(message, data...)
h := hash.New()
h.Write(message)
return h.Sum(nil)
}
func (c *Config) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
if curveID == X25519 {
var scalar, public [32]byte
if _, err := io.ReadFull(c.rand(), scalar[:]); err != nil {
return nil, keyShare{}, err
}
curve25519.ScalarBaseMult(&public, &scalar)
return scalar[:], keyShare{group: curveID, data: public[:]}, nil
}
curve, ok := curveForCurveID(curveID)
if !ok {
return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve")
}
privateKey, x, y, err := elliptic.GenerateKey(curve, c.rand())
if err != nil {
return nil, keyShare{}, err
}
ecdhePublic := elliptic.Marshal(curve, x, y)
return privateKey, keyShare{group: curveID, data: ecdhePublic}, nil
}
func deriveECDHESecret(ks keyShare, secretKey []byte) []byte {
if ks.group == X25519 {
if len(ks.data) != 32 {
return nil
}
var theirPublic, sharedKey, scalar [32]byte
copy(theirPublic[:], ks.data)
copy(scalar[:], secretKey)
curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
return sharedKey[:]
}
curve, ok := curveForCurveID(ks.group)
if !ok {
return nil
}
x, y := elliptic.Unmarshal(curve, ks.data)
if x == nil {
return nil
}
x, _ = curve.ScalarMult(x, y, secretKey)
xBytes := x.Bytes()
curveSize := (curve.Params().BitSize + 8 - 1) >> 3
if len(xBytes) == curveSize {
return xBytes
}
buf := make([]byte, curveSize)
copy(buf[len(buf)-len(xBytes):], xBytes)
return buf
}
func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
hkdfLabel := make([]byte, 4+len("TLS 1.3, ")+len(label)+len(hashValue))
hkdfLabel[0] = byte(L >> 8)
hkdfLabel[1] = byte(L)
hkdfLabel[2] = byte(len("TLS 1.3, ") + len(label))
copy(hkdfLabel[3:], "TLS 1.3, ")
z := hkdfLabel[3+len("TLS 1.3, "):]
copy(z, label)
z = z[len(label):]
z[0] = byte(len(hashValue))
copy(z[1:], hashValue)
return hkdfExpand(hash, secret, hkdfLabel, L)
}
func hmacOfSum(f crypto.Hash, hash hash.Hash, key []byte) []byte {
h := hmac.New(f.New, key)
h.Write(hash.Sum(nil))
return h.Sum(nil)
}
// Maximum allowed mismatch between the stated age of a ticket
// and the server-observed one. See
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second
// checkPSK tries to resume using a PSK, returning true (and updating the
// early secret in the key schedule) if the PSK was used and false otherwise.
func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
if hs.c.config.SessionTicketsDisabled {
return false, alertSuccess
}
foundDHE := false
for _, mode := range hs.clientHello.pskKeyExchangeModes {
if mode == pskDHEKeyExchange {
foundDHE = true
break
}
}
if !foundDHE {
return false, alertSuccess
}
hash := hashForSuite(hs.suite)
hashSize := hash.Size()
for i := range hs.clientHello.psks {
sessionTicket := append([]uint8{}, hs.clientHello.psks[i].identity...)
if hs.c.config.SessionTicketSealer != nil {
var ok bool
sessionTicket, ok = hs.c.config.SessionTicketSealer.Unseal(hs.clientHelloInfo(), sessionTicket)
if !ok {
continue
}
} else {
sessionTicket, _ = hs.c.decryptTicket(sessionTicket)
if sessionTicket == nil {
continue
}
}
s := &sessionState13{}
if s.unmarshal(sessionTicket) != alertSuccess {
continue
}
if s.vers != hs.c.vers {
continue
}
clientAge := time.Duration(hs.clientHello.psks[i].obfTicketAge-s.ageAdd) * time.Millisecond
serverAge := time.Since(time.Unix(int64(s.createdAt), 0))
if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
// XXX: NSS is off spec and sends obfuscated_ticket_age as seconds
clientAge = time.Duration(hs.clientHello.psks[i].obfTicketAge-s.ageAdd) * time.Second
if clientAge-serverAge > ticketAgeSkewAllowance || clientAge-serverAge < -ticketAgeSkewAllowance {
continue
}
}
// This enforces the stricter 0-RTT requirements on all ticket uses.
// The benefit of using PSK+ECDHE without 0-RTT are small enough that
// we can give them up in the edge case of changed suite or ALPN or SNI.
if s.suite != hs.suite.id {
continue
}
if s.alpnProtocol != hs.c.clientProtocol {
continue
}
if s.SNI != hs.c.serverName {
continue
}
hs.keySchedule.setSecret(s.resumptionSecret)
binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
chHash := hash.New()
chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
return false, alertDecryptError
}
if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such.
return false, alertIllegalParameter
}
if hs.c.config.Accept0RTTData {
hs.c.binder = expectedBinder
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
hs.hello13Enc.earlyData = true
}
}
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
return true, alertSuccess
}
return false, alertSuccess
}
func (hs *serverHandshakeState) sendSessionTicket13() error {
c := hs.c
if c.config.SessionTicketsDisabled {
return nil
}
foundDHE := false
for _, mode := range hs.clientHello.pskKeyExchangeModes {
if mode == pskDHEKeyExchange {
foundDHE = true
break
}
}
if !foundDHE {
return nil
}
resumptionSecret := hs.keySchedule.deriveSecret(secretResumption)
ageAddBuf := make([]byte, 4)
sessionState := &sessionState13{
vers: c.vers,
suite: hs.suite.id,
createdAt: uint64(time.Now().Unix()),
resumptionSecret: resumptionSecret,
alpnProtocol: c.clientProtocol,
SNI: c.serverName,
maxEarlyDataLen: c.config.Max0RTTDataSize,
}
for i := 0; i < numSessionTickets; i++ {
if _, err := io.ReadFull(c.config.rand(), ageAddBuf); err != nil {
c.sendAlert(alertInternalError)
return err
}
sessionState.ageAdd = uint32(ageAddBuf[0])<<24 | uint32(ageAddBuf[1])<<16 |
uint32(ageAddBuf[2])<<8 | uint32(ageAddBuf[3])
ticket := sessionState.marshal()
var err error
if c.config.SessionTicketSealer != nil {
cs := c.ConnectionState()
ticket, err = c.config.SessionTicketSealer.Seal(&cs, ticket)
} else {
ticket, err = c.encryptTicket(ticket)
}
if err != nil {
c.sendAlert(alertInternalError)
return err
}
if ticket == nil {
continue
}
ticketMsg := &newSessionTicketMsg13{
lifetime: 24 * 3600, // TODO(filippo)
maxEarlyDataLength: c.config.Max0RTTDataSize,
withEarlyDataInfo: c.config.Max0RTTDataSize > 0,
ageAdd: sessionState.ageAdd,
ticket: ticket,
}
if _, err := c.writeRecord(recordTypeHandshake, ticketMsg.marshal()); err != nil {
return err
}
}
return nil
}
func (hs *serverHandshakeState) traceErr(err error) {
if err == nil {
return
}
if os.Getenv("TLSDEBUG") == "error" {
if hs != nil && hs.clientHello != nil {
os.Stderr.WriteString(hex.Dump(hs.clientHello.marshal()))
} else if err == io.EOF {
return // don't stack trace on EOF before CH
}
fmt.Fprintf(os.Stderr, "\n%s\n", debug.Stack())
}
if os.Getenv("TLSDEBUG") == "short" {
var pcs [4]uintptr
frames := runtime.CallersFrames(pcs[0:runtime.Callers(3, pcs[:])])
for {
frame, more := frames.Next()
if frame.Function != "crypto/tls.(*halfConn).setErrorLocked" &&
frame.Function != "crypto/tls.(*Conn).sendAlertLocked" &&
frame.Function != "crypto/tls.(*Conn).sendAlert" {
file := frame.File[strings.LastIndex(frame.File, "/")+1:]
log.Printf("%s:%d (%s): %v", file, frame.Line, frame.Function, err)
return
}
if !more {
break
}
}
}
}