diff --git a/common.go b/common.go index 021f475..7199cd9 100644 --- a/common.go +++ b/common.go @@ -303,7 +303,27 @@ type Config struct { // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // first element of Certificates will be used. - GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error) + GetCertificate func(*ClientHelloInfo) (*Certificate, error) + + // GetConfigForClient, if not nil, is called after a ClientHello is + // received from a client. It may return a non-nil Config in order to + // change the Config that will be used to handle this connection. If + // the returned Config is nil, the original Config will be used. The + // Config returned by this callback may not be subsequently modified. + // + // If GetConfigForClient is nil, the Config passed to Server() will be + // used for all connections. + // + // Uniquely for the fields in the returned Config, session ticket keys + // will be duplicated from the original Config if not set. + // Specifically, if SetSessionTicketKeys was called on the original + // config but not on the returned config then the ticket keys from the + // original config will be copied into the new config before use. + // Otherwise, if SessionTicketKey was set in the original config but + // not in the returned config then it will be copied into the returned + // config before use. If neither of those cases applies then the key + // material from the returned config will be used for session tickets. + GetConfigForClient func(*ClientHelloInfo) (*Config, error) // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. @@ -398,13 +418,17 @@ type Config struct { serverInitOnce sync.Once // guards calling (*Config).serverInit - // mutex protects sessionTicketKeys + // mutex protects sessionTicketKeys and originalConfig. mutex sync.RWMutex // sessionTicketKeys contains zero or more ticket keys. If the length // is zero, SessionTicketsDisabled must be true. The first key is used // for new tickets and any subsequent keys can be used to decrypt old // tickets. sessionTicketKeys []ticketKey + // originalConfig is set to the Config that was passed to Server if + // this Config is returned by a GetConfigForClient callback. It's used + // by serverInit in order to copy session ticket keys if needed. + originalConfig *Config } // ticketKeyNameLen is the number of bytes of identifier that is prepended to @@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) { // Clone returns a shallow clone of c. // Only the exported fields are copied. func (c *Config) Clone() *Config { + var sessionTicketKeys []ticketKey + c.mutex.RLock() + sessionTicketKeys = c.sessionTicketKeys + c.mutex.RUnlock() + return &Config{ Rand: c.Rand, Time: c.Time, Certificates: c.Certificates, NameToCertificate: c.NameToCertificate, GetCertificate: c.GetCertificate, + GetConfigForClient: c.GetConfigForClient, RootCAs: c.RootCAs, NextProtos: c.NextProtos, ServerName: c.ServerName, @@ -457,6 +487,8 @@ func (c *Config) Clone() *Config { DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, KeyLogWriter: c.KeyLogWriter, + sessionTicketKeys: sessionTicketKeys, + // originalConfig is deliberately not duplicated. } } @@ -465,6 +497,11 @@ func (c *Config) serverInit() { return } + var originalConfig *Config + c.mutex.Lock() + originalConfig, c.originalConfig = c.originalConfig, nil + c.mutex.Unlock() + alreadySet := false for _, b := range c.SessionTicketKey { if b != 0 { @@ -474,13 +511,21 @@ func (c *Config) serverInit() { } if !alreadySet { - if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + if originalConfig != nil { + copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:]) + } else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { c.SessionTicketsDisabled = true return } } - c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} + if originalConfig != nil { + originalConfig.mutex.RLock() + c.sessionTicketKeys = originalConfig.sessionTicketKeys + originalConfig.mutex.RUnlock() + } else { + c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)} + } } func (c *Config) ticketKeys() []ticketKey { diff --git a/handshake_server.go b/handshake_server.go index 17141ae..630a99a 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -37,11 +37,9 @@ type serverHandshakeState struct { // serverHandshake performs a TLS handshake as a server. // c.out.Mutex <= L; c.handshakeMutex <= L. func (c *Conn) serverHandshake() error { - config := c.config - // If this is the first server handshake, we generate a random key to // encrypt the tickets with. - config.serverInitOnce.Do(config.serverInit) + c.config.serverInitOnce.Do(c.config.serverInit) hs := serverHandshakeState{ c: c, @@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error { // readClientHello reads a ClientHello message from the client and decides // whether we will perform session resumption. func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { - config := hs.c.config c := hs.c msg, err := c.readHandshake() @@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { c.sendAlert(alertUnexpectedMessage) return false, unexpectedMessageError(hs.clientHello, msg) } - c.vers, ok = config.mutualVersion(hs.clientHello.vers) + + clientHelloInfo := &ClientHelloInfo{ + CipherSuites: hs.clientHello.cipherSuites, + ServerName: hs.clientHello.serverName, + SupportedCurves: hs.clientHello.supportedCurves, + SupportedPoints: hs.clientHello.supportedPoints, + } + + if c.config.GetConfigForClient != nil { + if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil { + c.sendAlert(alertInternalError) + return false, err + } else if newConfig != nil { + newConfig.mutex.Lock() + newConfig.originalConfig = c.config + newConfig.mutex.Unlock() + + newConfig.serverInitOnce.Do(newConfig.serverInit) + c.config = newConfig + } + } + + c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) if !ok { c.sendAlert(alertProtocolVersion) return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) @@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { hs.hello = new(serverHelloMsg) supportedCurve := false - preferredCurves := config.curvePreferences() + preferredCurves := c.config.curvePreferences() Curves: for _, curve := range hs.clientHello.supportedCurves { for _, supported := range preferredCurves { @@ -171,7 +190,7 @@ Curves: hs.hello.vers = c.vers hs.hello.random = make([]byte, 32) - _, err = io.ReadFull(config.rand(), hs.hello.random) + _, err = io.ReadFull(c.config.rand(), hs.hello.random) if err != nil { c.sendAlert(alertInternalError) return false, err @@ -196,20 +215,15 @@ Curves: } else { // Although sending an empty NPN extension is reasonable, Firefox has // had a bug around this. Best to send nothing at all if - // config.NextProtos is empty. See + // c.config.NextProtos is empty. See // https://golang.org/issue/5445. - if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 { + if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 { hs.hello.nextProtoNeg = true - hs.hello.nextProtos = config.NextProtos + hs.hello.nextProtos = c.config.NextProtos } } - hs.cert, err = config.getCertificate(&ClientHelloInfo{ - CipherSuites: hs.clientHello.cipherSuites, - ServerName: hs.clientHello.serverName, - SupportedCurves: hs.clientHello.supportedCurves, - SupportedPoints: hs.clientHello.supportedPoints, - }) + hs.cert, err = c.config.getCertificate(clientHelloInfo) if err != nil { c.sendAlert(alertInternalError) return false, err @@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error { } func (hs *serverHandshakeState) doFullHandshake() error { - config := hs.c.config c := hs.c if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { hs.hello.ocspStapling = true } - hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled hs.hello.cipherSuite = hs.suite.id hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite) - if config.ClientAuth == NoClientCert { + if c.config.ClientAuth == NoClientCert { // No need to keep a full record of the handshake if client // certificates won't be used. hs.finishedHash.discardHandshakeBuffer() @@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { } keyAgreement := hs.suite.ka(c.vers) - skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello) + skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello) if err != nil { c.sendAlert(alertHandshakeFailure) return err @@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { } } - if config.ClientAuth >= RequestClientCert { + if c.config.ClientAuth >= RequestClientCert { // Request a client certificate certReq := new(certificateRequestMsg) certReq.certificateTypes = []byte{ @@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { // to our request. When we know the CAs we trust, then // we can send them down, so that the client can choose // an appropriate certificate to give to us. - if config.ClientCAs != nil { - certReq.certificateAuthorities = config.ClientCAs.Subjects() + if c.config.ClientCAs != nil { + certReq.certificateAuthorities = c.config.ClientCAs.Subjects() } hs.finishedHash.Write(certReq.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil { @@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { var ok bool // If we requested a client certificate, then the client must send a // certificate message, even if it's empty. - if config.ClientAuth >= RequestClientCert { + if c.config.ClientAuth >= RequestClientCert { if certMsg, ok = msg.(*certificateMsg); !ok { c.sendAlert(alertUnexpectedMessage) return unexpectedMessageError(certMsg, msg) @@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error { if len(certMsg.certificates) == 0 { // The client didn't actually send a certificate - switch config.ClientAuth { + switch c.config.ClientAuth { case RequireAnyClientCert, RequireAndVerifyClientCert: c.sendAlert(alertBadCertificate) return errors.New("tls: client didn't provide a certificate") @@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(ckx.marshal()) - preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers) + preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) if err != nil { c.sendAlert(alertHandshakeFailure) return err } hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) - if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil { + if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil { c.sendAlert(alertInternalError) return err } diff --git a/handshake_server_test.go b/handshake_server_test.go index 38d8275..94193bd 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) { } } +var getConfigForClientTests = []struct { + setup func(config *Config) + callback func(clientHello *ClientHelloInfo) (*Config, error) + errorSubstring string + verify func(config *Config) error +}{ + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + return nil, nil + }, + "", + nil, + }, + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + return nil, errors.New("should bubble up") + }, + "should bubble up", + nil, + }, + { + nil, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + // Setting a maximum version of TLS 1.1 should cause + // the handshake to fail. + config.MaxVersion = VersionTLS11 + return config, nil + }, + "version 301 when expecting version 302", + nil, + }, + { + func(config *Config) { + for i := range config.SessionTicketKey { + config.SessionTicketKey[i] = byte(i) + } + config.sessionTicketKeys = nil + }, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + for i := range config.SessionTicketKey { + config.SessionTicketKey[i] = 0 + } + config.sessionTicketKeys = nil + return config, nil + }, + "", + func(config *Config) error { + // The value of SessionTicketKey should have been + // duplicated into the per-connection Config. + for i := range config.SessionTicketKey { + if b := config.SessionTicketKey[i]; b != byte(i) { + return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b) + } + } + return nil + }, + }, + { + func(config *Config) { + var dummyKey [32]byte + for i := range dummyKey { + dummyKey[i] = byte(i) + } + + config.SetSessionTicketKeys([][32]byte{dummyKey}) + }, + func(clientHello *ClientHelloInfo) (*Config, error) { + config := testConfig.Clone() + config.sessionTicketKeys = nil + return config, nil + }, + "", + func(config *Config) error { + // The session ticket keys should have been duplicated + // into the per-connection Config. + if l := len(config.sessionTicketKeys); l != 1 { + return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l) + } + return nil + }, + }, +} + +func TestGetConfigForClient(t *testing.T) { + serverConfig := testConfig.Clone() + clientConfig := testConfig.Clone() + clientConfig.MinVersion = VersionTLS12 + + for i, test := range getConfigForClientTests { + if test.setup != nil { + test.setup(serverConfig) + } + + var configReturned *Config + serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) { + config, err := test.callback(clientHello) + configReturned = config + return config, err + } + c, s := net.Pipe() + done := make(chan error) + + go func() { + defer s.Close() + done <- Server(s, serverConfig).Handshake() + }() + + clientErr := Client(c, clientConfig).Handshake() + c.Close() + + serverErr := <-done + + if len(test.errorSubstring) == 0 { + if serverErr != nil || clientErr != nil { + t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr) + } + if test.verify != nil { + if err := test.verify(configReturned); err != nil { + t.Errorf("#%d: verify returned error: %v", i, err) + } + } + } else { + if serverErr == nil { + t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring) + } else if !strings.Contains(serverErr.Error(), test.errorSubstring) { + t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr) + } + } + } +} + func bigFromString(s string) *big.Int { ret := new(big.Int) ret.SetString(s, 10) diff --git a/tls_test.go b/tls_test.go index 8b8dfa4..d2a674d 100644 --- a/tls_test.go +++ b/tls_test.go @@ -477,7 +477,7 @@ func TestClone(t *testing.T) { case "Rand": f.Set(reflect.ValueOf(io.Reader(os.Stdin))) continue - case "Time", "GetCertificate": + case "Time", "GetCertificate", "GetConfigForClient": // DeepEqual can't compare functions. continue case "Certificates":