diff --git a/handshake_client_test.go b/handshake_client_test.go index 5851f89..eaef8aa 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -1189,7 +1189,7 @@ func TestVerifyPeerCertificate(t *testing.T) { // callback should still be called but // validatedChains must be empty. if l := len(validatedChains); l != 0 { - return errors.New("got len(validatedChains) = 0, wanted zero") + return fmt.Errorf("got len(validatedChains) = %d, wanted zero", l) } *called = true return nil @@ -1438,19 +1438,23 @@ func TestTLS11SignatureSchemes(t *testing.T) { } var getClientCertificateTests = []struct { - setup func(*Config) + setup func(*Config, *Config) expectedClientError string verify func(*testing.T, int, *ConnectionState) }{ { - func(clientConfig *Config) { + func(clientConfig, serverConfig *Config) { // Returning a Certificate with no certificate data // should result in an empty message being sent to the // server. + serverConfig.ClientCAs = nil clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { if len(cri.SignatureSchemes) == 0 { panic("empty SignatureSchemes") } + if len(cri.AcceptableCAs) != 0 { + panic("AcceptableCAs should have been empty") + } return new(Certificate), nil } }, @@ -1462,7 +1466,7 @@ var getClientCertificateTests = []struct { }, }, { - func(clientConfig *Config) { + func(clientConfig, serverConfig *Config) { // With TLS 1.1, the SignatureSchemes should be // synthesised from the supported certificate types. clientConfig.MaxVersion = VersionTLS11 @@ -1481,7 +1485,7 @@ var getClientCertificateTests = []struct { }, }, { - func(clientConfig *Config) { + func(clientConfig, serverConfig *Config) { // Returning an error should abort the handshake with // that error. clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { @@ -1493,14 +1497,21 @@ var getClientCertificateTests = []struct { }, }, { - func(clientConfig *Config) { + func(clientConfig, serverConfig *Config) { clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { - return &testConfig.Certificates[0], nil + if len(cri.AcceptableCAs) == 0 { + panic("empty AcceptableCAs") + } + cert := &Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + return cert, nil } }, "", func(t *testing.T, testNum int, cs *ConnectionState) { - if l := len(cs.VerifiedChains); l != 0 { + if len(cs.VerifiedChains) == 0 { t.Errorf("#%d: expected some verified chains, but found none", testNum) } }, @@ -1515,13 +1526,15 @@ func TestGetClientCertificate(t *testing.T) { for i, test := range getClientCertificateTests { serverConfig := testConfig.Clone() - serverConfig.ClientAuth = RequestClientCert + serverConfig.ClientAuth = VerifyClientCertIfGiven serverConfig.RootCAs = x509.NewCertPool() serverConfig.RootCAs.AddCert(issuer) + serverConfig.ClientCAs = serverConfig.RootCAs + serverConfig.Time = func() time.Time { return time.Unix(1476984729, 0) } clientConfig := testConfig.Clone() - test.setup(clientConfig) + test.setup(clientConfig, serverConfig) type serverResult struct { cs ConnectionState @@ -1553,6 +1566,8 @@ func TestGetClientCertificate(t *testing.T) { t.Errorf("#%d: client error: %v", i, clientErr) } else if got := clientErr.Error(); got != test.expectedClientError { t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) + } else { + test.verify(t, i, &result.cs) } } else if len(test.expectedClientError) > 0 { t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)