diff --git a/common.go b/common.go index ec3e997..84481a2 100644 --- a/common.go +++ b/common.go @@ -489,10 +489,10 @@ func (c *Config) BuildNameToCertificate() { type Certificate struct { Certificate [][]byte // PrivateKey contains the private key corresponding to the public key - // in Leaf. For a server, this must implement either crypto.Decrypter - // (implemented by RSA private keys) or crypto.Signer (which includes - // RSA and ECDSA private keys). For a client doing client authentication, - // this can be any type that implements crypto.Signer. + // in Leaf. For a server, this must implement crypto.Signer and/or + // crypto.Decrypter, with an RSA or ECDSA PublicKey. For a client + // (performing client authentication), this must be a crypto.Signer + // with an RSA or ECDSA PublicKey. PrivateKey crypto.PrivateKey // OCSPStaple contains an optional OCSP response which will be served // to clients that request it. diff --git a/handshake_server.go b/handshake_server.go index 8b31c7c..7fc1b5f 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -25,9 +25,8 @@ type serverHandshakeState struct { suite *cipherSuite ellipticOk bool ecdsaOk bool - rsaOk bool - signOk bool - decryptOk bool + rsaDecryptOk bool + rsaSignOk bool sessionState *sessionState finishedHash finishedHash masterSecret []byte @@ -201,22 +200,20 @@ Curves: } if priv, ok := hs.cert.PrivateKey.(crypto.Signer); ok { - hs.signOk = true switch priv.Public().(type) { case *ecdsa.PublicKey: hs.ecdsaOk = true case *rsa.PublicKey: - hs.rsaOk = true + hs.rsaSignOk = true default: c.sendAlert(alertInternalError) return false, fmt.Errorf("crypto/tls: unsupported signing key type (%T)", priv.Public()) } } if priv, ok := hs.cert.PrivateKey.(crypto.Decrypter); ok { - hs.decryptOk = true switch priv.Public().(type) { case *rsa.PublicKey: - hs.rsaOk = true + hs.rsaDecryptOk = true default: c.sendAlert(alertInternalError) return false, fmt.Errorf("crypto/tls: unsupported decryption key type (%T)", priv.Public()) @@ -692,17 +689,17 @@ func (hs *serverHandshakeState) setCipherSuite(id uint16, supportedCipherSuites // Don't select a ciphersuite which we can't // support for this client. if candidate.flags&suiteECDHE != 0 { - if !hs.ellipticOk || !hs.signOk { + if !hs.ellipticOk { continue } if candidate.flags&suiteECDSA != 0 { if !hs.ecdsaOk { continue } - } else if !hs.rsaOk { + } else if !hs.rsaSignOk { continue } - } else if !hs.decryptOk || !hs.rsaOk { + } else if !hs.rsaDecryptOk { continue } if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 { diff --git a/handshake_server_test.go b/handshake_server_test.go index af5cadb..ed0248f 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -63,6 +63,10 @@ func init() { testConfig.BuildNameToCertificate() } +func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) { + testClientHelloFailure(t, serverConfig, m, "") +} + func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) { // Create in-memory network connection, // send message to server. Should return @@ -78,7 +82,11 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa }() err := Server(s, serverConfig).Handshake() s.Close() - if err == nil || !strings.Contains(err.Error(), expectedSubStr) { + if len(expectedSubStr) == 0 { + if err != nil && err != io.EOF { + t.Errorf("Got error: %s; expected to succeed", err, expectedSubStr) + } + } else if err == nil || !strings.Contains(err.Error(), expectedSubStr) { t.Errorf("Got error: %s; expected to match substring '%s'", err, expectedSubStr) } } @@ -126,6 +134,55 @@ func TestNoRC4ByDefault(t *testing.T) { testClientHelloFailure(t, &serverConfig, clientHello, "no cipher suite supported by both client and server") } +func TestDontSelectECDSAWithRSAKey(t *testing.T) { + // Test that, even when both sides support an ECDSA cipher suite, it + // won't be selected if the server's private key doesn't support it. + clientHello := &clientHelloMsg{ + vers: 0x0301, + cipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, + compressionMethods: []uint8{0}, + supportedCurves: []CurveID{CurveP256}, + supportedPoints: []uint8{pointFormatUncompressed}, + } + serverConfig := *testConfig + serverConfig.CipherSuites = clientHello.cipherSuites + serverConfig.Certificates = make([]Certificate, 1) + serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate} + serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey + serverConfig.BuildNameToCertificate() + // First test that it *does* work when the server's key is ECDSA. + testClientHello(t, &serverConfig, clientHello) + + // Now test that switching to an RSA key causes the expected error (and + // not an internal error about a signing failure). + serverConfig.Certificates = testConfig.Certificates + testClientHelloFailure(t, &serverConfig, clientHello, "no cipher suite supported by both client and server") +} + +func TestDontSelectRSAWithECDSAKey(t *testing.T) { + // Test that, even when both sides support an RSA cipher suite, it + // won't be selected if the server's private key doesn't support it. + clientHello := &clientHelloMsg{ + vers: 0x0301, + cipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, + compressionMethods: []uint8{0}, + supportedCurves: []CurveID{CurveP256}, + supportedPoints: []uint8{pointFormatUncompressed}, + } + serverConfig := *testConfig + serverConfig.CipherSuites = clientHello.cipherSuites + // First test that it *does* work when the server's key is RSA. + testClientHello(t, &serverConfig, clientHello) + + // Now test that switching to an ECDSA key causes the expected error + // (and not an internal error about a signing failure). + serverConfig.Certificates = make([]Certificate, 1) + serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate} + serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey + serverConfig.BuildNameToCertificate() + testClientHelloFailure(t, &serverConfig, clientHello, "no cipher suite supported by both client and server") +} + func TestRenegotiationExtension(t *testing.T) { clientHello := &clientHelloMsg{ vers: VersionTLS12,