Refactor the keyAgreement interface

It's sufficient to pass in the *tls.Certificate (resp.
*x509.Certificate) to the server functions (resp. client funcctions),
but not necessary; the existing keyAgreement implementations only makes
use of the private key (resp. public key). Moreover, this change is
necessary for implementing the delegated credentials extension, which
replaces the private key (resp. public key) used in the handshake.
This commit is contained in:
Christopher Patton 2018-06-21 09:34:19 -07:00
parent 3ff71dcdc5
commit 963d5877be
4 changed files with 22 additions and 23 deletions

View File

@ -5,6 +5,7 @@
package tls package tls
import ( import (
"crypto"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/des" "crypto/des"
@ -12,7 +13,6 @@ import (
"crypto/rc4" "crypto/rc4"
"crypto/sha1" "crypto/sha1"
"crypto/sha256" "crypto/sha256"
"crypto/x509"
"hash" "hash"
"golang_org/x/crypto/chacha20poly1305" "golang_org/x/crypto/chacha20poly1305"
@ -26,15 +26,15 @@ type keyAgreement interface {
// In the case that the key agreement protocol doesn't use a // In the case that the key agreement protocol doesn't use a
// ServerKeyExchange message, generateServerKeyExchange can return nil, // ServerKeyExchange message, generateServerKeyExchange can return nil,
// nil. // nil.
generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) generateServerKeyExchange(*Config, crypto.PrivateKey, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error)
processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) processClientKeyExchange(*Config, crypto.PrivateKey, *clientKeyExchangeMsg, uint16) ([]byte, error)
// On the client side, the next two methods are called in order. // On the client side, the next two methods are called in order.
// This method may not be called if the server doesn't send a // This method may not be called if the server doesn't send a
// ServerKeyExchange message. // ServerKeyExchange message.
processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, crypto.PublicKey, *serverKeyExchangeMsg) error
generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) generateClientKeyExchange(*Config, *clientHelloMsg, crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error)
} }
const ( const (

View File

@ -480,7 +480,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
skx, ok := msg.(*serverKeyExchangeMsg) skx, ok := msg.(*serverKeyExchangeMsg)
if ok { if ok {
hs.finishedHash.Write(skx.marshal()) hs.finishedHash.Write(skx.marshal())
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx) err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0].PublicKey, skx)
if err != nil { if err != nil {
c.sendAlert(alertUnexpectedMessage) c.sendAlert(alertUnexpectedMessage)
return err return err
@ -529,7 +529,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
} }
} }
preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0]) preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0].PublicKey)
if err != nil { if err != nil {
c.sendAlert(alertInternalError) c.sendAlert(alertInternalError)
return err return err

View File

@ -479,7 +479,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
keyAgreement := hs.suite.ka(c.vers) keyAgreement := hs.suite.ka(c.vers)
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert.PrivateKey, hs.clientHello, hs.hello)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure) c.sendAlert(alertHandshakeFailure)
return err return err
@ -572,7 +572,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
} }
hs.finishedHash.Write(ckx.marshal()) hs.finishedHash.Write(ckx.marshal())
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert.PrivateKey, ckx, c.vers)
if err != nil { if err != nil {
if err == errClientKeyExchange { if err == errClientKeyExchange {
c.sendAlert(alertDecodeError) c.sendAlert(alertDecodeError)

View File

@ -10,7 +10,6 @@ import (
"crypto/md5" "crypto/md5"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1"
"crypto/x509"
"errors" "errors"
"io" "io"
"math/big" "math/big"
@ -25,11 +24,11 @@ var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
// encrypts the pre-master secret to the server's public key. // encrypts the pre-master secret to the server's public key.
type rsaKeyAgreement struct{} type rsaKeyAgreement struct{}
func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
return nil, nil return nil, nil
} }
func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) < 2 { if len(ckx.ciphertext) < 2 {
return nil, errClientKeyExchange return nil, errClientKeyExchange
} }
@ -42,7 +41,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
} }
ciphertext = ckx.ciphertext[2:] ciphertext = ckx.ciphertext[2:]
} }
priv, ok := cert.PrivateKey.(crypto.Decrypter) priv, ok := sk.(crypto.Decrypter)
if !ok { if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter") return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
} }
@ -60,11 +59,11 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
return preMasterSecret, nil return preMasterSecret, nil
} }
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error {
return errors.New("tls: unexpected ServerKeyExchange") return errors.New("tls: unexpected ServerKeyExchange")
} }
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) {
preMasterSecret := make([]byte, 48) preMasterSecret := make([]byte, 48)
preMasterSecret[0] = byte(clientHello.vers >> 8) preMasterSecret[0] = byte(clientHello.vers >> 8)
preMasterSecret[1] = byte(clientHello.vers) preMasterSecret[1] = byte(clientHello.vers)
@ -73,7 +72,7 @@ func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello
return nil, nil, err return nil, nil, err
} }
encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret) encrypted, err := rsa.EncryptPKCS1v15(config.rand(), pk.(*rsa.PublicKey), preMasterSecret)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -156,7 +155,7 @@ type ecdheKeyAgreement struct {
x, y *big.Int x, y *big.Int
} }
func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, sk crypto.PrivateKey, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
preferredCurves := config.curvePreferences() preferredCurves := config.curvePreferences()
NextCandidate: NextCandidate:
@ -207,7 +206,7 @@ NextCandidate:
serverECDHParams[3] = byte(len(ecdhePublic)) serverECDHParams[3] = byte(len(ecdhePublic))
copy(serverECDHParams[4:], ecdhePublic) copy(serverECDHParams[4:], ecdhePublic)
priv, ok := cert.PrivateKey.(crypto.Signer) priv, ok := sk.(crypto.Signer)
if !ok { if !ok {
return nil, errors.New("tls: certificate private key does not implement crypto.Signer") return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
} }
@ -255,7 +254,7 @@ NextCandidate:
return skx, nil return skx, nil
} }
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, sk crypto.PrivateKey, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
return nil, errClientKeyExchange return nil, errClientKeyExchange
} }
@ -291,7 +290,7 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert
return preMasterSecret, nil return preMasterSecret, nil
} }
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, pk crypto.PublicKey, skx *serverKeyExchangeMsg) error {
if len(skx.key) < 4 { if len(skx.key) < 4 {
return errServerKeyExchange return errServerKeyExchange
} }
@ -337,7 +336,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
return errServerKeyExchange return errServerKeyExchange
} }
} }
_, sigType, hashFunc, err := pickSignatureAlgorithm(cert.PublicKey, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version) _, sigType, hashFunc, err := pickSignatureAlgorithm(pk, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version)
if err != nil { if err != nil {
return err return err
} }
@ -355,10 +354,10 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
if err != nil { if err != nil {
return err return err
} }
return verifyHandshakeSignature(sigType, cert.PublicKey, hashFunc, digest, sig) return verifyHandshakeSignature(sigType, pk, hashFunc, digest, sig)
} }
func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, pk crypto.PublicKey) ([]byte, *clientKeyExchangeMsg, error) {
if ka.curveid == 0 { if ka.curveid == 0 {
return nil, nil, errors.New("tls: missing ServerKeyExchange message") return nil, nil, errors.New("tls: missing ServerKeyExchange message")
} }