crypto/tls: add VerifyPeerCertificate to tls.Config
VerifyPeerCertificate returns an error if the peer should not be trusted. It will be called after the initial handshake and before any other verification checks on the cert or chain are performed. This provides the callee an opportunity to augment the certificate verification. If VerifyPeerCertificate is not nil and returns an error, then the handshake will fail. Fixes #16363 Change-Id: I6a22f199f0e81b6f5d5f37c54d85ab878216bb22 Reviewed-on: https://go-review.googlesource.com/26654 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
5b71240bb5
commit
a660d3e993
13
common.go
13
common.go
@ -325,6 +325,18 @@ type Config struct {
|
|||||||
// material from the returned config will be used for session tickets.
|
// material from the returned config will be used for session tickets.
|
||||||
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
|
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
|
||||||
|
|
||||||
|
// VerifyPeerCertificate, if not nil, is called after normal
|
||||||
|
// certificate verification by either a TLS client or server. It
|
||||||
|
// receives the raw ASN.1 certificates provided by the peer and also
|
||||||
|
// any verified chains that normal processing found. If it returns a
|
||||||
|
// non-nil error, the handshake is aborted and that error results.
|
||||||
|
//
|
||||||
|
// If normal verification fails then the handshake will abort before
|
||||||
|
// considering this callback. If normal verification is disabled by
|
||||||
|
// setting InsecureSkipVerify then this callback will be considered but
|
||||||
|
// the verifiedChains argument will always be nil.
|
||||||
|
VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error
|
||||||
|
|
||||||
// RootCAs defines the set of root certificate authorities
|
// RootCAs defines the set of root certificate authorities
|
||||||
// that clients use when verifying server certificates.
|
// that clients use when verifying server certificates.
|
||||||
// If RootCAs is nil, TLS uses the host's root CA set.
|
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||||
@ -474,6 +486,7 @@ func (c *Config) Clone() *Config {
|
|||||||
NameToCertificate: c.NameToCertificate,
|
NameToCertificate: c.NameToCertificate,
|
||||||
GetCertificate: c.GetCertificate,
|
GetCertificate: c.GetCertificate,
|
||||||
GetConfigForClient: c.GetConfigForClient,
|
GetConfigForClient: c.GetConfigForClient,
|
||||||
|
VerifyPeerCertificate: c.VerifyPeerCertificate,
|
||||||
RootCAs: c.RootCAs,
|
RootCAs: c.RootCAs,
|
||||||
NextProtos: c.NextProtos,
|
NextProtos: c.NextProtos,
|
||||||
ServerName: c.ServerName,
|
ServerName: c.ServerName,
|
||||||
|
@ -304,6 +304,13 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.config.VerifyPeerCertificate != nil {
|
||||||
|
if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil {
|
||||||
|
c.sendAlert(alertBadCertificate)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch certs[0].PublicKey.(type) {
|
switch certs[0].PublicKey.(type) {
|
||||||
case *rsa.PublicKey, *ecdsa.PublicKey:
|
case *rsa.PublicKey, *ecdsa.PublicKey:
|
||||||
break
|
break
|
||||||
|
@ -1067,6 +1067,160 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVerifyPeerCertificate(t *testing.T) {
|
||||||
|
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCAs := x509.NewCertPool()
|
||||||
|
rootCAs.AddCert(issuer)
|
||||||
|
|
||||||
|
now := func() time.Time { return time.Unix(1476984729, 0) }
|
||||||
|
|
||||||
|
sentinelErr := errors.New("TestVerifyPeerCertificate")
|
||||||
|
|
||||||
|
verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
if l := len(rawCerts); l != 1 {
|
||||||
|
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
|
||||||
|
}
|
||||||
|
if len(validatedChains) == 0 {
|
||||||
|
return errors.New("got len(validatedChains) = 0, wanted non-zero")
|
||||||
|
}
|
||||||
|
*called = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
configureServer func(*Config, *bool)
|
||||||
|
configureClient func(*Config, *bool)
|
||||||
|
validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
configureServer: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = false
|
||||||
|
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
return verifyCallback(called, rawCerts, validatedChains)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
configureClient: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = false
|
||||||
|
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
return verifyCallback(called, rawCerts, validatedChains)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||||
|
if clientErr != nil {
|
||||||
|
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
|
||||||
|
}
|
||||||
|
if serverErr != nil {
|
||||||
|
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
|
||||||
|
}
|
||||||
|
if !clientCalled {
|
||||||
|
t.Error("#%d: client did not call callback", testNo)
|
||||||
|
}
|
||||||
|
if !serverCalled {
|
||||||
|
t.Error("#%d: server did not call callback", testNo)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configureServer: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = false
|
||||||
|
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
return sentinelErr
|
||||||
|
}
|
||||||
|
},
|
||||||
|
configureClient: func(config *Config, called *bool) {
|
||||||
|
config.VerifyPeerCertificate = nil
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||||
|
if serverErr != sentinelErr {
|
||||||
|
t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configureServer: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = false
|
||||||
|
},
|
||||||
|
configureClient: func(config *Config, called *bool) {
|
||||||
|
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
return sentinelErr
|
||||||
|
}
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||||
|
if clientErr != sentinelErr {
|
||||||
|
t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
configureServer: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = false
|
||||||
|
},
|
||||||
|
configureClient: func(config *Config, called *bool) {
|
||||||
|
config.InsecureSkipVerify = true
|
||||||
|
config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error {
|
||||||
|
if l := len(rawCerts); l != 1 {
|
||||||
|
return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l)
|
||||||
|
}
|
||||||
|
// With InsecureSkipVerify set, this
|
||||||
|
// callback should still be called but
|
||||||
|
// validatedChains must be empty.
|
||||||
|
if l := len(validatedChains); l != 0 {
|
||||||
|
return errors.New("got len(validatedChains) = 0, wanted zero")
|
||||||
|
}
|
||||||
|
*called = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) {
|
||||||
|
if clientErr != nil {
|
||||||
|
t.Errorf("#%d: client handshake failed: %v", testNo, clientErr)
|
||||||
|
}
|
||||||
|
if serverErr != nil {
|
||||||
|
t.Errorf("#%d: server handshake failed: %v", testNo, serverErr)
|
||||||
|
}
|
||||||
|
if !clientCalled {
|
||||||
|
t.Error("#%d: client did not call callback", testNo)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
c, s := net.Pipe()
|
||||||
|
done := make(chan error)
|
||||||
|
|
||||||
|
var clientCalled, serverCalled bool
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.ServerName = "example.golang"
|
||||||
|
config.ClientAuth = RequireAndVerifyClientCert
|
||||||
|
config.ClientCAs = rootCAs
|
||||||
|
config.Time = now
|
||||||
|
test.configureServer(config, &serverCalled)
|
||||||
|
|
||||||
|
err = Server(s, config).Handshake()
|
||||||
|
s.Close()
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.ServerName = "example.golang"
|
||||||
|
config.RootCAs = rootCAs
|
||||||
|
config.Time = now
|
||||||
|
test.configureClient(config, &clientCalled)
|
||||||
|
clientErr := Client(c, config).Handshake()
|
||||||
|
c.Close()
|
||||||
|
serverErr := <-done
|
||||||
|
|
||||||
|
test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
|
// brokenConn wraps a net.Conn and causes all Writes after a certain number to
|
||||||
// fail with brokenConnErr.
|
// fail with brokenConnErr.
|
||||||
type brokenConn struct {
|
type brokenConn struct {
|
||||||
|
@ -747,6 +747,13 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c
|
|||||||
c.verifiedChains = chains
|
c.verifiedChains = chains
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.config.VerifyPeerCertificate != nil {
|
||||||
|
if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
|
||||||
|
c.sendAlert(alertBadCertificate)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(certs) == 0 {
|
if len(certs) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
|
|||||||
case "Rand":
|
case "Rand":
|
||||||
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
||||||
continue
|
continue
|
||||||
case "Time", "GetCertificate", "GetConfigForClient":
|
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate":
|
||||||
// DeepEqual can't compare functions.
|
// DeepEqual can't compare functions.
|
||||||
continue
|
continue
|
||||||
case "Certificates":
|
case "Certificates":
|
||||||
|
Loading…
Reference in New Issue
Block a user