[sike] Refactor key agreement in TLS 1.3 [PATCH 1/2] (#153)

Previously there where two methods used for key agreemnt
tls.Conn::generateKeyShare and tls.Conn::deriveDHESecret. Both were
used on client and server side. Boolean flag is used in order to
differentiate between key agreement performed on client and on server
side. Which sucks badly.
In order to implement shared secret agreement with KEM it is better to
add method which implements server specific key agreement and provide
default implementation which reuses tls.Conn::generateKeyShare followed
by tls.Conn::deriveDHESecret.
Now, it is possible for most of the DH-style key agreements to reuse
default implementation and for KEM-style key agreement to provide server
specific implementation.
This commit is contained in:
Henry Case 2019-02-25 17:25:16 +00:00 committed by GitHub
parent 7619b84b13
commit a5d35123cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

162
13.go
View File

@ -61,8 +61,8 @@ type keySchedule13 struct {
config *Config // Used for KeyLogWriter callback, nil if keylogging is disabled. config *Config // Used for KeyLogWriter callback, nil if keylogging is disabled.
} }
// Interface implemented by DH key exchange strategies // Interface implemented by key exchange strategies
type dhKex interface { type kex interface {
// c - context of current TLS handshake, groupId - ID of an algorithm // c - context of current TLS handshake, groupId - ID of an algorithm
// (curve/field) being chosen for key agreement. Methods implmenting an // (curve/field) being chosen for key agreement. Methods implmenting an
// interface always assume that provided groupId is correct. // interface always assume that provided groupId is correct.
@ -70,27 +70,58 @@ type dhKex interface {
// In case of success, function returns secret key and ephemeral key. Otherwise // In case of success, function returns secret key and ephemeral key. Otherwise
// error is set. // error is set.
generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error)
// c - context of current TLS handshake, ks - public key received // keyAgreementClient declares an API for implementing shared secret agreement on
// from the other side of the connection, secretKey - is a private key // the client side. `c` is a context of current TLS handshake, `ks` is a public key
// used for DH key agreement. Function returns shared secret in case // received from the server, ``privateKey`` client private key.
// of success or empty slice otherwise. // Function returns shared secret in case of success or non nil error otherwise.
derive(c *Conn, ks keyShare, secretKey []byte) []byte keyAgreementClient(c *Conn, ks keyShare, privateKey []byte) ([]byte, error)
// keyAgreementServer declares an API for implementing shared secret agreement on
// the server side. `c` context of current TLS handshake, `ks` is a public key
// received from the client side of the connection, ``privateKey`` is a private key
// of a server.
// Function returns secret shared between parties and public value to exchange
// between parties. In case of failure `error` must be set.
keyAgreementServer(c *Conn, ks keyShare) ([]byte, keyShare, error)
}
// defaultServerKEX is an abstract class defining default, common behaviour on
// a server side.
type defaultServerKEX struct{}
// defaultServerKEX is an abstract class defining default implementation of
// server side key agreement. It generates ephemeral key and uses it together
// with client public part in order to calculate shared secret.
func (defaultServerKEX) keyAgreementServer(c *Conn, clientKS keyShare) ([]byte, keyShare, error) {
privateKey, publicKey, err := c.generateKeyShare(clientKS.group)
if err != nil {
c.sendAlert(alertInternalError)
return nil, keyShare{}, err
}
// Use same key agreement implementation as on the client side
ss, err := c.keyAgreementClient(clientKS, privateKey)
if err != nil {
c.sendAlert(alertIllegalParameter)
return nil, keyShare{}, err
}
return ss, publicKey, nil
} }
// Key Exchange strategies per curve type // Key Exchange strategies per curve type
type kexNist struct{} // Used by NIST curves; P-256, P-384, P-512 type kexNIST struct{ defaultServerKEX } // Used by NIST curves; P-256, P-384, P-512
type kexX25519 struct{} // Used by X25519 type kexX25519 struct{ defaultServerKEX } // Used by X25519
type kexSIDHp503 struct{} // Used by SIDH/P503 type kexSIDHp503 struct{ defaultServerKEX } // Used by SIDH/P503
type kexHybridSIDHp503X25519 struct { type kexHybridSIDHp503X25519 struct {
defaultServerKEX
classicKEX kexX25519 classicKEX kexX25519
pqKEX kexSIDHp503 pqKEX kexSIDHp503
} // Used by SIDH-ECDH hybrid scheme } // Used by SIDH-ECDH hybrid scheme
// Routing map for key exchange strategies // Routing map for key exchange strategies
var dhKexStrat = map[CurveID]dhKex{ var kexStrat = map[CurveID]kex{
CurveP256: &kexNist{}, CurveP256: &kexNIST{},
CurveP384: &kexNist{}, CurveP384: &kexNIST{},
CurveP521: &kexNist{}, CurveP521: &kexNIST{},
X25519: &kexX25519{}, X25519: &kexX25519{},
HybridSIDHp503Curve25519: &kexHybridSIDHp503X25519{}, HybridSIDHp503Curve25519: &kexHybridSIDHp503X25519{},
} }
@ -202,17 +233,6 @@ CurvePreferenceLoop:
} }
} }
} }
if ks.group == 0 {
c.sendAlert(alertInternalError)
return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo)
}
privateKey, serverKS, err := c.generateKeyShare(ks.group)
if err != nil {
c.sendAlert(alertInternalError)
return err
}
hs.hello.keyShare = serverKS
hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
hashSize := hash.Size() hashSize := hash.Size()
@ -232,14 +252,17 @@ CurvePreferenceLoop:
} }
hs.keySchedule.write(hs.clientHello.marshal()) hs.keySchedule.write(hs.clientHello.marshal())
earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient) earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient)
ecdheSecret := c.deriveDHESecret(ks, privateKey) if ks.group == 0 {
if ecdheSecret == nil { c.sendAlert(alertInternalError)
c.sendAlert(alertIllegalParameter) return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo)
return errors.New("tls: bad ECDHE client share")
} }
sharedSecret, serverKS, err := c.keyAgreementServer(ks)
if err != nil {
return err
}
hs.hello.keyShare = serverKS
hs.keySchedule.write(hs.hello.marshal()) hs.keySchedule.write(hs.hello.marshal())
if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
@ -251,7 +274,7 @@ CurvePreferenceLoop:
return err return err
} }
hs.keySchedule.setSecret(ecdheSecret) hs.keySchedule.setSecret(sharedSecret)
clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
hs.hsClientCipher = clientCipher hs.hsClientCipher = clientCipher
serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
@ -598,10 +621,12 @@ func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byt
return h.Sum(nil) return h.Sum(nil)
} }
// generateKeyShare generates keypair. Private key is returned as first argument, public key // generateKeyShare generates keypair. On success it returns private key and keyShare
// is returned in keyShare.data. keyshare.curveID stores ID of the scheme used. // structure with keyShare.group set to supported group ID (as per 4.2.7 in RFC 8446)
// and keyShare.data set to public key, third argument is nil. On failure, third returned
// value (an error) contains error message and first two values are undefined.
func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) { func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
if val, ok := dhKexStrat[curveID]; ok { if val, ok := kexStrat[curveID]; ok {
return val.generate(c, curveID) return val.generate(c, curveID)
} }
return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve") return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve")
@ -609,11 +634,20 @@ func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) {
// DH key agreement. ks stores public key, secretKey stores private key used for ephemeral // DH key agreement. ks stores public key, secretKey stores private key used for ephemeral
// key agreement. Function returns shared secret in case of success or empty slice otherwise. // key agreement. Function returns shared secret in case of success or empty slice otherwise.
func (c *Conn) deriveDHESecret(ks keyShare, secretKey []byte) []byte { func (c *Conn) keyAgreementClient(ks keyShare, secretKey []byte) ([]byte, error) {
if val, ok := dhKexStrat[ks.group]; ok { if val, ok := kexStrat[ks.group]; ok {
return val.derive(c, ks, secretKey) return val.keyAgreementClient(c, ks, secretKey)
} }
return nil return nil, errors.New("tls: unsupported group")
}
// keyAgreementServer generates ephemeral keypair on the on the server side
// and then uses 'keyShare' (client public key) to derive shared secret
func (c *Conn) keyAgreementServer(clientKS keyShare) ([]byte, keyShare, error) {
if val, ok := kexStrat[clientKS.group]; ok {
return val.keyAgreementServer(c, clientKS)
}
return nil, keyShare{}, errors.New("unsupported group")
} }
func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte { func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
@ -989,14 +1023,14 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
// 0-RTT is not supported yet, so use an empty PSK. // 0-RTT is not supported yet, so use an empty PSK.
hs.keySchedule.setSecret(nil) hs.keySchedule.setSecret(nil)
ecdheSecret := c.deriveDHESecret(serverHello.keyShare, hs.privateKey) sharedSecret, err := c.keyAgreementClient(serverHello.keyShare, hs.privateKey)
if ecdheSecret == nil { if err != nil {
c.sendAlert(alertIllegalParameter) c.sendAlert(alertIllegalParameter)
return errors.New("tls: bad ECDHE server share") return err
} }
// Calculate handshake secrets. // Calculate handshake secrets.
hs.keySchedule.setSecret(ecdheSecret) hs.keySchedule.setSecret(sharedSecret)
clientCipher, clientHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) clientCipher, clientHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeClient)
serverCipher, serverHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) serverCipher, serverHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeServer)
if c.hand.Len() > 0 { if c.hand.Len() > 0 {
@ -1170,10 +1204,10 @@ func supportedSigAlgorithmsCert(schemes []SignatureScheme) (ret []SignatureSchem
return return
} }
// Functions below implement dhKex interface for different DH shared secret agreements // Functions below implement kex interface for different DH shared secret agreements
// KEX: P-256, P-384, P-512 KEX // KEX: P-256, P-384, P-512 KEX
func (kexNist) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) { func (kexNIST) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) {
// never fails // never fails
curve, _ := curveForCurveID(groupId) curve, _ := curveForCurveID(groupId)
private, x, y, err := elliptic.GenerateKey(curve, c.config.rand()) private, x, y, err := elliptic.GenerateKey(curve, c.config.rand())
@ -1184,22 +1218,22 @@ func (kexNist) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare,
ks.data = elliptic.Marshal(curve, x, y) ks.data = elliptic.Marshal(curve, x, y)
return return
} }
func (kexNist) derive(c *Conn, ks keyShare, secretKey []byte) []byte { func (kexNIST) keyAgreementClient(c *Conn, ks keyShare, secretKey []byte) ([]byte, error) {
// never fails // never fails
curve, _ := curveForCurveID(ks.group) curve, _ := curveForCurveID(ks.group)
x, y := elliptic.Unmarshal(curve, ks.data) x, y := elliptic.Unmarshal(curve, ks.data)
if x == nil { if x == nil {
return nil return nil, errors.New("tls: Point not on a curve")
} }
x, _ = curve.ScalarMult(x, y, secretKey) x, _ = curve.ScalarMult(x, y, secretKey)
xBytes := x.Bytes() xBytes := x.Bytes()
curveSize := (curve.Params().BitSize + 8 - 1) >> 3 curveSize := (curve.Params().BitSize + 8 - 1) >> 3
if len(xBytes) == curveSize { if len(xBytes) == curveSize {
return xBytes return xBytes, nil
} }
buf := make([]byte, curveSize) buf := make([]byte, curveSize)
copy(buf[len(buf)-len(xBytes):], xBytes) copy(buf[len(buf)-len(xBytes):], xBytes)
return buf return buf, nil
} }
// KEX: X25519 // KEX: X25519
@ -1212,15 +1246,15 @@ func (kexX25519) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) {
return scalar[:], keyShare{group: X25519, data: public[:]}, nil return scalar[:], keyShare{group: X25519, data: public[:]}, nil
} }
func (kexX25519) derive(c *Conn, ks keyShare, secretKey []byte) []byte { func (kexX25519) keyAgreementClient(c *Conn, ks keyShare, secretKey []byte) ([]byte, error) {
var theirPublic, sharedKey, scalar [x25519SharedSecretSz]byte var theirPublic, sharedKey, scalar [x25519SharedSecretSz]byte
if len(ks.data) != x25519SharedSecretSz { if len(ks.data) != x25519SharedSecretSz {
return nil return nil, errors.New("tls: wrong shared secret size")
} }
copy(theirPublic[:], ks.data) copy(theirPublic[:], ks.data)
copy(scalar[:], secretKey) copy(scalar[:], secretKey)
curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
return sharedKey[:] return sharedKey[:], nil
} }
// KEX: SIDH/503 // KEX: SIDH/503
@ -1234,27 +1268,27 @@ func (kexSIDHp503) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error)
return prvKey.Export(), keyShare{group: 0 /*UNUSED*/, data: pubKey.Export()}, nil return prvKey.Export(), keyShare{group: 0 /*UNUSED*/, data: pubKey.Export()}, nil
} }
func (kexSIDHp503) derive(c *Conn, ks keyShare, key []byte) []byte { func (kexSIDHp503) keyAgreementClient(c *Conn, ks keyShare, key []byte) ([]byte, error) {
var prvVariant, pubVariant = getSidhKeyVariant(c.isClient) var prvVariant, pubVariant = getSidhKeyVariant(c.isClient)
var prvKeySize = P503PrvKeySz var prvKeySize = P503PrvKeySz
if len(ks.data) != P503PubKeySz || len(key) != prvKeySize { if len(ks.data) != P503PubKeySz || len(key) != prvKeySize {
return nil return nil, errors.New("tls: wrong key size")
} }
prvKey := sidh.NewPrivateKey(sidh.FP_503, prvVariant) prvKey := sidh.NewPrivateKey(sidh.FP_503, prvVariant)
pubKey := sidh.NewPublicKey(sidh.FP_503, pubVariant) pubKey := sidh.NewPublicKey(sidh.FP_503, pubVariant)
if err := prvKey.Import(key); err != nil { if err := prvKey.Import(key); err != nil {
return nil return nil, errors.New("tls: internal error")
} }
if err := pubKey.Import(ks.data); err != nil { if err := pubKey.Import(ks.data); err != nil {
return nil return nil, errors.New("tls: internal error")
} }
// Never fails // Never fails
sharedKey, _ := sidh.DeriveSecret(prvKey, pubKey) sharedKey, _ := sidh.DeriveSecret(prvKey, pubKey)
return sharedKey return sharedKey, nil
} }
// KEX Hybrid SIDH/503-X25519 // KEX Hybrid SIDH/503-X25519
@ -1280,7 +1314,7 @@ func (kex *kexHybridSIDHp503X25519) generate(c *Conn, groupId CurveID) (private
return prvHybrid[:], keyShare{group: HybridSIDHp503Curve25519, data: pubHybrid[:]}, nil return prvHybrid[:], keyShare{group: HybridSIDHp503Curve25519, data: pubHybrid[:]}, nil
} }
func (kex *kexHybridSIDHp503X25519) derive(c *Conn, ks keyShare, key []byte) []byte { func (kex *kexHybridSIDHp503X25519) keyAgreementClient(c *Conn, ks keyShare, key []byte) ([]byte, error) {
var sharedKey [SIDHp503Curve25519SharedKeySz]byte var sharedKey [SIDHp503Curve25519SharedKeySz]byte
var ret []byte var ret []byte
var tmpKs keyShare var tmpKs keyShare
@ -1288,19 +1322,19 @@ func (kex *kexHybridSIDHp503X25519) derive(c *Conn, ks keyShare, key []byte) []b
// Key agreement for classic // Key agreement for classic
tmpKs.group = X25519 tmpKs.group = X25519
tmpKs.data = ks.data[:x25519SharedSecretSz] tmpKs.data = ks.data[:x25519SharedSecretSz]
ret = kex.classicKEX.derive(c, tmpKs, key[:x25519SharedSecretSz]) ret, err := kex.classicKEX.keyAgreementClient(c, tmpKs, key[:x25519SharedSecretSz])
if ret == nil { if err != nil {
return nil return nil, err
} }
copy(sharedKey[:], ret) copy(sharedKey[:], ret)
// Key agreement for PQ // Key agreement for PQ
tmpKs.group = 0 /*UNUSED*/ tmpKs.group = 0 /*UNUSED*/
tmpKs.data = ks.data[x25519SharedSecretSz:] tmpKs.data = ks.data[x25519SharedSecretSz:]
ret = kex.pqKEX.derive(c, tmpKs, key[x25519SharedSecretSz:]) ret, err = kex.pqKEX.keyAgreementClient(c, tmpKs, key[x25519SharedSecretSz:])
if ret == nil { if err != nil {
return nil return nil, err
} }
copy(sharedKey[x25519SharedSecretSz:], ret) copy(sharedKey[x25519SharedSecretSz:], ret)
return sharedKey[:] return sharedKey[:], nil
} }