diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..1e0cf23 --- /dev/null +++ b/auth.go @@ -0,0 +1,93 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "encoding/asn1" + "errors" + "fmt" +) + +// pickSignatureAlgorithm selects a signature algorithm that is compatible with +// the given public key and the list of algorithms from the peer and this side. +// +// The returned SignatureScheme codepoint is only meaningful for TLS 1.2, +// previous TLS versions have a fixed hash function. +func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (SignatureScheme, uint8, crypto.Hash, error) { + if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 { + // If the client didn't specify any signature_algorithms + // extension then we can assume that it supports SHA1. See + // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + switch pubkey.(type) { + case *rsa.PublicKey: + if tlsVersion < VersionTLS12 { + return 0, signatureRSA, crypto.MD5SHA1, nil + } else { + return PKCS1WithSHA1, signatureRSA, crypto.SHA1, nil + } + case *ecdsa.PublicKey: + return ECDSAWithSHA1, signatureECDSA, crypto.SHA1, nil + default: + return 0, 0, 0, fmt.Errorf("tls: unsupported public key: %T", pubkey) + } + } + for _, sigAlg := range peerSigAlgs { + if !isSupportedSignatureAlgorithm(sigAlg, ourSigAlgs) { + continue + } + hashAlg, err := lookupTLSHash(sigAlg) + if err != nil { + panic("tls: supported signature algorithm has an unknown hash function") + } + sigType := signatureFromSignatureScheme(sigAlg) + switch pubkey.(type) { + case *rsa.PublicKey: + if sigType == signatureRSA { + return sigAlg, sigType, hashAlg, nil + } + case *ecdsa.PublicKey: + if sigType == signatureECDSA { + return sigAlg, sigType, hashAlg, nil + } + } + } + return 0, 0, 0, errors.New("tls: peer doesn't support any common signature algorithms") +} + +// verifyHandshakeSignature verifies a signature against pre-hashed handshake +// contents. +func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, digest, sig []byte) error { + switch sigType { + case signatureECDSA: + pubKey, ok := pubkey.(*ecdsa.PublicKey) + if !ok { + return errors.New("tls: ECDSA signing requires a ECDSA public key") + } + ecdsaSig := new(ecdsaSignature) + if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("tls: ECDSA signature contained zero or negative values") + } + if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) { + return errors.New("tls: ECDSA verification failure") + } + case signatureRSA: + pubKey, ok := pubkey.(*rsa.PublicKey) + if !ok { + return errors.New("tls: RSA signing requires a RSA public key") + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil { + return err + } + default: + return errors.New("tls: unknown signature algorithm") + } + return nil +} diff --git a/handshake_client.go b/handshake_client.go index cdeacc0..e904518 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -485,26 +485,16 @@ func (hs *clientHandshakeState) doFullHandshake() error { return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey) } - var signatureType uint8 - switch key.Public().(type) { - case *ecdsa.PublicKey: - signatureType = signatureECDSA - case *rsa.PublicKey: - signatureType = signatureRSA - default: + signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(key.Public(), certReq.supportedSignatureAlgorithms, hs.hello.supportedSignatureAlgorithms, c.vers) + if err != nil { c.sendAlert(alertInternalError) - return fmt.Errorf("tls: failed to sign handshake with client certificate: unknown client certificate key type: %T", key) + return err } - // SignatureAndHashAlgorithm was introduced in TLS 1.2. if certVerify.hasSignatureAndHash { - certVerify.signatureAlgorithm, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.supportedSignatureAlgorithms, signatureType) - if err != nil { - c.sendAlert(alertInternalError) - return err - } + certVerify.signatureAlgorithm = signatureAlgorithm } - digest, hashFunc, err := hs.finishedHash.hashForClientCertificate(signatureType, certVerify.signatureAlgorithm, hs.masterSecret) + digest, err := hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret) if err != nil { c.sendAlert(alertInternalError) return err diff --git a/handshake_server.go b/handshake_server.go index f294a80..fc129ec 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -10,7 +10,6 @@ import ( "crypto/rsa" "crypto/subtle" "crypto/x509" - "encoding/asn1" "errors" "fmt" "io" @@ -605,59 +604,15 @@ func (hs *serverHandshakeState) doFullHandshake() error { } // Determine the signature type. - var signatureAlgorithm SignatureScheme - var sigType uint8 - if certVerify.hasSignatureAndHash { - signatureAlgorithm = certVerify.signatureAlgorithm - if !isSupportedSignatureAlgorithm(signatureAlgorithm, supportedSignatureAlgorithms) { - return errors.New("tls: unsupported hash function for client certificate") - } - sigType = signatureFromSignatureScheme(signatureAlgorithm) - } else { - // Before TLS 1.2 the signature algorithm was implicit - // from the key type, and only one hash per signature - // algorithm was possible. Leave signatureAlgorithm - // unset. - switch pub.(type) { - case *ecdsa.PublicKey: - sigType = signatureECDSA - case *rsa.PublicKey: - sigType = signatureRSA - } + _, sigType, hashFunc, err := pickSignatureAlgorithm(pub, []SignatureScheme{certVerify.signatureAlgorithm}, supportedSignatureAlgorithms, c.vers) + if err != nil { + c.sendAlert(alertIllegalParameter) + return err } - switch key := pub.(type) { - case *ecdsa.PublicKey: - if sigType != signatureECDSA { - err = errors.New("tls: bad signature type for client's ECDSA certificate") - break - } - ecdsaSig := new(ecdsaSignature) - if _, err = asn1.Unmarshal(certVerify.signature, ecdsaSig); err != nil { - break - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - err = errors.New("tls: ECDSA signature contained zero or negative values") - break - } - var digest []byte - if digest, _, err = hs.finishedHash.hashForClientCertificate(sigType, signatureAlgorithm, hs.masterSecret); err != nil { - break - } - if !ecdsa.Verify(key, digest, ecdsaSig.R, ecdsaSig.S) { - err = errors.New("tls: ECDSA verification failure") - } - case *rsa.PublicKey: - if sigType != signatureRSA { - err = errors.New("tls: bad signature type for client's RSA certificate") - break - } - var digest []byte - var hashFunc crypto.Hash - if digest, hashFunc, err = hs.finishedHash.hashForClientCertificate(sigType, signatureAlgorithm, hs.masterSecret); err != nil { - break - } - err = rsa.VerifyPKCS1v15(key, hashFunc, digest, certVerify.signature) + var digest []byte + if digest, err = hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret); err == nil { + err = verifyHandshakeSignature(sigType, pub, hashFunc, digest, certVerify.signature) } if err != nil { c.sendAlert(alertBadCertificate) diff --git a/key_agreement.go b/key_agreement.go index c825fc5..d161fe8 100644 --- a/key_agreement.go +++ b/key_agreement.go @@ -6,13 +6,11 @@ package tls import ( "crypto" - "crypto/ecdsa" "crypto/elliptic" "crypto/md5" "crypto/rsa" "crypto/sha1" "crypto/x509" - "encoding/asn1" "errors" "io" "math/big" @@ -110,58 +108,20 @@ func md5SHA1Hash(slices [][]byte) []byte { } // hashForServerKeyExchange hashes the given slices and returns their digest -// and the identifier of the hash function used. The signatureAlgorithm argument -// is only used for >= TLS 1.2 and identifies the hash function to use. -func hashForServerKeyExchange(sigType uint8, signatureAlgorithm SignatureScheme, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) { +// using the given hash function. +func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) { if version >= VersionTLS12 { - if !isSupportedSignatureAlgorithm(signatureAlgorithm, supportedSignatureAlgorithms) { - return nil, crypto.Hash(0), errors.New("tls: unsupported hash function used by peer") - } - hashFunc, err := lookupTLSHash(signatureAlgorithm) - if err != nil { - return nil, crypto.Hash(0), err - } h := hashFunc.New() for _, slice := range slices { h.Write(slice) } digest := h.Sum(nil) - return digest, hashFunc, nil + return digest, nil } if sigType == signatureECDSA { - return sha1Hash(slices), crypto.SHA1, nil + return sha1Hash(slices), nil } - return md5SHA1Hash(slices), crypto.MD5SHA1, nil -} - -// pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a -// ServerKeyExchange given the signature type being used and the client's -// advertised list of supported signature and hash combinations. -func pickTLS12HashForSignature(sigType uint8, clientList []SignatureScheme) (SignatureScheme, error) { - if len(clientList) == 0 { - // If the client didn't specify any signature_algorithms - // extension then we can assume that it supports SHA1. See - // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 - switch sigType { - case signatureRSA: - return PKCS1WithSHA1, nil - case signatureECDSA: - return ECDSAWithSHA1, nil - default: - return 0, errors.New("tls: unknown signature algorithm") - } - } - - for _, sigAlg := range clientList { - if signatureFromSignatureScheme(sigAlg) != sigType { - continue - } - if isSupportedSignatureAlgorithm(sigAlg, supportedSignatureAlgorithms) { - return sigAlg, nil - } - } - - return 0, errors.New("tls: client doesn't support any common hash functions") + return md5SHA1Hash(slices), nil } func curveForCurveID(id CurveID) (elliptic.Curve, bool) { @@ -247,40 +207,25 @@ NextCandidate: serverECDHParams[3] = byte(len(ecdhePublic)) copy(serverECDHParams[4:], ecdhePublic) - var signatureAlgorithm SignatureScheme - - if ka.version >= VersionTLS12 { - var err error - signatureAlgorithm, err = pickTLS12HashForSignature(ka.sigType, clientHello.supportedSignatureAlgorithms) - if err != nil { - return nil, err - } - } - - digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, hello.random, serverECDHParams) - if err != nil { - return nil, err - } - priv, ok := cert.PrivateKey.(crypto.Signer) if !ok { return nil, errors.New("tls: certificate private key does not implement crypto.Signer") } - var sig []byte - switch ka.sigType { - case signatureECDSA: - _, ok := priv.Public().(*ecdsa.PublicKey) - if !ok { - return nil, errors.New("tls: ECDHE ECDSA requires an ECDSA server key") - } - case signatureRSA: - _, ok := priv.Public().(*rsa.PublicKey) - if !ok { - return nil, errors.New("tls: ECDHE RSA requires a RSA server key") - } - default: - return nil, errors.New("tls: unknown ECDHE signature algorithm") + + signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version) + if err != nil { + return nil, err } + if sigType != ka.sigType { + return nil, errors.New("tls: certificate cannot be used with the selected cipher suite") + } + + digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams) + if err != nil { + return nil, err + } + + var sig []byte sig, err = priv.Sign(config.rand(), digest, hashFunc) if err != nil { return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error()) @@ -383,53 +328,30 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell if ka.version >= VersionTLS12 { // handle SignatureAndHashAlgorithm signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1]) - if signatureFromSignatureScheme(signatureAlgorithm) != ka.sigType { - return errServerKeyExchange - } sig = sig[2:] if len(sig) < 2 { return errServerKeyExchange } } + _, sigType, hashFunc, err := pickSignatureAlgorithm(cert.PublicKey, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version) + if err != nil { + return err + } + if sigType != ka.sigType { + return errServerKeyExchange + } + sigLen := int(sig[0])<<8 | int(sig[1]) if sigLen+2 != len(sig) { return errServerKeyExchange } sig = sig[2:] - digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, signatureAlgorithm, ka.version, clientHello.random, serverHello.random, serverECDHParams) + digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams) if err != nil { return err } - switch ka.sigType { - case signatureECDSA: - pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey) - if !ok { - return errors.New("tls: ECDHE ECDSA requires a ECDSA server public key") - } - ecdsaSig := new(ecdsaSignature) - if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { - return err - } - if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { - return errors.New("tls: ECDSA signature contained zero or negative values") - } - if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) { - return errors.New("tls: ECDSA verification failure") - } - case signatureRSA: - pubKey, ok := cert.PublicKey.(*rsa.PublicKey) - if !ok { - return errors.New("tls: ECDHE RSA requires a RSA server public key") - } - if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil { - return err - } - default: - return errors.New("tls: unknown ECDHE signature algorithm") - } - - return nil + return verifyHandshakeSignature(sigType, cert.PublicKey, hashFunc, digest, sig) } func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { diff --git a/prf.go b/prf.go index 74438f8..53cf04a 100644 --- a/prf.go +++ b/prf.go @@ -309,50 +309,35 @@ func (h finishedHash) serverSum(masterSecret []byte) []byte { return out } -// selectClientCertSignatureAlgorithm returns a SignatureScheme to sign a -// client's CertificateVerify with, or an error if none can be found. -func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []SignatureScheme, sigType uint8) (SignatureScheme, error) { - for _, v := range serverList { - if signatureFromSignatureScheme(v) == sigType && isSupportedSignatureAlgorithm(v, supportedSignatureAlgorithms) { - return v, nil - } - } - return 0, errors.New("tls: no supported signature algorithm found for signing client certificate") -} - -// hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash -// id suitable for signing by a TLS client certificate. -func (h finishedHash) hashForClientCertificate(sigType uint8, signatureAlgorithm SignatureScheme, masterSecret []byte) ([]byte, crypto.Hash, error) { +// hashForClientCertificate returns a digest over the handshake messages so far, +// suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(sigType uint8, hashAlg crypto.Hash, masterSecret []byte) ([]byte, error) { if (h.version == VersionSSL30 || h.version >= VersionTLS12) && h.buffer == nil { panic("a handshake hash for a client-certificate was requested after discarding the handshake buffer") } if h.version == VersionSSL30 { if sigType != signatureRSA { - return nil, 0, errors.New("tls: unsupported signature type for client certificate") + return nil, errors.New("tls: unsupported signature type for client certificate") } md5Hash := md5.New() md5Hash.Write(h.buffer) sha1Hash := sha1.New() sha1Hash.Write(h.buffer) - return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil + return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), nil } if h.version >= VersionTLS12 { - hashAlg, err := lookupTLSHash(signatureAlgorithm) - if err != nil { - return nil, 0, err - } hash := hashAlg.New() hash.Write(h.buffer) - return hash.Sum(nil), hashAlg, nil + return hash.Sum(nil), nil } if sigType == signatureECDSA { - return h.server.Sum(nil), crypto.SHA1, nil + return h.server.Sum(nil), nil } - return h.Sum(), crypto.MD5SHA1, nil + return h.Sum(), nil } // discardHandshakeBuffer is called when there is no more need to