diff --git a/13.go b/13.go index 15a655c..4d72c93 100644 --- a/13.go +++ b/13.go @@ -61,8 +61,8 @@ type keySchedule13 struct { config *Config // Used for KeyLogWriter callback, nil if keylogging is disabled. } -// Interface implemented by DH key exchange strategies -type dhKex interface { +// Interface implemented by key exchange strategies +type kex interface { // c - context of current TLS handshake, groupId - ID of an algorithm // (curve/field) being chosen for key agreement. Methods implmenting an // interface always assume that provided groupId is correct. @@ -70,27 +70,58 @@ type dhKex interface { // In case of success, function returns secret key and ephemeral key. Otherwise // error is set. generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) - // c - context of current TLS handshake, ks - public key received - // from the other side of the connection, secretKey - is a private key - // used for DH key agreement. Function returns shared secret in case - // of success or empty slice otherwise. - derive(c *Conn, ks keyShare, secretKey []byte) []byte + // keyAgreementClient declares an API for implementing shared secret agreement on + // the client side. `c` is a context of current TLS handshake, `ks` is a public key + // received from the server, ``privateKey`` client private key. + // Function returns shared secret in case of success or non nil error otherwise. + keyAgreementClient(c *Conn, ks keyShare, privateKey []byte) ([]byte, error) + // keyAgreementServer declares an API for implementing shared secret agreement on + // the server side. `c` context of current TLS handshake, `ks` is a public key + // received from the client side of the connection, ``privateKey`` is a private key + // of a server. + // Function returns secret shared between parties and public value to exchange + // between parties. In case of failure `error` must be set. + keyAgreementServer(c *Conn, ks keyShare) ([]byte, keyShare, error) +} + +// defaultServerKEX is an abstract class defining default, common behaviour on +// a server side. +type defaultServerKEX struct{} + +// defaultServerKEX is an abstract class defining default implementation of +// server side key agreement. It generates ephemeral key and uses it together +// with client public part in order to calculate shared secret. +func (defaultServerKEX) keyAgreementServer(c *Conn, clientKS keyShare) ([]byte, keyShare, error) { + privateKey, publicKey, err := c.generateKeyShare(clientKS.group) + if err != nil { + c.sendAlert(alertInternalError) + return nil, keyShare{}, err + } + + // Use same key agreement implementation as on the client side + ss, err := c.keyAgreementClient(clientKS, privateKey) + if err != nil { + c.sendAlert(alertIllegalParameter) + return nil, keyShare{}, err + } + return ss, publicKey, nil } // Key Exchange strategies per curve type -type kexNist struct{} // Used by NIST curves; P-256, P-384, P-512 -type kexX25519 struct{} // Used by X25519 -type kexSIDHp503 struct{} // Used by SIDH/P503 +type kexNIST struct{ defaultServerKEX } // Used by NIST curves; P-256, P-384, P-512 +type kexX25519 struct{ defaultServerKEX } // Used by X25519 +type kexSIDHp503 struct{ defaultServerKEX } // Used by SIDH/P503 type kexHybridSIDHp503X25519 struct { + defaultServerKEX classicKEX kexX25519 pqKEX kexSIDHp503 } // Used by SIDH-ECDH hybrid scheme // Routing map for key exchange strategies -var dhKexStrat = map[CurveID]dhKex{ - CurveP256: &kexNist{}, - CurveP384: &kexNist{}, - CurveP521: &kexNist{}, +var kexStrat = map[CurveID]kex{ + CurveP256: &kexNIST{}, + CurveP384: &kexNIST{}, + CurveP521: &kexNIST{}, X25519: &kexX25519{}, HybridSIDHp503Curve25519: &kexHybridSIDHp503X25519{}, } @@ -202,17 +233,6 @@ CurvePreferenceLoop: } } } - if ks.group == 0 { - c.sendAlert(alertInternalError) - return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo) - } - - privateKey, serverKS, err := c.generateKeyShare(ks.group) - if err != nil { - c.sendAlert(alertInternalError) - return err - } - hs.hello.keyShare = serverKS hash := hashForSuite(hs.suite) hashSize := hash.Size() @@ -232,14 +252,17 @@ CurvePreferenceLoop: } hs.keySchedule.write(hs.clientHello.marshal()) - earlyClientCipher, _ := hs.keySchedule.prepareCipher(secretEarlyClient) - ecdheSecret := c.deriveDHESecret(ks, privateKey) - if ecdheSecret == nil { - c.sendAlert(alertIllegalParameter) - return errors.New("tls: bad ECDHE client share") + if ks.group == 0 { + c.sendAlert(alertInternalError) + return errors.New("tls: HelloRetryRequest not implemented") // TODO(filippo) + } + sharedSecret, serverKS, err := c.keyAgreementServer(ks) + if err != nil { + return err } + hs.hello.keyShare = serverKS hs.keySchedule.write(hs.hello.marshal()) if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { @@ -251,7 +274,7 @@ CurvePreferenceLoop: return err } - hs.keySchedule.setSecret(ecdheSecret) + hs.keySchedule.setSecret(sharedSecret) clientCipher, cTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) hs.hsClientCipher = clientCipher serverCipher, sTrafficSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) @@ -598,10 +621,12 @@ func prepareDigitallySigned(hash crypto.Hash, context string, data []byte) []byt return h.Sum(nil) } -// generateKeyShare generates keypair. Private key is returned as first argument, public key -// is returned in keyShare.data. keyshare.curveID stores ID of the scheme used. +// generateKeyShare generates keypair. On success it returns private key and keyShare +// structure with keyShare.group set to supported group ID (as per 4.2.7 in RFC 8446) +// and keyShare.data set to public key, third argument is nil. On failure, third returned +// value (an error) contains error message and first two values are undefined. func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) { - if val, ok := dhKexStrat[curveID]; ok { + if val, ok := kexStrat[curveID]; ok { return val.generate(c, curveID) } return nil, keyShare{}, errors.New("tls: preferredCurves includes unsupported curve") @@ -609,11 +634,20 @@ func (c *Conn) generateKeyShare(curveID CurveID) ([]byte, keyShare, error) { // DH key agreement. ks stores public key, secretKey stores private key used for ephemeral // key agreement. Function returns shared secret in case of success or empty slice otherwise. -func (c *Conn) deriveDHESecret(ks keyShare, secretKey []byte) []byte { - if val, ok := dhKexStrat[ks.group]; ok { - return val.derive(c, ks, secretKey) +func (c *Conn) keyAgreementClient(ks keyShare, secretKey []byte) ([]byte, error) { + if val, ok := kexStrat[ks.group]; ok { + return val.keyAgreementClient(c, ks, secretKey) } - return nil + return nil, errors.New("tls: unsupported group") +} + +// keyAgreementServer generates ephemeral keypair on the on the server side +// and then uses 'keyShare' (client public key) to derive shared secret +func (c *Conn) keyAgreementServer(clientKS keyShare) ([]byte, keyShare, error) { + if val, ok := kexStrat[clientKS.group]; ok { + return val.keyAgreementServer(c, clientKS) + } + return nil, keyShare{}, errors.New("unsupported group") } func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte { @@ -989,14 +1023,14 @@ func (hs *clientHandshakeState) doTLS13Handshake() error { // 0-RTT is not supported yet, so use an empty PSK. hs.keySchedule.setSecret(nil) - ecdheSecret := c.deriveDHESecret(serverHello.keyShare, hs.privateKey) - if ecdheSecret == nil { + sharedSecret, err := c.keyAgreementClient(serverHello.keyShare, hs.privateKey) + if err != nil { c.sendAlert(alertIllegalParameter) - return errors.New("tls: bad ECDHE server share") + return err } // Calculate handshake secrets. - hs.keySchedule.setSecret(ecdheSecret) + hs.keySchedule.setSecret(sharedSecret) clientCipher, clientHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeClient) serverCipher, serverHandshakeSecret := hs.keySchedule.prepareCipher(secretHandshakeServer) if c.hand.Len() > 0 { @@ -1170,10 +1204,10 @@ func supportedSigAlgorithmsCert(schemes []SignatureScheme) (ret []SignatureSchem return } -// Functions below implement dhKex interface for different DH shared secret agreements +// Functions below implement kex interface for different DH shared secret agreements // KEX: P-256, P-384, P-512 KEX -func (kexNist) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) { +func (kexNIST) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, err error) { // never fails curve, _ := curveForCurveID(groupId) private, x, y, err := elliptic.GenerateKey(curve, c.config.rand()) @@ -1184,22 +1218,22 @@ func (kexNist) generate(c *Conn, groupId CurveID) (private []byte, ks keyShare, ks.data = elliptic.Marshal(curve, x, y) return } -func (kexNist) derive(c *Conn, ks keyShare, secretKey []byte) []byte { +func (kexNIST) keyAgreementClient(c *Conn, ks keyShare, secretKey []byte) ([]byte, error) { // never fails curve, _ := curveForCurveID(ks.group) x, y := elliptic.Unmarshal(curve, ks.data) if x == nil { - return nil + return nil, errors.New("tls: Point not on a curve") } x, _ = curve.ScalarMult(x, y, secretKey) xBytes := x.Bytes() curveSize := (curve.Params().BitSize + 8 - 1) >> 3 if len(xBytes) == curveSize { - return xBytes + return xBytes, nil } buf := make([]byte, curveSize) copy(buf[len(buf)-len(xBytes):], xBytes) - return buf + return buf, nil } // KEX: X25519 @@ -1212,15 +1246,15 @@ func (kexX25519) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) { return scalar[:], keyShare{group: X25519, data: public[:]}, nil } -func (kexX25519) derive(c *Conn, ks keyShare, secretKey []byte) []byte { +func (kexX25519) keyAgreementClient(c *Conn, ks keyShare, secretKey []byte) ([]byte, error) { var theirPublic, sharedKey, scalar [x25519SharedSecretSz]byte if len(ks.data) != x25519SharedSecretSz { - return nil + return nil, errors.New("tls: wrong shared secret size") } copy(theirPublic[:], ks.data) copy(scalar[:], secretKey) curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic) - return sharedKey[:] + return sharedKey[:], nil } // KEX: SIDH/503 @@ -1234,27 +1268,27 @@ func (kexSIDHp503) generate(c *Conn, groupId CurveID) ([]byte, keyShare, error) return prvKey.Export(), keyShare{group: 0 /*UNUSED*/, data: pubKey.Export()}, nil } -func (kexSIDHp503) derive(c *Conn, ks keyShare, key []byte) []byte { +func (kexSIDHp503) keyAgreementClient(c *Conn, ks keyShare, key []byte) ([]byte, error) { var prvVariant, pubVariant = getSidhKeyVariant(c.isClient) var prvKeySize = P503PrvKeySz if len(ks.data) != P503PubKeySz || len(key) != prvKeySize { - return nil + return nil, errors.New("tls: wrong key size") } prvKey := sidh.NewPrivateKey(sidh.FP_503, prvVariant) pubKey := sidh.NewPublicKey(sidh.FP_503, pubVariant) if err := prvKey.Import(key); err != nil { - return nil + return nil, errors.New("tls: internal error") } if err := pubKey.Import(ks.data); err != nil { - return nil + return nil, errors.New("tls: internal error") } // Never fails sharedKey, _ := sidh.DeriveSecret(prvKey, pubKey) - return sharedKey + return sharedKey, nil } // KEX Hybrid SIDH/503-X25519 @@ -1280,7 +1314,7 @@ func (kex *kexHybridSIDHp503X25519) generate(c *Conn, groupId CurveID) (private return prvHybrid[:], keyShare{group: HybridSIDHp503Curve25519, data: pubHybrid[:]}, nil } -func (kex *kexHybridSIDHp503X25519) derive(c *Conn, ks keyShare, key []byte) []byte { +func (kex *kexHybridSIDHp503X25519) keyAgreementClient(c *Conn, ks keyShare, key []byte) ([]byte, error) { var sharedKey [SIDHp503Curve25519SharedKeySz]byte var ret []byte var tmpKs keyShare @@ -1288,19 +1322,19 @@ func (kex *kexHybridSIDHp503X25519) derive(c *Conn, ks keyShare, key []byte) []b // Key agreement for classic tmpKs.group = X25519 tmpKs.data = ks.data[:x25519SharedSecretSz] - ret = kex.classicKEX.derive(c, tmpKs, key[:x25519SharedSecretSz]) - if ret == nil { - return nil + ret, err := kex.classicKEX.keyAgreementClient(c, tmpKs, key[:x25519SharedSecretSz]) + if err != nil { + return nil, err } copy(sharedKey[:], ret) // Key agreement for PQ tmpKs.group = 0 /*UNUSED*/ tmpKs.data = ks.data[x25519SharedSecretSz:] - ret = kex.pqKEX.derive(c, tmpKs, key[x25519SharedSecretSz:]) - if ret == nil { - return nil + ret, err = kex.pqKEX.keyAgreementClient(c, tmpKs, key[x25519SharedSecretSz:]) + if err != nil { + return nil, err } copy(sharedKey[x25519SharedSecretSz:], ret) - return sharedKey[:] + return sharedKey[:], nil }