From 4b0d17eca388a2e3a6b0ef9aa767a380d8ce425b Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Sun, 6 Nov 2016 09:01:12 -0800 Subject: [PATCH] crypto/tls: implement TLS 1.3 minimal server --- 13.go | 334 ++++++++++++++++++++++++++++++++ _dev/tris-localserver/server.go | 11 +- common.go | 11 +- conn.go | 8 + handshake_server.go | 86 ++++++-- hkdf.go | 58 ++++++ key_agreement.go | 10 +- prf.go | 7 + 8 files changed, 489 insertions(+), 36 deletions(-) create mode 100644 13.go create mode 100644 hkdf.go diff --git a/13.go b/13.go new file mode 100644 index 0000000..dd6d392 --- /dev/null +++ b/13.go @@ -0,0 +1,334 @@ +package tls + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/hmac" + "crypto/rsa" + "crypto/subtle" + "errors" + "io" + + "golang_org/x/crypto/curve25519" +) + +func (hs *serverHandshakeState) doTLS13Handshake() error { + config := hs.c.config + c := hs.c + + hs.c.cipherSuite, hs.hello13.cipherSuite = hs.suite.id, hs.suite.id + + // When picking the group for the handshake, priority is given to groups + // that the client provided a keyShare for, so to avoid a round-trip. + // After that the order of CurvePreferences is respected. + var ks keyShare + for _, curveID := range config.curvePreferences() { + for _, keyShare := range hs.clientHello.keyShares { + if curveID == keyShare.group { + ks = keyShare + break + } + } + } + if ks.group == 0 { + c.sendAlert(alertInternalError) + return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo) + } + + privateKey, serverKS, err := config.generateKeyShare(ks.group) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello13.keyShare = serverKS + + hash := crypto.SHA256 + if hs.suite.flags&suiteSHA384 != 0 { + hash = crypto.SHA384 + } + hashSize := hash.Size() + + ecdheSecret := deriveECDHESecret(ks, privateKey) + if ecdheSecret == nil { + c.sendAlert(alertIllegalParameter) + return errors.New("tls: bad ECDHE client share") + } + + hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.finishedHash.Write(hs.clientHello.marshal()) + hs.finishedHash.Write(hs.hello13.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello13.marshal()); err != nil { + return err + } + + earlySecret := hkdfExtract(hash, nil, nil) + handshakeSecret := hkdfExtract(hash, ecdheSecret, earlySecret) + + handshakeCtx := hs.finishedHash.Sum() + + cHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "client handshake traffic secret", hashSize) + cKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "key", hs.suite.keyLen) + cIV := hkdfExpandLabel(hash, cHandshakeTS, nil, "iv", 12) + sHandshakeTS := hkdfExpandLabel(hash, handshakeSecret, handshakeCtx, "server handshake traffic secret", hashSize) + sKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "key", hs.suite.keyLen) + sIV := hkdfExpandLabel(hash, sHandshakeTS, nil, "iv", 12) + + clientCipher := hs.suite.aead(cKey, cIV) + c.in.setCipher(c.vers, clientCipher) + serverCipher := hs.suite.aead(sKey, sIV) + c.out.setCipher(c.vers, serverCipher) + + hs.finishedHash.Write(hs.hello13Enc.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, hs.hello13Enc.marshal()); err != nil { + return err + } + + certMsg := &certificateMsg13{ + certificates: hs.cert.Certificate, + } + hs.finishedHash.Write(certMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { + return err + } + + sigScheme, err := hs.selectTLS13SignatureScheme() + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + sigHash := hashForSignatureScheme(sigScheme) + opts := crypto.SignerOpts(sigHash) + if signatureSchemeIsPSS(sigScheme) { + opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: sigHash} + } + + toSign := prepareDigitallySigned(sigHash, "TLS 1.3, server CertificateVerify", hs.finishedHash.Sum()) + signature, err := hs.cert.PrivateKey.(crypto.Signer).Sign(config.rand(), toSign[:], opts) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + + verifyMsg := &certificateVerifyMsg{ + hasSignatureAndHash: true, + signatureAndHash: sigSchemeToSigAndHash(sigScheme), + signature: signature, + } + hs.finishedHash.Write(verifyMsg.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, verifyMsg.marshal()); err != nil { + return err + } + + serverFinishedKey := hkdfExpandLabel(hash, sHandshakeTS, nil, "finished", hashSize) + clientFinishedKey := hkdfExpandLabel(hash, cHandshakeTS, nil, "finished", hashSize) + + h := hmac.New(hash.New, serverFinishedKey) + h.Write(hs.finishedHash.Sum()) + verifyData := h.Sum(nil) + serverFinished := &finishedMsg{ + verifyData: verifyData, + } + hs.finishedHash.Write(serverFinished.marshal()) + if _, err := c.writeRecord(recordTypeHandshake, serverFinished.marshal()); err != nil { + return err + } + + if _, err := c.flush(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + h = hmac.New(hash.New, clientFinishedKey) + h.Write(hs.finishedHash.Sum()) + expectedVerifyData := h.Sum(nil) + if len(expectedVerifyData) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + + masterSecret := hkdfExtract(hash, nil, handshakeSecret) + handshakeCtx = hs.finishedHash.Sum() + + cTrafficSecret0 := hkdfExpandLabel(hash, masterSecret, handshakeCtx, "client application traffic secret", hashSize) + cKey = hkdfExpandLabel(hash, cTrafficSecret0, nil, "key", hs.suite.keyLen) + cIV = hkdfExpandLabel(hash, cTrafficSecret0, nil, "iv", 12) + sTrafficSecret0 := hkdfExpandLabel(hash, masterSecret, handshakeCtx, "server application traffic secret", hashSize) + sKey = hkdfExpandLabel(hash, sTrafficSecret0, nil, "key", hs.suite.keyLen) + sIV = hkdfExpandLabel(hash, sTrafficSecret0, nil, "iv", 12) + + clientCipher = hs.suite.aead(cKey, cIV) + c.in.setCipher(c.vers, clientCipher) + serverCipher = hs.suite.aead(sKey, sIV) + c.out.setCipher(c.vers, serverCipher) + + return nil +} + +// selectTLS13SignatureScheme chooses the SignatureScheme for the CertificateVerify +// based on the certificate type and client supported schemes. If no overlap is found, +// a fallback is selected. +// +// See https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.4.1.2 +func (hs *serverHandshakeState) selectTLS13SignatureScheme() (sigScheme SignatureScheme, err error) { + var supportedSchemes []SignatureScheme + signer, ok := hs.cert.PrivateKey.(crypto.Signer) + if !ok { + return 0, errors.New("tls: certificate private key does not implement crypto.Signer") + } + pk := signer.Public() + if _, ok := pk.(*rsa.PublicKey); ok { + sigScheme = PSSWithSHA256 + supportedSchemes = []SignatureScheme{PSSWithSHA256, PSSWithSHA384, PSSWithSHA512} + } else if pk, ok := pk.(*ecdsa.PublicKey); ok { + switch pk.Curve { + case elliptic.P256(): + sigScheme = ECDSAWithP256AndSHA256 + supportedSchemes = []SignatureScheme{ECDSAWithP256AndSHA256} + case elliptic.P384(): + sigScheme = ECDSAWithP384AndSHA384 + supportedSchemes = []SignatureScheme{ECDSAWithP384AndSHA384} + case elliptic.P521(): + sigScheme = ECDSAWithP521AndSHA512 + supportedSchemes = []SignatureScheme{ECDSAWithP521AndSHA512} + default: + return 0, errors.New("tls: unknown ECDSA certificate curve") + } + } else { + return 0, errors.New("tls: unknown certificate key type") + } + + for _, ss := range supportedSchemes { + for _, cs := range hs.clientHello.signatureAndHashes { + if ss == sigAndHashToSigScheme(cs) { + return ss, nil + } + } + } + + return sigScheme, nil +} + +func sigSchemeToSigAndHash(s SignatureScheme) (sah signatureAndHash) { + sah.hash = byte(s >> 8) + sah.signature = byte(s) + return +} + +func sigAndHashToSigScheme(sah signatureAndHash) SignatureScheme { + return SignatureScheme(sah.hash)<<8 | SignatureScheme(sah.signature) +} + +func signatureSchemeIsPSS(s SignatureScheme) bool { + return s == PSSWithSHA256 || s == PSSWithSHA384 || s == PSSWithSHA512 +} + +// hashForSignatureScheme returns the Hash used by a SignatureScheme which is +// supported by selectTLS13SignatureScheme. +func hashForSignatureScheme(ss SignatureScheme) crypto.Hash { + switch ss { + case PSSWithSHA256, ECDSAWithP256AndSHA256: + return crypto.SHA256 + case PSSWithSHA384, ECDSAWithP384AndSHA384: + return crypto.SHA384 + case PSSWithSHA512, ECDSAWithP521AndSHA512: + return crypto.SHA512 + default: + panic("unsupported SignatureScheme passed to hashForSignatureScheme") + } +} + +func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byte { + message := bytes.Repeat([]byte{32}, 64) + message = append(message, context...) + message = append(message, 0) + message = append(message, data...) + h := hash.New() + h.Write(message) + return h.Sum(nil) +} + +func (c *Config) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) { + if curveID == X25519 { + var scalar, public [32]byte + if _, err := io.ReadFull(c.rand(), scalar[:]); err != nil { + return nil, keyShare{}, err + } + + curve25519.ScalarBaseMult(&public, &scalar) + return scalar[:], keyShare{group: curveID, data: public[:]}, nil + } + + curve, ok := curveForCurveID(curveID) + if !ok { + return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve") + } + + privateKey, x, y, err := elliptic.GenerateKey(curve, c.rand()) + if err != nil { + return nil, keyShare{}, err + } + ecdhePublic := elliptic.Marshal(curve, x, y) + + return privateKey, keyShare{group: curveID, data: ecdhePublic}, nil +} + +func deriveECDHESecret(ks keyShare, pk []byte) []byte { + if ks.group == X25519 { + if len(ks.data) != 32 { + return nil + } + + var theirPublic, sharedKey, scalar [32]byte + copy(theirPublic[:], ks.data) + copy(scalar[:], pk) + curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) + return sharedKey[:] + } + + curve, ok := curveForCurveID(ks.group) + if !ok { + return nil + } + x, y := elliptic.Unmarshal(curve, ks.data) + if x == nil { + return nil + } + x, _ = curve.ScalarMult(x, y, pk) + xBytes := x.Bytes() + curveSize := (curve.Params().BitSize + 8 - 1) >> 3 + if len(xBytes) == curveSize { + return xBytes + } + buf := make([]byte, curveSize) + copy(buf[len(buf)-len(xBytes):], xBytes) + return buf +} + +func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte { + hkdfLabel := make([]byte, 4+len("TLS 1.3, ")+len(label)+len(hashValue)) + hkdfLabel[0] = byte(L >> 8) + hkdfLabel[1] = byte(L) + hkdfLabel[2] = byte(len("TLS 1.3, ") + len(label)) + copy(hkdfLabel[3:], "TLS 1.3, ") + z := hkdfLabel[3+len("TLS 1.3, "):] + copy(z, label) + z = z[len(label):] + z[0] = byte(len(hashValue)) + copy(z[1:], hashValue) + + return hkdfExpand(hash, secret, hkdfLabel, L) +} diff --git a/_dev/tris-localserver/server.go b/_dev/tris-localserver/server.go index afc0872..6c3fe2a 100644 --- a/_dev/tris-localserver/server.go +++ b/_dev/tris-localserver/server.go @@ -10,12 +10,11 @@ import ( ) var tlsVersionToName = map[uint16]string{ - tls.VersionTLS10: "1.0", - tls.VersionTLS11: "1.1", - tls.VersionTLS12: "1.2", - tls.VersionTLS13: "1.3", - 0x7f00 | 16: "1.3 (draft 16)", - 0x7f00 | 18: "1.3 (draft 18)", + tls.VersionTLS10: "1.0", + tls.VersionTLS11: "1.1", + tls.VersionTLS12: "1.2", + tls.VersionTLS13: "1.3", + tls.VersionTLS13Draft18: "1.3 (draft 18)", } func main() { diff --git a/common.go b/common.go index 6146192..f9eae48 100644 --- a/common.go +++ b/common.go @@ -22,11 +22,12 @@ import ( ) const ( - VersionSSL30 = 0x0300 - VersionTLS10 = 0x0301 - VersionTLS11 = 0x0302 - VersionTLS12 = 0x0303 - VersionTLS13 = 0x0304 + VersionSSL30 = 0x0300 + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 + VersionTLS13 = 0x0304 + VersionTLS13Draft18 = 0x7f00 | 18 ) const ( diff --git a/conn.go b/conn.go index c53edc9..ea0643c 100644 --- a/conn.go +++ b/conn.go @@ -185,6 +185,14 @@ func (hc *halfConn) changeCipherSpec() error { return nil } +func (hc *halfConn) setCipher(version uint16, cipher interface{}) { + hc.version = version + hc.cipher = cipher + for i := range hc.seq { + hc.seq[i] = 0 + } +} + // incSeq increments the sequence number. func (hc *halfConn) incSeq() { for i := 7; i >= 0; i-- { diff --git a/handshake_server.go b/handshake_server.go index b97b3aa..6d120a9 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -22,6 +22,8 @@ type serverHandshakeState struct { c *Conn clientHello *clientHelloMsg hello *serverHelloMsg + hello13 *serverHelloMsg13 + hello13Enc *encryptedExtensionsMsg suite *cipherSuite ellipticOk bool ecdsaOk bool @@ -52,7 +54,11 @@ func (c *Conn) serverHandshake() error { // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 c.buffering = true - if isResume { + if hs.hello13 != nil { + if err := hs.doTLS13Handshake(); err != nil { + return err + } + } else if isResume { // The client has included a session ticket and so we do an abbreviated handshake. if err := hs.doResumeHandshake(); err != nil { return err @@ -134,14 +140,31 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { } } - c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) - if !ok { - c.sendAlert(alertProtocolVersion) - return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) + var keyShares []CurveID + for _, ks := range hs.clientHello.keyShares { + keyShares = append(keyShares, ks.group) } - c.haveVers = true - hs.hello = new(serverHelloMsg) + if hs.clientHello.supportedVersions != nil { + for _, v := range hs.clientHello.supportedVersions { + if (v >= c.config.minVersion() && v <= c.config.maxVersion()) || + v == VersionTLS13Draft18 { + c.vers = v + break + } + } + if c.vers == 0 { + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: none of the client versions (%x) are supported", hs.clientHello.supportedVersions) + } + } else { + c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) + if !ok { + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) + } + } + c.haveVers = true supportedCurve := false preferredCurves := c.config.curvePreferences() @@ -162,6 +185,8 @@ Curves: break } } + // TLS 1.3 has removed point format negotiation. + supportedPointFormat = supportedPointFormat || c.vers >= VersionTLS13 hs.ellipticOk = supportedCurve && supportedPointFormat foundCompression := false @@ -177,13 +202,9 @@ Curves: c.sendAlert(alertHandshakeFailure) return false, errors.New("tls: client does not support uncompressed connections") } - - hs.hello.vers = c.vers - hs.hello.random = make([]byte, 32) - _, err = io.ReadFull(c.config.rand(), hs.hello.random) - if err != nil { - c.sendAlert(alertInternalError) - return false, err + if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 { + c.sendAlert(alertIllegalParameter) + return false, errors.New("tls: 1.3 client offered compression") } if len(hs.clientHello.secureRenegotiation) != 0 { @@ -191,15 +212,40 @@ Curves: return false, errors.New("tls: initial handshake had non-empty renegotiation extension") } - hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported - hs.hello.compressionMethod = compressionNone + if c.vers < VersionTLS13 { + hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + _, err = io.ReadFull(c.config.rand(), hs.hello.random) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported + hs.hello.compressionMethod = compressionNone + } else { + hs.hello13 = new(serverHelloMsg13) + hs.hello13Enc = new(encryptedExtensionsMsg) + hs.hello13.vers = c.vers + hs.hello13.random = make([]byte, 32) + _, err = io.ReadFull(c.config.rand(), hs.hello13.random) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + } + if len(hs.clientHello.serverName) > 0 { c.serverName = hs.clientHello.serverName } if len(hs.clientHello.alpnProtocols) > 0 { if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { - hs.hello.alpnProtocol = selectedProto + if hs.hello != nil { + hs.hello.alpnProtocol = selectedProto + } else { + hs.hello13Enc.alpnProtocol = selectedProto + } c.clientProtocol = selectedProto } } else { @@ -207,7 +253,7 @@ Curves: // had a bug around this. Best to send nothing at all if // c.config.NextProtos is empty. See // https://golang.org/issue/5445. - if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 { + if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 && c.vers < VersionTLS13 { hs.hello.nextProtoNeg = true hs.hello.nextProtos = c.config.NextProtos } @@ -218,7 +264,7 @@ Curves: c.sendAlert(alertInternalError) return false, err } - if hs.clientHello.scts { + if hs.clientHello.scts && hs.hello != nil { // TODO: TLS 1.3 SCTs hs.hello.scts = hs.cert.SignedCertificateTimestamps } @@ -243,7 +289,7 @@ Curves: } } - if hs.checkForResumption() { + if c.vers != VersionTLS13 && hs.checkForResumption() { return true, nil } diff --git a/hkdf.go b/hkdf.go new file mode 100644 index 0000000..86db10b --- /dev/null +++ b/hkdf.go @@ -0,0 +1,58 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +// Mostly derived from golang.org/x/crypto/hkdf, but with an exposed +// Extract API. +// +// HKDF is a cryptographic key derivation function (KDF) with the goal of +// expanding limited input keying material into one or more cryptographically +// strong secret keys. +// +// RFC 5869: https://tools.ietf.org/html/rfc5869 + +import ( + "crypto" + "crypto/hmac" +) + +func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte { + var ( + expander = hmac.New(hash.New, prk) + res = make([]byte, l) + counter = byte(1) + prev []byte + ) + + if l > 255*expander.Size() { + panic("hkdf: requested too much output") + } + + p := res + for len(p) > 0 { + expander.Reset() + expander.Write(prev) + expander.Write(info) + expander.Write([]byte{counter}) + prev = expander.Sum(prev[:0]) + counter++ + n := copy(p, prev) + p = p[n:] + } + + return res +} + +func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte { + if salt == nil { + salt = make([]byte, hash.Size()) + } + if secret == nil { + secret = make([]byte, hash.Size()) + } + extractor := hmac.New(hash.New, salt) + extractor.Write(secret) + return extractor.Sum(nil) +} diff --git a/key_agreement.go b/key_agreement.go index 1b27c04..00e651c 100644 --- a/key_agreement.go +++ b/key_agreement.go @@ -323,14 +323,14 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert if x == nil { return nil, errClientKeyExchange } - if !curve.IsOnCurve(x, y) { - return nil, errClientKeyExchange - } x, _ = curve.ScalarMult(x, y, ka.privateKey) - preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3) + curveSize := (curve.Params().BitSize + 7) >> 3 xBytes := x.Bytes() + if len(xBytes) == curveSize { + return xBytes, nil + } + preMasterSecret := make([]byte, curveSize) copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) - return preMasterSecret, nil } diff --git a/prf.go b/prf.go index 5833fc1..e397bb9 100644 --- a/prf.go +++ b/prf.go @@ -122,6 +122,13 @@ var clientFinishedLabel = []byte("client finished") var serverFinishedLabel = []byte("server finished") func prfAndHashForVersion(version uint16, suite *cipherSuite) (func(result, secret, label, seed []byte), crypto.Hash) { + if version >= VersionTLS13 { + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384), crypto.SHA384 + } + return prf12(sha256.New), crypto.SHA256 + } + switch version { case VersionSSL30: return prf30, crypto.Hash(0)