crypto/tls: Added dynamic alternative to NameToCertificate map for SNI
Revised version of https://golang.org/cl/81260045/ LGTM=agl R=golang-codereviews, gobot, agl, ox CC=golang-codereviews https://golang.org/cl/107400043
This commit is contained in:
parent
5e8d397065
commit
9e441ebf1d
61
common.go
61
common.go
@ -202,6 +202,32 @@ type ClientSessionCache interface {
|
|||||||
Put(sessionKey string, cs *ClientSessionState)
|
Put(sessionKey string, cs *ClientSessionState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientHelloInfo contains information from a ClientHello message in order to
|
||||||
|
// guide certificate selection in the GetCertificate callback.
|
||||||
|
type ClientHelloInfo struct {
|
||||||
|
// CipherSuites lists the CipherSuites supported by the client (e.g.
|
||||||
|
// TLS_RSA_WITH_RC4_128_SHA).
|
||||||
|
CipherSuites []uint16
|
||||||
|
|
||||||
|
// ServerName indicates the name of the server requested by the client
|
||||||
|
// in order to support virtual hosting. ServerName is only set if the
|
||||||
|
// client is using SNI (see
|
||||||
|
// http://tools.ietf.org/html/rfc4366#section-3.1).
|
||||||
|
ServerName string
|
||||||
|
|
||||||
|
// SupportedCurves lists the elliptic curves supported by the client.
|
||||||
|
// SupportedCurves is set only if the Supported Elliptic Curves
|
||||||
|
// Extension is being used (see
|
||||||
|
// http://tools.ietf.org/html/rfc4492#section-5.1.1).
|
||||||
|
SupportedCurves []CurveID
|
||||||
|
|
||||||
|
// SupportedPoints lists the point formats supported by the client.
|
||||||
|
// SupportedPoints is set only if the Supported Point Formats Extension
|
||||||
|
// is being used (see
|
||||||
|
// http://tools.ietf.org/html/rfc4492#section-5.1.2).
|
||||||
|
SupportedPoints []uint8
|
||||||
|
}
|
||||||
|
|
||||||
// A Config structure is used to configure a TLS client or server.
|
// A Config structure is used to configure a TLS client or server.
|
||||||
// After one has been passed to a TLS function it must not be
|
// After one has been passed to a TLS function it must not be
|
||||||
// modified. A Config may be reused; the tls package will also not
|
// modified. A Config may be reused; the tls package will also not
|
||||||
@ -230,6 +256,13 @@ type Config struct {
|
|||||||
// for all connections.
|
// for all connections.
|
||||||
NameToCertificate map[string]*Certificate
|
NameToCertificate map[string]*Certificate
|
||||||
|
|
||||||
|
// GetCertificate returns a Certificate based on the given
|
||||||
|
// ClientHelloInfo. If GetCertificate is nil or returns nil, then the
|
||||||
|
// certificate is retrieved from NameToCertificate. If
|
||||||
|
// NameToCertificate is nil, the first element of Certificates will be
|
||||||
|
// used.
|
||||||
|
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
|
||||||
|
|
||||||
// RootCAs defines the set of root certificate authorities
|
// RootCAs defines the set of root certificate authorities
|
||||||
// that clients use when verifying server certificates.
|
// that clients use when verifying server certificates.
|
||||||
// If RootCAs is nil, TLS uses the host's root CA set.
|
// If RootCAs is nil, TLS uses the host's root CA set.
|
||||||
@ -384,22 +417,28 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) {
|
|||||||
return vers, true
|
return vers, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCertificateForName returns the best certificate for the given name,
|
// getCertificate returns the best certificate for the given ClientHelloInfo,
|
||||||
// defaulting to the first element of c.Certificates if there are no good
|
// defaulting to the first element of c.Certificates.
|
||||||
// options.
|
func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) {
|
||||||
func (c *Config) getCertificateForName(name string) *Certificate {
|
if c.GetCertificate != nil {
|
||||||
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
cert, err := c.GetCertificate(clientHello)
|
||||||
// There's only one choice, so no point doing any work.
|
if cert != nil || err != nil {
|
||||||
return &c.Certificates[0]
|
return cert, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
name = strings.ToLower(name)
|
if len(c.Certificates) == 1 || c.NameToCertificate == nil {
|
||||||
|
// There's only one choice, so no point doing any work.
|
||||||
|
return &c.Certificates[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := strings.ToLower(clientHello.ServerName)
|
||||||
for len(name) > 0 && name[len(name)-1] == '.' {
|
for len(name) > 0 && name[len(name)-1] == '.' {
|
||||||
name = name[:len(name)-1]
|
name = name[:len(name)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
if cert, ok := c.NameToCertificate[name]; ok {
|
if cert, ok := c.NameToCertificate[name]; ok {
|
||||||
return cert
|
return cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// try replacing labels in the name with wildcards until we get a
|
// try replacing labels in the name with wildcards until we get a
|
||||||
@ -409,12 +448,12 @@ func (c *Config) getCertificateForName(name string) *Certificate {
|
|||||||
labels[i] = "*"
|
labels[i] = "*"
|
||||||
candidate := strings.Join(labels, ".")
|
candidate := strings.Join(labels, ".")
|
||||||
if cert, ok := c.NameToCertificate[candidate]; ok {
|
if cert, ok := c.NameToCertificate[candidate]; ok {
|
||||||
return cert
|
return cert, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If nothing matches, return the first certificate.
|
// If nothing matches, return the first certificate.
|
||||||
return &c.Certificates[0]
|
return &c.Certificates[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
|
// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate
|
||||||
|
22
conn_test.go
22
conn_test.go
@ -88,19 +88,31 @@ func TestCertificateSelection(t *testing.T) {
|
|||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
if n := pointerToIndex(config.getCertificateForName("example.com")); n != 0 {
|
certificateForName := func(name string) *Certificate {
|
||||||
|
clientHello := &ClientHelloInfo{
|
||||||
|
ServerName: name,
|
||||||
|
}
|
||||||
|
if cert, err := config.getCertificate(clientHello); err != nil {
|
||||||
|
t.Errorf("unable to get certificate for name '%s': %s", name, err)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return cert
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n := pointerToIndex(certificateForName("example.com")); n != 0 {
|
||||||
t.Errorf("example.com returned certificate %d, not 0", n)
|
t.Errorf("example.com returned certificate %d, not 0", n)
|
||||||
}
|
}
|
||||||
if n := pointerToIndex(config.getCertificateForName("bar.example.com")); n != 1 {
|
if n := pointerToIndex(certificateForName("bar.example.com")); n != 1 {
|
||||||
t.Errorf("bar.example.com returned certificate %d, not 1", n)
|
t.Errorf("bar.example.com returned certificate %d, not 1", n)
|
||||||
}
|
}
|
||||||
if n := pointerToIndex(config.getCertificateForName("foo.example.com")); n != 2 {
|
if n := pointerToIndex(certificateForName("foo.example.com")); n != 2 {
|
||||||
t.Errorf("foo.example.com returned certificate %d, not 2", n)
|
t.Errorf("foo.example.com returned certificate %d, not 2", n)
|
||||||
}
|
}
|
||||||
if n := pointerToIndex(config.getCertificateForName("foo.bar.example.com")); n != 3 {
|
if n := pointerToIndex(certificateForName("foo.bar.example.com")); n != 3 {
|
||||||
t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
|
t.Errorf("foo.bar.example.com returned certificate %d, not 3", n)
|
||||||
}
|
}
|
||||||
if n := pointerToIndex(config.getCertificateForName("foo.bar.baz.example.com")); n != 0 {
|
if n := pointerToIndex(certificateForName("foo.bar.baz.example.com")); n != 0 {
|
||||||
t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
|
t.Errorf("foo.bar.baz.example.com returned certificate %d, not 0", n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -186,7 +186,16 @@ Curves:
|
|||||||
}
|
}
|
||||||
hs.cert = &config.Certificates[0]
|
hs.cert = &config.Certificates[0]
|
||||||
if len(hs.clientHello.serverName) > 0 {
|
if len(hs.clientHello.serverName) > 0 {
|
||||||
hs.cert = config.getCertificateForName(hs.clientHello.serverName)
|
chi := &ClientHelloInfo{
|
||||||
|
CipherSuites: hs.clientHello.cipherSuites,
|
||||||
|
ServerName: hs.clientHello.serverName,
|
||||||
|
SupportedCurves: hs.clientHello.supportedCurves,
|
||||||
|
SupportedPoints: hs.clientHello.supportedPoints,
|
||||||
|
}
|
||||||
|
if hs.cert, err = config.getCertificate(chi); err != nil {
|
||||||
|
c.sendAlert(alertInternalError)
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
|
_, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey)
|
||||||
|
@ -257,6 +257,9 @@ type serverTest struct {
|
|||||||
expectedPeerCerts []string
|
expectedPeerCerts []string
|
||||||
// config, if not nil, contains a custom Config to use for this test.
|
// config, if not nil, contains a custom Config to use for this test.
|
||||||
config *Config
|
config *Config
|
||||||
|
// expectAlert, if true, indicates that a fatal alert should be returned
|
||||||
|
// when handshaking with the server.
|
||||||
|
expectAlert bool
|
||||||
// validate, if not nil, is a function that will be called with the
|
// validate, if not nil, is a function that will be called with the
|
||||||
// ConnectionState of the resulting connection. It returns false if the
|
// ConnectionState of the resulting connection. It returns false if the
|
||||||
// ConnectionState is unacceptable.
|
// ConnectionState is unacceptable.
|
||||||
@ -370,7 +373,9 @@ func (test *serverTest) run(t *testing.T, write bool) {
|
|||||||
if !write {
|
if !write {
|
||||||
flows, err := test.loadData()
|
flows, err := test.loadData()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
|
if !test.expectAlert {
|
||||||
|
t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for i, b := range flows {
|
for i, b := range flows {
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
@ -379,11 +384,17 @@ func (test *serverTest) run(t *testing.T, write bool) {
|
|||||||
}
|
}
|
||||||
bb := make([]byte, len(b))
|
bb := make([]byte, len(b))
|
||||||
n, err := io.ReadFull(clientConn, bb)
|
n, err := io.ReadFull(clientConn, bb)
|
||||||
if err != nil {
|
if test.expectAlert {
|
||||||
t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
|
if err == nil {
|
||||||
}
|
t.Fatal("Expected read failure but read succeeded")
|
||||||
if !bytes.Equal(b, bb) {
|
}
|
||||||
t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b, bb) {
|
||||||
|
t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
clientConn.Close()
|
clientConn.Close()
|
||||||
@ -562,6 +573,61 @@ func TestHandshakeServerSNI(t *testing.T) {
|
|||||||
runServerTestTLS12(t, test)
|
runServerTestTLS12(t, test)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestHandshakeServerSNICertForName is similar to TestHandshakeServerSNI, but
|
||||||
|
// tests the dynamic GetCertificate method
|
||||||
|
func TestHandshakeServerSNIGetCertificate(t *testing.T) {
|
||||||
|
config := *testConfig
|
||||||
|
|
||||||
|
// Replace the NameToCertificate map with a GetCertificate function
|
||||||
|
nameToCert := config.NameToCertificate
|
||||||
|
config.NameToCertificate = nil
|
||||||
|
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
|
||||||
|
cert, _ := nameToCert[clientHello.ServerName]
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
test := &serverTest{
|
||||||
|
name: "SNI",
|
||||||
|
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
|
||||||
|
config: &config,
|
||||||
|
}
|
||||||
|
runServerTestTLS12(t, test)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandshakeServerSNICertForNameNotFound is similar to
|
||||||
|
// TestHandshakeServerSNICertForName, but tests to make sure that when the
|
||||||
|
// GetCertificate method doesn't return a cert, we fall back to what's in
|
||||||
|
// the NameToCertificate map.
|
||||||
|
func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) {
|
||||||
|
config := *testConfig
|
||||||
|
|
||||||
|
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
test := &serverTest{
|
||||||
|
name: "SNI",
|
||||||
|
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
|
||||||
|
config: &config,
|
||||||
|
}
|
||||||
|
runServerTestTLS12(t, test)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandshakeServerSNICertForNameError tests to make sure that errors in
|
||||||
|
// GetCertificate result in a tls alert.
|
||||||
|
func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
|
||||||
|
config := *testConfig
|
||||||
|
|
||||||
|
config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
|
||||||
|
return nil, fmt.Errorf("Test error in GetCertificate")
|
||||||
|
}
|
||||||
|
test := &serverTest{
|
||||||
|
name: "SNI",
|
||||||
|
command: []string{"openssl", "s_client", "-no_ticket", "-cipher", "AES128-SHA", "-servername", "snitest.com"},
|
||||||
|
config: &config,
|
||||||
|
expectAlert: true,
|
||||||
|
}
|
||||||
|
runServerTestTLS12(t, test)
|
||||||
|
}
|
||||||
|
|
||||||
// TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with
|
// TestCipherSuiteCertPreferance ensures that we select an RSA ciphersuite with
|
||||||
// an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate.
|
// an RSA certificate and an ECDSA ciphersuite with an ECDSA certificate.
|
||||||
func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {
|
func TestCipherSuiteCertPreferenceECDSA(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user