diff --git a/common.go b/common.go index fc898e2..9078b63 100644 --- a/common.go +++ b/common.go @@ -325,6 +325,18 @@ type Config struct { // material from the returned config will be used for session tickets. GetConfigForClient func(*ClientHelloInfo) (*Config, error) + // VerifyPeerCertificate, if not nil, is called after normal + // certificate verification by either a TLS client or server. It + // receives the raw ASN.1 certificates provided by the peer and also + // any verified chains that normal processing found. If it returns a + // non-nil error, the handshake is aborted and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. If normal verification is disabled by + // setting InsecureSkipVerify then this callback will be considered but + // the verifiedChains argument will always be nil. + VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. @@ -474,6 +486,7 @@ func (c *Config) Clone() *Config { NameToCertificate: c.NameToCertificate, GetCertificate: c.GetCertificate, GetConfigForClient: c.GetConfigForClient, + VerifyPeerCertificate: c.VerifyPeerCertificate, RootCAs: c.RootCAs, NextProtos: c.NextProtos, ServerName: c.ServerName, diff --git a/handshake_client.go b/handshake_client.go index e42953a..c331c65 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -304,6 +304,13 @@ func (hs *clientHandshakeState) doFullHandshake() error { } } + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + switch certs[0].PublicKey.(type) { case *rsa.PublicKey, *ecdsa.PublicKey: break diff --git a/handshake_client_test.go b/handshake_client_test.go index 3de1dfa..7bbeed0 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -1067,6 +1067,160 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) { } } +func TestVerifyPeerCertificate(t *testing.T) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + rootCAs := x509.NewCertPool() + rootCAs.AddCert(issuer) + + now := func() time.Time { return time.Unix(1476984729, 0) } + + sentinelErr := errors.New("TestVerifyPeerCertificate") + + verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + if len(validatedChains) == 0 { + return errors.New("got len(validatedChains) = 0, wanted non-zero") + } + *called = true + return nil + } + + tests := []struct { + configureServer func(*Config, *bool) + configureClient func(*Config, *bool) + validate func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) + }{ + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyCallback(called, rawCerts, validatedChains) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyCallback(called, rawCerts, validatedChains) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("#%d: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("#%d: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Error("#%d: client did not call callback", testNo) + } + if !serverCalled { + t.Error("#%d: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = true + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + if l := len(rawCerts); l != 1 { + return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) + } + // With InsecureSkipVerify set, this + // callback should still be called but + // validatedChains must be empty. + if l := len(validatedChains); l != 0 { + return errors.New("got len(validatedChains) = 0, wanted zero") + } + *called = true + return nil + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("#%d: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("#%d: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Error("#%d: client did not call callback", testNo) + } + }, + }, + } + + for i, test := range tests { + c, s := net.Pipe() + done := make(chan error) + + var clientCalled, serverCalled bool + + go func() { + config := testConfig.Clone() + config.ServerName = "example.golang" + config.ClientAuth = RequireAndVerifyClientCert + config.ClientCAs = rootCAs + config.Time = now + test.configureServer(config, &serverCalled) + + err = Server(s, config).Handshake() + s.Close() + done <- err + }() + + config := testConfig.Clone() + config.ServerName = "example.golang" + config.RootCAs = rootCAs + config.Time = now + test.configureClient(config, &clientCalled) + clientErr := Client(c, config).Handshake() + c.Close() + serverErr := <-done + + test.validate(t, i, clientCalled, serverCalled, clientErr, serverErr) + } +} + // brokenConn wraps a net.Conn and causes all Writes after a certain number to // fail with brokenConnErr. type brokenConn struct { diff --git a/handshake_server.go b/handshake_server.go index 630a99a..724ed71 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -747,6 +747,13 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c c.verifiedChains = chains } + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return nil, err + } + } + if len(certs) == 0 { return nil, nil } diff --git a/tls_test.go b/tls_test.go index 87cfa3f..99153a5 100644 --- a/tls_test.go +++ b/tls_test.go @@ -477,7 +477,7 @@ func TestClone(t *testing.T) { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) continue - case "Time", "GetCertificate", "GetConfigForClient": + case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate": // DeepEqual can't compare functions. continue case "Certificates":