diff --git a/common.go b/common.go index 929c8ef..352df25 100644 --- a/common.go +++ b/common.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "crypto/sha512" "crypto/x509" + "errors" "fmt" "io" "math/big" @@ -270,10 +271,12 @@ type Config struct { NameToCertificate map[string]*Certificate // GetCertificate returns a Certificate based on the given - // ClientHelloInfo. If GetCertificate is nil or returns nil, then the - // certificate is retrieved from NameToCertificate. If - // NameToCertificate is nil, the first element of Certificates will be - // used. + // ClientHelloInfo. It will only be called if the client supplies SNI + // information or if Certificates is empty. + // + // If GetCertificate is nil or returns nil, then the certificate is + // retrieved from NameToCertificate. If NameToCertificate is nil, the + // first element of Certificates will be used. GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) // RootCAs defines the set of root certificate authorities @@ -500,13 +503,18 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) { // getCertificate returns the best certificate for the given ClientHelloInfo, // defaulting to the first element of c.Certificates. func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { - if c.GetCertificate != nil { + if c.GetCertificate != nil && + (len(c.Certificates) == 0 || len(clientHello.ServerName) > 0) { cert, err := c.GetCertificate(clientHello) if cert != nil || err != nil { return cert, err } } + if len(c.Certificates) == 0 { + return nil, errors.New("crypto/tls: no certificates configured") + } + if len(c.Certificates) == 1 || c.NameToCertificate == nil { // There's only one choice, so no point doing any work. return &c.Certificates[0], nil diff --git a/handshake_server.go b/handshake_server.go index e6e0b02..26c3333 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -189,22 +189,14 @@ Curves: } } - if len(config.Certificates) == 0 { + if hs.cert, err = config.getCertificate(&ClientHelloInfo{ + CipherSuites: hs.clientHello.cipherSuites, + ServerName: hs.clientHello.serverName, + SupportedCurves: hs.clientHello.supportedCurves, + SupportedPoints: hs.clientHello.supportedPoints, + }); err != nil { c.sendAlert(alertInternalError) - return false, errors.New("tls: no certificates configured") - } - hs.cert = &config.Certificates[0] - if len(hs.clientHello.serverName) > 0 { - chi := &ClientHelloInfo{ - CipherSuites: hs.clientHello.cipherSuites, - ServerName: hs.clientHello.serverName, - SupportedCurves: hs.clientHello.supportedCurves, - SupportedPoints: hs.clientHello.supportedPoints, - } - if hs.cert, err = config.getCertificate(chi); err != nil { - c.sendAlert(alertInternalError) - return false, err - } + return false, err } if hs.clientHello.scts { hs.hello.scts = hs.cert.SignedCertificateTimestamps diff --git a/handshake_server_test.go b/handshake_server_test.go index 3183c6f..bbf105d 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -796,6 +796,36 @@ func TestHandshakeServerSNIGetCertificateError(t *testing.T) { testClientHelloFailure(t, &serverConfig, clientHello, errMsg) } +// TestHandshakeServerEmptyCertificates tests that GetCertificates is called in +// the case that Certificates is empty, even without SNI. +func TestHandshakeServerEmptyCertificates(t *testing.T) { + const errMsg = "TestHandshakeServerEmptyCertificates error" + + serverConfig := *testConfig + serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { + return nil, errors.New(errMsg) + } + serverConfig.Certificates = nil + + clientHello := &clientHelloMsg{ + vers: 0x0301, + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{0}, + } + testClientHelloFailure(t, &serverConfig, clientHello, errMsg) + + // With an empty Certificates and a nil GetCertificate, the server + // should always return a “no certificates” error. + serverConfig.GetCertificate = nil + + clientHello = &clientHelloMsg{ + vers: 0x0301, + cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + compressionMethods: []uint8{0}, + } + testClientHelloFailure(t, &serverConfig, clientHello, "no certificates") +} + // TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with // an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate. func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {