Revised version of https://golang.org/cl/81260045/ LGTM=agl R=golang-codereviews, gobot, agl, ox CC=golang-codereviews https://golang.org/cl/107400043v1.2.3
@@ -202,6 +202,32 @@ type ClientSessionCache interface { | |||
Put(sessionKey string, cs *ClientSessionState) | |||
} | |||
// ClientHelloInfo contains information from a ClientHello message in order to | |||
// guide certificate selection in the GetCertificate callback. | |||
type ClientHelloInfo struct { | |||
// CipherSuites lists the CipherSuites supported by the client (e.g. | |||
// TLS_RSA_WITH_RC4_128_SHA). | |||
CipherSuites []uint16 | |||
// ServerName indicates the name of the server requested by the client | |||
// in order to support virtual hosting. ServerName is only set if the | |||
// client is using SNI (see | |||
// http://tools.ietf.org/html/rfc4366#section-3.1). | |||
ServerName string | |||
// SupportedCurves lists the elliptic curves supported by the client. | |||
// SupportedCurves is set only if the Supported Elliptic Curves | |||
// Extension is being used (see | |||
// http://tools.ietf.org/html/rfc4492#section-5.1.1). | |||
SupportedCurves []CurveID | |||
// SupportedPoints lists the point formats supported by the client. | |||
// SupportedPoints is set only if the Supported Point Formats Extension | |||
// is being used (see | |||
// http://tools.ietf.org/html/rfc4492#section-5.1.2). | |||
SupportedPoints []uint8 | |||
} | |||
// A Config structure is used to configure a TLS client or server. | |||
// After one has been passed to a TLS function it must not be | |||
// modified. A Config may be reused; the tls package will also not | |||
@@ -230,6 +256,13 @@ type Config struct { | |||
// for all connections. | |||
NameToCertificate map[string]*Certificate | |||
// GetCertificate returns a Certificate based on the given | |||
// ClientHelloInfo. If GetCertificate is nil or returns nil, then the | |||
// certificate is retrieved from NameToCertificate. If | |||
// NameToCertificate is nil, the first element of Certificates will be | |||
// used. | |||
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) | |||
// RootCAs defines the set of root certificate authorities | |||
// that clients use when verifying server certificates. | |||
// If RootCAs is nil, TLS uses the host's root CA set. | |||
@@ -384,22 +417,28 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) { | |||
return vers, true | |||
} | |||
// getCertificateForName returns the best certificate for the given name, | |||
// defaulting to the first element of c.Certificates if there are no good | |||
// options. | |||
func (c *Config) getCertificateForName(name string) *Certificate { | |||
// getCertificate returns the best certificate for the given ClientHelloInfo, | |||
// defaulting to the first element of c.Certificates. | |||
func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { | |||
if c.GetCertificate != nil { | |||
cert, err := c.GetCertificate(clientHello) | |||
if cert != nil || err != nil { | |||
return cert, err | |||
} | |||
} | |||
if len(c.Certificates) == 1 || c.NameToCertificate == nil { | |||
// There's only one choice, so no point doing any work. | |||
return &c.Certificates[0] | |||
return &c.Certificates[0], nil | |||
} | |||
name = strings.ToLower(name) | |||
name := strings.ToLower(clientHello.ServerName) | |||
for len(name) > 0 && name[len(name)-1] == '.' { | |||
name = name[:len(name)-1] | |||
} | |||
if cert, ok := c.NameToCertificate[name]; ok { | |||
return cert | |||
return cert, nil | |||
} | |||
// try replacing labels in the name with wildcards until we get a | |||
@@ -409,12 +448,12 @@ func (c *Config) getCertificateForName(name string) *Certificate { | |||
labels[i] = "*" | |||
candidate := strings.Join(labels, ".") | |||
if cert, ok := c.NameToCertificate[candidate]; ok { | |||
return cert | |||
return cert, nil | |||
} | |||
} | |||
// If nothing matches, return the first certificate. | |||
return &c.Certificates[0] | |||
return &c.Certificates[0], nil | |||
} | |||
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate | |||
@@ -88,19 +88,31 @@ func TestCertificateSelection(t *testing.T) { | |||
return -1 | |||
} | |||
if n := pointerToIndex(config.getCertificateForName("example.com")); n != 0 { | |||
certificateForName := func(name string) *Certificate { | |||
clientHello := &ClientHelloInfo{ | |||
ServerName: name, | |||
} | |||
if cert, err := config.getCertificate(clientHello); err != nil { | |||
t.Errorf("unable to get certificate for name '%s': %s", name, err) | |||
return nil | |||
} else { | |||
return cert | |||
} | |||
} | |||
if n := pointerToIndex(certificateForName("example.com")); n != 0 { | |||
t.Errorf("example.com returned certificate %d, not 0", n) | |||
} | |||
if n := pointerToIndex(config.getCertificateForName("bar.example.com")); n != 1 { | |||
if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 { | |||
t.Errorf("bar.example.com returned certificate %d, not 1", n) | |||
} | |||
if n := pointerToIndex(config.getCertificateForName("foo.example.com")); n != 2 { | |||
if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 { | |||
t.Errorf("foo.example.com returned certificate %d, not 2", n) | |||
} | |||
if n := pointerToIndex(config.getCertificateForName("foo.bar.example.com")); n != 3 { | |||
if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 { | |||
t.Errorf("foo.bar.example.com returned certificate %d, not 3", n) | |||
} | |||
if n := pointerToIndex(config.getCertificateForName("foo.bar.baz.example.com")); n != 0 { | |||
if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 { | |||
t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n) | |||
} | |||
} |
@@ -186,7 +186,16 @@ Curves: | |||
} | |||
hs.cert = &config.Certificates[0] | |||
if len(hs.clientHello.serverName) > 0 { | |||
hs.cert = config.getCertificateForName(hs.clientHello.serverName) | |||
chi := &ClientHelloInfo{ | |||
CipherSuites: hs.clientHello.cipherSuites, | |||
ServerName: hs.clientHello.serverName, | |||
SupportedCurves: hs.clientHello.supportedCurves, | |||
SupportedPoints: hs.clientHello.supportedPoints, | |||
} | |||
if hs.cert, err = config.getCertificate(chi); err != nil { | |||
c.sendAlert(alertInternalError) | |||
return false, err | |||
} | |||
} | |||
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey) | |||
@@ -257,6 +257,9 @@ type serverTest struct { | |||
expectedPeerCerts []string | |||
// config, if not nil, contains a custom Config to use for this test. | |||
config *Config | |||
// expectAlert, if true, indicates that a fatal alert should be returned | |||
// when handshaking with the server. | |||
expectAlert bool | |||
// validate, if not nil, is a function that will be called with the | |||
// ConnectionState of the resulting connection. It returns false if the | |||
// ConnectionState is unacceptable. | |||
@@ -370,7 +373,9 @@ func (test *serverTest) run(t *testing.T, write bool) { | |||
if !write { | |||
flows, err := test.loadData() | |||
if err != nil { | |||
t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) | |||
if !test.expectAlert { | |||
t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath()) | |||
} | |||
} | |||
for i, b := range flows { | |||
if i%2 == 0 { | |||
@@ -379,11 +384,17 @@ func (test *serverTest) run(t *testing.T, write bool) { | |||
} | |||
bb := make([]byte, len(b)) | |||
n, err := io.ReadFull(clientConn, bb) | |||
if err != nil { | |||
t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) | |||
} | |||
if !bytes.Equal(b, bb) { | |||
t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) | |||
if test.expectAlert { | |||
if err == nil { | |||
t.Fatal("Expected read failure but read succeeded") | |||
} | |||
} else { | |||
if err != nil { | |||
t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b) | |||
} | |||
if !bytes.Equal(b, bb) { | |||
t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b) | |||
} | |||
} | |||
} | |||
clientConn.Close() | |||
@@ -562,6 +573,61 @@ func TestHandshakeServerSNI(t *testing.T) { | |||
runServerTestTLS12(t, test) | |||
} | |||
// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but | |||
// tests the dynamic GetCertificate method | |||
func TestHandshakeServerSNIGetCertificate(t *testing.T) { | |||
config := *testConfig | |||
// Replace the NameToCertificate map with a GetCertificate function | |||
nameToCert := config.NameToCertificate | |||
config.NameToCertificate = nil | |||
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { | |||
cert, _ := nameToCert[clientHello.ServerName] | |||
return cert, nil | |||
} | |||
test := &serverTest{ | |||
name: "SNI", | |||
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, | |||
config: &config, | |||
} | |||
runServerTestTLS12(t, test) | |||
} | |||
// TestHandshakeServerSNICertForNameNotFound is similar to | |||
// TestHandshakeServerSNICertForName, but tests to make sure that when the | |||
// GetCertificate method doesn't return a cert, we fall back to what's in | |||
// the NameToCertificate map. | |||
func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) { | |||
config := *testConfig | |||
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { | |||
return nil, nil | |||
} | |||
test := &serverTest{ | |||
name: "SNI", | |||
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, | |||
config: &config, | |||
} | |||
runServerTestTLS12(t, test) | |||
} | |||
// TestHandshakeServerSNICertForNameError tests to make sure that errors in | |||
// GetCertificate result in a tls alert. | |||
func TestHandshakeServerSNIGetCertificateError(t *testing.T) { | |||
config := *testConfig | |||
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) { | |||
return nil, fmt.Errorf("Test error in GetCertificate") | |||
} | |||
test := &serverTest{ | |||
name: "SNI", | |||
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"}, | |||
config: &config, | |||
expectAlert: true, | |||
} | |||
runServerTestTLS12(t, test) | |||
} | |||
// TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with | |||
// an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate. | |||
func TestCipherSuiteCertPreferenceECDSA(t *testing.T) { | |||