GetConfigForClient allows the tls.Config to be updated on a per-client basis. Fixes #16066. Fixes #15707. Fixes #15699. Change-Id: I2c675a443d557f969441226729f98502b38901ea Reviewed-on: https://go-review.googlesource.com/30790 Run-TryBot: Adam Langley <agl@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>v1.2.3
@@ -303,7 +303,27 @@ type Config struct { | |||
// If GetCertificate is nil or returns nil, then the certificate is | |||
// retrieved from NameToCertificate. If NameToCertificate is nil, the | |||
// first element of Certificates will be used. | |||
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) | |||
GetCertificate func(*ClientHelloInfo) (*Certificate, error) | |||
// GetConfigForClient, if not nil, is called after a ClientHello is | |||
// received from a client. It may return a non-nil Config in order to | |||
// change the Config that will be used to handle this connection. If | |||
// the returned Config is nil, the original Config will be used. The | |||
// Config returned by this callback may not be subsequently modified. | |||
// | |||
// If GetConfigForClient is nil, the Config passed to Server() will be | |||
// used for all connections. | |||
// | |||
// Uniquely for the fields in the returned Config, session ticket keys | |||
// will be duplicated from the original Config if not set. | |||
// Specifically, if SetSessionTicketKeys was called on the original | |||
// config but not on the returned config then the ticket keys from the | |||
// original config will be copied into the new config before use. | |||
// Otherwise, if SessionTicketKey was set in the original config but | |||
// not in the returned config then it will be copied into the returned | |||
// config before use. If neither of those cases applies then the key | |||
// material from the returned config will be used for session tickets. | |||
GetConfigForClient func(*ClientHelloInfo) (*Config, error) | |||
// RootCAs defines the set of root certificate authorities | |||
// that clients use when verifying server certificates. | |||
@@ -398,13 +418,17 @@ type Config struct { | |||
serverInitOnce sync.Once // guards calling (*Config).serverInit | |||
// mutex protects sessionTicketKeys | |||
// mutex protects sessionTicketKeys and originalConfig. | |||
mutex sync.RWMutex | |||
// sessionTicketKeys contains zero or more ticket keys. If the length | |||
// is zero, SessionTicketsDisabled must be true. The first key is used | |||
// for new tickets and any subsequent keys can be used to decrypt old | |||
// tickets. | |||
sessionTicketKeys []ticketKey | |||
// originalConfig is set to the Config that was passed to Server if | |||
// this Config is returned by a GetConfigForClient callback. It's used | |||
// by serverInit in order to copy session ticket keys if needed. | |||
originalConfig *Config | |||
} | |||
// ticketKeyNameLen is the number of bytes of identifier that is prepended to | |||
@@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) { | |||
// Clone returns a shallow clone of c. | |||
// Only the exported fields are copied. | |||
func (c *Config) Clone() *Config { | |||
var sessionTicketKeys []ticketKey | |||
c.mutex.RLock() | |||
sessionTicketKeys = c.sessionTicketKeys | |||
c.mutex.RUnlock() | |||
return &Config{ | |||
Rand: c.Rand, | |||
Time: c.Time, | |||
Certificates: c.Certificates, | |||
NameToCertificate: c.NameToCertificate, | |||
GetCertificate: c.GetCertificate, | |||
GetConfigForClient: c.GetConfigForClient, | |||
RootCAs: c.RootCAs, | |||
NextProtos: c.NextProtos, | |||
ServerName: c.ServerName, | |||
@@ -457,6 +487,8 @@ func (c *Config) Clone() *Config { | |||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, | |||
Renegotiation: c.Renegotiation, | |||
KeyLogWriter: c.KeyLogWriter, | |||
sessionTicketKeys: sessionTicketKeys, | |||
// originalConfig is deliberately not duplicated. | |||
} | |||
} | |||
@@ -465,6 +497,11 @@ func (c *Config) serverInit() { | |||
return | |||
} | |||
var originalConfig *Config | |||
c.mutex.Lock() | |||
originalConfig, c.originalConfig = c.originalConfig, nil | |||
c.mutex.Unlock() | |||
alreadySet := false | |||
for _, b := range c.SessionTicketKey { | |||
if b != 0 { | |||
@@ -474,13 +511,21 @@ func (c *Config) serverInit() { | |||
} | |||
if !alreadySet { | |||
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { | |||
if originalConfig != nil { | |||
copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:]) | |||
} else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { | |||
c.SessionTicketsDisabled = true | |||
return | |||
} | |||
} | |||
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} | |||
if originalConfig != nil { | |||
originalConfig.mutex.RLock() | |||
c.sessionTicketKeys = originalConfig.sessionTicketKeys | |||
originalConfig.mutex.RUnlock() | |||
} else { | |||
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} | |||
} | |||
} | |||
func (c *Config) ticketKeys() []ticketKey { | |||
@@ -37,11 +37,9 @@ type serverHandshakeState struct { | |||
// serverHandshake performs a TLS handshake as a server. | |||
// c.out.Mutex <= L; c.handshakeMutex <= L. | |||
func (c *Conn) serverHandshake() error { | |||
config := c.config | |||
// If this is the first server handshake, we generate a random key to | |||
// encrypt the tickets with. | |||
config.serverInitOnce.Do(config.serverInit) | |||
c.config.serverInitOnce.Do(c.config.serverInit) | |||
hs := serverHandshakeState{ | |||
c: c, | |||
@@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error { | |||
// readClientHello reads a ClientHello message from the client and decides | |||
// whether we will perform session resumption. | |||
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { | |||
config := hs.c.config | |||
c := hs.c | |||
msg, err := c.readHandshake() | |||
@@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { | |||
c.sendAlert(alertUnexpectedMessage) | |||
return false, unexpectedMessageError(hs.clientHello, msg) | |||
} | |||
c.vers, ok = config.mutualVersion(hs.clientHello.vers) | |||
clientHelloInfo := &ClientHelloInfo{ | |||
CipherSuites: hs.clientHello.cipherSuites, | |||
ServerName: hs.clientHello.serverName, | |||
SupportedCurves: hs.clientHello.supportedCurves, | |||
SupportedPoints: hs.clientHello.supportedPoints, | |||
} | |||
if c.config.GetConfigForClient != nil { | |||
if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil { | |||
c.sendAlert(alertInternalError) | |||
return false, err | |||
} else if newConfig != nil { | |||
newConfig.mutex.Lock() | |||
newConfig.originalConfig = c.config | |||
newConfig.mutex.Unlock() | |||
newConfig.serverInitOnce.Do(newConfig.serverInit) | |||
c.config = newConfig | |||
} | |||
} | |||
c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) | |||
if !ok { | |||
c.sendAlert(alertProtocolVersion) | |||
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) | |||
@@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { | |||
hs.hello = new(serverHelloMsg) | |||
supportedCurve := false | |||
preferredCurves := config.curvePreferences() | |||
preferredCurves := c.config.curvePreferences() | |||
Curves: | |||
for _, curve := range hs.clientHello.supportedCurves { | |||
for _, supported := range preferredCurves { | |||
@@ -171,7 +190,7 @@ Curves: | |||
hs.hello.vers = c.vers | |||
hs.hello.random = make([]byte, 32) | |||
_, err = io.ReadFull(config.rand(), hs.hello.random) | |||
_, err = io.ReadFull(c.config.rand(), hs.hello.random) | |||
if err != nil { | |||
c.sendAlert(alertInternalError) | |||
return false, err | |||
@@ -196,20 +215,15 @@ Curves: | |||
} else { | |||
// Although sending an empty NPN extension is reasonable, Firefox has | |||
// had a bug around this. Best to send nothing at all if | |||
// config.NextProtos is empty. See | |||
// c.config.NextProtos is empty. See | |||
// https://golang.org/issue/5445. | |||
if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 { | |||
if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 { | |||
hs.hello.nextProtoNeg = true | |||
hs.hello.nextProtos = config.NextProtos | |||
hs.hello.nextProtos = c.config.NextProtos | |||
} | |||
} | |||
hs.cert, err = config.getCertificate(&ClientHelloInfo{ | |||
CipherSuites: hs.clientHello.cipherSuites, | |||
ServerName: hs.clientHello.serverName, | |||
SupportedCurves: hs.clientHello.supportedCurves, | |||
SupportedPoints: hs.clientHello.supportedPoints, | |||
}) | |||
hs.cert, err = c.config.getCertificate(clientHelloInfo) | |||
if err != nil { | |||
c.sendAlert(alertInternalError) | |||
return false, err | |||
@@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error { | |||
} | |||
func (hs *serverHandshakeState) doFullHandshake() error { | |||
config := hs.c.config | |||
c := hs.c | |||
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { | |||
hs.hello.ocspStapling = true | |||
} | |||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled | |||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled | |||
hs.hello.cipherSuite = hs.suite.id | |||
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) | |||
if config.ClientAuth == NoClientCert { | |||
if c.config.ClientAuth == NoClientCert { | |||
// No need to keep a full record of the handshake if client | |||
// certificates won't be used. | |||
hs.finishedHash.discardHandshakeBuffer() | |||
@@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
} | |||
keyAgreement := hs.suite.ka(c.vers) | |||
skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello) | |||
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) | |||
if err != nil { | |||
c.sendAlert(alertHandshakeFailure) | |||
return err | |||
@@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
} | |||
} | |||
if config.ClientAuth >= RequestClientCert { | |||
if c.config.ClientAuth >= RequestClientCert { | |||
// Request a client certificate | |||
certReq := new(certificateRequestMsg) | |||
certReq.certificateTypes = []byte{ | |||
@@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
// to our request. When we know the CAs we trust, then | |||
// we can send them down, so that the client can choose | |||
// an appropriate certificate to give to us. | |||
if config.ClientCAs != nil { | |||
certReq.certificateAuthorities = config.ClientCAs.Subjects() | |||
if c.config.ClientCAs != nil { | |||
certReq.certificateAuthorities = c.config.ClientCAs.Subjects() | |||
} | |||
hs.finishedHash.Write(certReq.marshal()) | |||
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { | |||
@@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
var ok bool | |||
// If we requested a client certificate, then the client must send a | |||
// certificate message, even if it's empty. | |||
if config.ClientAuth >= RequestClientCert { | |||
if c.config.ClientAuth >= RequestClientCert { | |||
if certMsg, ok = msg.(*certificateMsg); !ok { | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(certMsg, msg) | |||
@@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
if len(certMsg.certificates) == 0 { | |||
// The client didn't actually send a certificate | |||
switch config.ClientAuth { | |||
switch c.config.ClientAuth { | |||
case RequireAnyClientCert, RequireAndVerifyClientCert: | |||
c.sendAlert(alertBadCertificate) | |||
return errors.New("tls: client didn't provide a certificate") | |||
@@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
} | |||
hs.finishedHash.Write(ckx.marshal()) | |||
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers) | |||
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) | |||
if err != nil { | |||
c.sendAlert(alertHandshakeFailure) | |||
return err | |||
} | |||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) | |||
if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil { | |||
if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil { | |||
c.sendAlert(alertInternalError) | |||
return err | |||
} | |||
@@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) { | |||
} | |||
} | |||
var getConfigForClientTests = []struct { | |||
setup func(config *Config) | |||
callback func(clientHello *ClientHelloInfo) (*Config, error) | |||
errorSubstring string | |||
verify func(config *Config) error | |||
}{ | |||
{ | |||
nil, | |||
func(clientHello *ClientHelloInfo) (*Config, error) { | |||
return nil, nil | |||
}, | |||
"", | |||
nil, | |||
}, | |||
{ | |||
nil, | |||
func(clientHello *ClientHelloInfo) (*Config, error) { | |||
return nil, errors.New("should bubble up") | |||
}, | |||
"should bubble up", | |||
nil, | |||
}, | |||
{ | |||
nil, | |||
func(clientHello *ClientHelloInfo) (*Config, error) { | |||
config := testConfig.Clone() | |||
// Setting a maximum version of TLS 1.1 should cause | |||
// the handshake to fail. | |||
config.MaxVersion = VersionTLS11 | |||
return config, nil | |||
}, | |||
"version 301 when expecting version 302", | |||
nil, | |||
}, | |||
{ | |||
func(config *Config) { | |||
for i := range config.SessionTicketKey { | |||
config.SessionTicketKey[i] = byte(i) | |||
} | |||
config.sessionTicketKeys = nil | |||
}, | |||
func(clientHello *ClientHelloInfo) (*Config, error) { | |||
config := testConfig.Clone() | |||
for i := range config.SessionTicketKey { | |||
config.SessionTicketKey[i] = 0 | |||
} | |||
config.sessionTicketKeys = nil | |||
return config, nil | |||
}, | |||
"", | |||
func(config *Config) error { | |||
// The value of SessionTicketKey should have been | |||
// duplicated into the per-connection Config. | |||
for i := range config.SessionTicketKey { | |||
if b := config.SessionTicketKey[i]; b != byte(i) { | |||
return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b) | |||
} | |||
} | |||
return nil | |||
}, | |||
}, | |||
{ | |||
func(config *Config) { | |||
var dummyKey [32]byte | |||
for i := range dummyKey { | |||
dummyKey[i] = byte(i) | |||
} | |||
config.SetSessionTicketKeys([][32]byte{dummyKey}) | |||
}, | |||
func(clientHello *ClientHelloInfo) (*Config, error) { | |||
config := testConfig.Clone() | |||
config.sessionTicketKeys = nil | |||
return config, nil | |||
}, | |||
"", | |||
func(config *Config) error { | |||
// The session ticket keys should have been duplicated | |||
// into the per-connection Config. | |||
if l := len(config.sessionTicketKeys); l != 1 { | |||
return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l) | |||
} | |||
return nil | |||
}, | |||
}, | |||
} | |||
func TestGetConfigForClient(t *testing.T) { | |||
serverConfig := testConfig.Clone() | |||
clientConfig := testConfig.Clone() | |||
clientConfig.MinVersion = VersionTLS12 | |||
for i, test := range getConfigForClientTests { | |||
if test.setup != nil { | |||
test.setup(serverConfig) | |||
} | |||
var configReturned *Config | |||
serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) { | |||
config, err := test.callback(clientHello) | |||
configReturned = config | |||
return config, err | |||
} | |||
c, s := net.Pipe() | |||
done := make(chan error) | |||
go func() { | |||
defer s.Close() | |||
done <- Server(s, serverConfig).Handshake() | |||
}() | |||
clientErr := Client(c, clientConfig).Handshake() | |||
c.Close() | |||
serverErr := <-done | |||
if len(test.errorSubstring) == 0 { | |||
if serverErr != nil || clientErr != nil { | |||
t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr) | |||
} | |||
if test.verify != nil { | |||
if err := test.verify(configReturned); err != nil { | |||
t.Errorf("#%d: verify returned error: %v", i, err) | |||
} | |||
} | |||
} else { | |||
if serverErr == nil { | |||
t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring) | |||
} else if !strings.Contains(serverErr.Error(), test.errorSubstring) { | |||
t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr) | |||
} | |||
} | |||
} | |||
} | |||
func bigFromString(s string) *big.Int { | |||
ret := new(big.Int) | |||
ret.SetString(s, 10) | |||
@@ -477,7 +477,7 @@ func TestClone(t *testing.T) { | |||
case "Rand": | |||
f.Set(reflect.ValueOf(io.Reader(os.Stdin))) | |||
continue | |||
case "Time", "GetCertificate": | |||
case "Time", "GetCertificate", "GetConfigForClient": | |||
// DeepEqual can't compare functions. | |||
continue | |||
case "Certificates": | |||