diff --git a/common.go b/common.go index beca798..5b2c666 100644 --- a/common.go +++ b/common.go @@ -286,6 +286,21 @@ type ClientHelloInfo struct { Conn net.Conn } +// CertificateRequestInfo contains information from a server's +// CertificateRequest message, which is used to demand a certificate and proof +// of control from a client. +type CertificateRequestInfo struct { + // AcceptableCAs contains zero or more, DER-encoded, X.501 + // Distinguished Names. These are the names of root or intermediate CAs + // that the server wishes the returned certificate to be signed by. An + // empty slice indicates that the server has no preference. + AcceptableCAs [][]byte + + // SignatureSchemes lists the signature schemes that the server is + // willing to verify. + SignatureSchemes []SignatureScheme +} + // RenegotiationSupport enumerates the different levels of support for TLS // renegotiation. TLS renegotiation is the act of performing subsequent // handshakes on a connection after the first. This significantly complicates @@ -328,10 +343,11 @@ type Config struct { // If Time is nil, TLS uses time.Now. Time func() time.Time - // Certificates contains one or more certificate chains - // to present to the other side of the connection. - // Server configurations must include at least one certificate - // or else set GetCertificate. + // Certificates contains one or more certificate chains to present to + // the other side of the connection. Server configurations must include + // at least one certificate or else set GetCertificate. Clients doing + // client-authentication may set either Certificates or + // GetClientCertificate. Certificates []Certificate // NameToCertificate maps from a certificate name to an element of @@ -351,6 +367,21 @@ type Config struct { // first element of Certificates will be used. GetCertificate func(*ClientHelloInfo) (*Certificate, error) + // GetClientCertificate, if not nil, is called when a server requests a + // certificate from a client. If set, the contents of Certificates will + // be ignored. + // + // If GetClientCertificate returns an error, the handshake will be + // aborted and that error will be returned. Otherwise + // GetClientCertificate must return a non-nil Certificate. If + // Certificate.Certificate is empty then no certificate will be sent to + // the server. If this is unacceptable to the server then it may abort + // the handshake. + // + // GetClientCertificate may be called multiple times for the same + // connection if renegotiation occurs or if TLS 1.3 is in use. + GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error) + // GetConfigForClient, if not nil, is called after a ClientHello is // received from a client. It may return a non-nil Config in order to // change the Config that will be used to handle this connection. If diff --git a/handshake_client.go b/handshake_client.go index c331c65..89bdd59 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -199,7 +199,7 @@ NextCipherSuite: // Otherwise, in a full handshake, if we don't have any certificates // configured then we will never send a CertificateVerify message and // thus no signatures are needed in that case either. - if isResume || len(c.config.Certificates) == 0 { + if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) { hs.finishedHash.discardHandshakeBuffer() } @@ -377,71 +377,11 @@ func (hs *clientHandshakeState) doFullHandshake() error { certReq, ok := msg.(*certificateRequestMsg) if ok { certRequested = true - - // RFC 4346 on the certificateAuthorities field: - // A list of the distinguished names of acceptable certificate - // authorities. These distinguished names may specify a desired - // distinguished name for a root CA or for a subordinate CA; - // thus, this message can be used to describe both known roots - // and a desired authorization space. If the - // certificate_authorities list is empty then the client MAY - // send any certificate of the appropriate - // ClientCertificateType, unless there is some external - // arrangement to the contrary. - hs.finishedHash.Write(certReq.marshal()) - var rsaAvail, ecdsaAvail bool - for _, certType := range certReq.certificateTypes { - switch certType { - case certTypeRSASign: - rsaAvail = true - case certTypeECDSASign: - ecdsaAvail = true - } - } - - // We need to search our list of client certs for one - // where SignatureAlgorithm is acceptable to the server and the - // Issuer is in certReq.certificateAuthorities - findCert: - for i, chain := range c.config.Certificates { - if !rsaAvail && !ecdsaAvail { - continue - } - - for j, cert := range chain.Certificate { - x509Cert := chain.Leaf - // parse the certificate if this isn't the leaf - // node, or if chain.Leaf was nil - if j != 0 || x509Cert == nil { - if x509Cert, err = x509.ParseCertificate(cert); err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error()) - } - } - - switch { - case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA: - case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA: - default: - continue findCert - } - - if len(certReq.certificateAuthorities) == 0 { - // they gave us an empty list, so just take the - // first cert from c.config.Certificates - chainToSend = &chain - break findCert - } - - for _, ca := range certReq.certificateAuthorities { - if bytes.Equal(x509Cert.RawIssuer, ca) { - chainToSend = &chain - break findCert - } - } - } + if chainToSend, err = hs.getCertificate(certReq); err != nil { + c.sendAlert(alertInternalError) + return err } msg, err = c.readHandshake() @@ -462,9 +402,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { // certificate to send. if certRequested { certMsg = new(certificateMsg) - if chainToSend != nil { - certMsg.certificates = chainToSend.Certificate - } + certMsg.certificates = chainToSend.Certificate hs.finishedHash.Write(certMsg.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil { return err @@ -483,7 +421,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { } } - if chainToSend != nil { + if chainToSend != nil && len(chainToSend.Certificate) > 0 { certVerify := &certificateVerifyMsg{ hasSignatureAndHash: c.vers >= VersionTLS12, } @@ -727,6 +665,117 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error { return nil } +// tls11SignatureSchemes contains the signature schemes that we synthesise for +// a TLS <= 1.1 connection, based on the supported certificate types. +var tls11SignatureSchemes = []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1} + +const ( + // tls11SignatureSchemesNumECDSA is the number of initial elements of + // tls11SignatureSchemes that use ECDSA. + tls11SignatureSchemesNumECDSA = 3 + // tls11SignatureSchemesNumRSA is the number of trailing elements of + // tls11SignatureSchemes that use RSA. + tls11SignatureSchemesNumRSA = 4 +) + +func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (*Certificate, error) { + c := hs.c + + var rsaAvail, ecdsaAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case certTypeRSASign: + rsaAvail = true + case certTypeECDSASign: + ecdsaAvail = true + } + } + + if c.config.GetClientCertificate != nil { + var signatureSchemes []SignatureScheme + + if !certReq.hasSignatureAndHash { + // Prior to TLS 1.2, the signature schemes were not + // included in the certificate request message. In this + // case we use a plausible list based on the acceptable + // certificate types. + signatureSchemes = tls11SignatureSchemes + if !ecdsaAvail { + signatureSchemes = signatureSchemes[tls11SignatureSchemesNumECDSA:] + } + if !rsaAvail { + signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA] + } + } else { + signatureSchemes = make([]SignatureScheme, 0, len(certReq.signatureAndHashes)) + for _, sah := range certReq.signatureAndHashes { + signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature)) + } + } + + return c.config.GetClientCertificate(&CertificateRequestInfo{ + AcceptableCAs: certReq.certificateAuthorities, + SignatureSchemes: signatureSchemes, + }) + } + + // RFC 4346 on the certificateAuthorities field: A list of the + // distinguished names of acceptable certificate authorities. + // These distinguished names may specify a desired + // distinguished name for a root CA or for a subordinate CA; + // thus, this message can be used to describe both known roots + // and a desired authorization space. If the + // certificate_authorities list is empty then the client MAY + // send any certificate of the appropriate + // ClientCertificateType, unless there is some external + // arrangement to the contrary. + + // We need to search our list of client certs for one + // where SignatureAlgorithm is acceptable to the server and the + // Issuer is in certReq.certificateAuthorities +findCert: + for i, chain := range c.config.Certificates { + if !rsaAvail && !ecdsaAvail { + continue + } + + for j, cert := range chain.Certificate { + x509Cert := chain.Leaf + // parse the certificate if this isn't the leaf + // node, or if chain.Leaf was nil + if j != 0 || x509Cert == nil { + var err error + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + c.sendAlert(alertInternalError) + return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error()) + } + } + + switch { + case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA: + case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA: + default: + continue findCert + } + + if len(certReq.certificateAuthorities) == 0 { + // they gave us an empty list, so just take the + // first cert from c.config.Certificates + return &chain, nil + } + + for _, ca := range certReq.certificateAuthorities { + if bytes.Equal(x509Cert.RawIssuer, ca) { + return &chain, nil + } + } + } + } + + // No acceptable certificate found. Don't send a certificate. + return new(Certificate), nil +} + // clientSessionCacheKey returns a key used to cache sessionTickets that could // be used to resume previously negotiated TLS sessions with a server. func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { diff --git a/handshake_client_test.go b/handshake_client_test.go index 24d119e..d603915 100644 --- a/handshake_client_test.go +++ b/handshake_client_test.go @@ -1408,3 +1408,137 @@ func TestHandshakeRace(t *testing.T) { <-readDone } } + +func TestTLS11SignatureSchemes(t *testing.T) { + expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA + if expected != len(tls11SignatureSchemes) { + t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes)) + } +} + +var getClientCertificateTests = []struct { + setup func(*Config) + expectedClientError string + verify func(*testing.T, int, *ConnectionState) +}{ + { + func(clientConfig *Config) { + // Returning a Certificate with no certificate data + // should result in an empty message being sent to the + // server. + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig *Config) { + // With TLS 1.1, the SignatureSchemes should be + // synthesised from the supported certificate types. + clientConfig.MaxVersion = VersionTLS11 + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + if len(cri.SignatureSchemes) == 0 { + panic("empty SignatureSchemes") + } + return new(Certificate), nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.PeerCertificates); l != 0 { + t.Errorf("#%d: expected no certificates but got %d", testNum, l) + } + }, + }, + { + func(clientConfig *Config) { + // Returning an error should abort the handshake with + // that error. + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + return nil, errors.New("GetClientCertificate") + } + }, + "GetClientCertificate", + func(t *testing.T, testNum int, cs *ConnectionState) { + }, + }, + { + func(clientConfig *Config) { + clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) { + return &testConfig.Certificates[0], nil + } + }, + "", + func(t *testing.T, testNum int, cs *ConnectionState) { + if l := len(cs.VerifiedChains); l != 0 { + t.Errorf("#%d: expected some verified chains, but found none", testNum) + } + }, + }, +} + +func TestGetClientCertificate(t *testing.T) { + issuer, err := x509.ParseCertificate(testRSACertificateIssuer) + if err != nil { + panic(err) + } + + for i, test := range getClientCertificateTests { + serverConfig := testConfig.Clone() + serverConfig.ClientAuth = RequestClientCert + serverConfig.RootCAs = x509.NewCertPool() + serverConfig.RootCAs.AddCert(issuer) + + clientConfig := testConfig.Clone() + + test.setup(clientConfig) + + type serverResult struct { + cs ConnectionState + err error + } + + c, s := net.Pipe() + done := make(chan serverResult) + + go func() { + defer s.Close() + server := Server(s, serverConfig) + err := server.Handshake() + + var cs ConnectionState + if err == nil { + cs = server.ConnectionState() + } + done <- serverResult{cs, err} + }() + + clientErr := Client(c, clientConfig).Handshake() + c.Close() + + result := <-done + + if clientErr != nil { + if len(test.expectedClientError) == 0 { + t.Errorf("#%d: client error: %v", i, clientErr) + } else if got := clientErr.Error(); got != test.expectedClientError { + t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got) + } + } else if len(test.expectedClientError) > 0 { + t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError) + } else if err := result.err; err != nil { + t.Errorf("#%d: server error: %v", i, err) + } else { + test.verify(t, i, &result.cs) + } + } +} diff --git a/tls_test.go b/tls_test.go index b4fedd5..a0c0908 100644 --- a/tls_test.go +++ b/tls_test.go @@ -584,7 +584,7 @@ func TestClone(t *testing.T) { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) continue - case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate": + case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate": // DeepEqual can't compare functions. continue case "Certificates":