crypto/tls: add Config.GetConfigForClient
GetConfigForClient allows the tls.Config to be updated on a per-client basis. Fixes #16066. Fixes #15707. Fixes #15699. Change-Id: I2c675a443d557f969441226729f98502b38901ea Reviewed-on: https://go-review.googlesource.com/30790 Run-TryBot: Adam Langley <agl@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
6f5a445fc3
commit
826c39c6ef
53
common.go
53
common.go
@ -303,7 +303,27 @@ type Config struct {
|
|||||||
// If GetCertificate is nil or returns nil, then the certificate is
|
// If GetCertificate is nil or returns nil, then the certificate is
|
||||||
// retrieved from NameToCertificate. If NameToCertificate is nil, the
|
// retrieved from NameToCertificate. If NameToCertificate is nil, the
|
||||||
// first element of Certificates will be used.
|
// first element of Certificates will be used.
|
||||||
GetCertificate func(clientHello *ClientHelloInfo) (*Certificate, error)
|
GetCertificate func(*ClientHelloInfo) (*Certificate, error)
|
||||||
|
|
||||||
|
// GetConfigForClient, if not nil, is called after a ClientHello is
|
||||||
|
// received from a client. It may return a non-nil Config in order to
|
||||||
|
// change the Config that will be used to handle this connection. If
|
||||||
|
// the returned Config is nil, the original Config will be used. The
|
||||||
|
// Config returned by this callback may not be subsequently modified.
|
||||||
|
//
|
||||||
|
// If GetConfigForClient is nil, the Config passed to Server() will be
|
||||||
|
// used for all connections.
|
||||||
|
//
|
||||||
|
// Uniquely for the fields in the returned Config, session ticket keys
|
||||||
|
// will be duplicated from the original Config if not set.
|
||||||
|
// Specifically, if SetSessionTicketKeys was called on the original
|
||||||
|
// config but not on the returned config then the ticket keys from the
|
||||||
|
// original config will be copied into the new config before use.
|
||||||
|
// Otherwise, if SessionTicketKey was set in the original config but
|
||||||
|
// not in the returned config then it will be copied into the returned
|
||||||
|
// config before use. If neither of those cases applies then the key
|
||||||
|
// material from the returned config will be used for session tickets.
|
||||||
|
GetConfigForClient func(*ClientHelloInfo) (*Config, error)
|
||||||
|
|
||||||
// RootCAs defines the set of root certificate authorities
|
// RootCAs defines the set of root certificate authorities
|
||||||
// that clients use when verifying server certificates.
|
// that clients use when verifying server certificates.
|
||||||
@ -398,13 +418,17 @@ type Config struct {
|
|||||||
|
|
||||||
serverInitOnce sync.Once // guards calling (*Config).serverInit
|
serverInitOnce sync.Once // guards calling (*Config).serverInit
|
||||||
|
|
||||||
// mutex protects sessionTicketKeys
|
// mutex protects sessionTicketKeys and originalConfig.
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
// sessionTicketKeys contains zero or more ticket keys. If the length
|
// sessionTicketKeys contains zero or more ticket keys. If the length
|
||||||
// is zero, SessionTicketsDisabled must be true. The first key is used
|
// is zero, SessionTicketsDisabled must be true. The first key is used
|
||||||
// for new tickets and any subsequent keys can be used to decrypt old
|
// for new tickets and any subsequent keys can be used to decrypt old
|
||||||
// tickets.
|
// tickets.
|
||||||
sessionTicketKeys []ticketKey
|
sessionTicketKeys []ticketKey
|
||||||
|
// originalConfig is set to the Config that was passed to Server if
|
||||||
|
// this Config is returned by a GetConfigForClient callback. It's used
|
||||||
|
// by serverInit in order to copy session ticket keys if needed.
|
||||||
|
originalConfig *Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
|
// ticketKeyNameLen is the number of bytes of identifier that is prepended to
|
||||||
@ -434,12 +458,18 @@ func ticketKeyFromBytes(b [32]byte) (key ticketKey) {
|
|||||||
// Clone returns a shallow clone of c.
|
// Clone returns a shallow clone of c.
|
||||||
// Only the exported fields are copied.
|
// Only the exported fields are copied.
|
||||||
func (c *Config) Clone() *Config {
|
func (c *Config) Clone() *Config {
|
||||||
|
var sessionTicketKeys []ticketKey
|
||||||
|
c.mutex.RLock()
|
||||||
|
sessionTicketKeys = c.sessionTicketKeys
|
||||||
|
c.mutex.RUnlock()
|
||||||
|
|
||||||
return &Config{
|
return &Config{
|
||||||
Rand: c.Rand,
|
Rand: c.Rand,
|
||||||
Time: c.Time,
|
Time: c.Time,
|
||||||
Certificates: c.Certificates,
|
Certificates: c.Certificates,
|
||||||
NameToCertificate: c.NameToCertificate,
|
NameToCertificate: c.NameToCertificate,
|
||||||
GetCertificate: c.GetCertificate,
|
GetCertificate: c.GetCertificate,
|
||||||
|
GetConfigForClient: c.GetConfigForClient,
|
||||||
RootCAs: c.RootCAs,
|
RootCAs: c.RootCAs,
|
||||||
NextProtos: c.NextProtos,
|
NextProtos: c.NextProtos,
|
||||||
ServerName: c.ServerName,
|
ServerName: c.ServerName,
|
||||||
@ -457,6 +487,8 @@ func (c *Config) Clone() *Config {
|
|||||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||||
Renegotiation: c.Renegotiation,
|
Renegotiation: c.Renegotiation,
|
||||||
KeyLogWriter: c.KeyLogWriter,
|
KeyLogWriter: c.KeyLogWriter,
|
||||||
|
sessionTicketKeys: sessionTicketKeys,
|
||||||
|
// originalConfig is deliberately not duplicated.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -465,6 +497,11 @@ func (c *Config) serverInit() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var originalConfig *Config
|
||||||
|
c.mutex.Lock()
|
||||||
|
originalConfig, c.originalConfig = c.originalConfig, nil
|
||||||
|
c.mutex.Unlock()
|
||||||
|
|
||||||
alreadySet := false
|
alreadySet := false
|
||||||
for _, b := range c.SessionTicketKey {
|
for _, b := range c.SessionTicketKey {
|
||||||
if b != 0 {
|
if b != 0 {
|
||||||
@ -474,13 +511,21 @@ func (c *Config) serverInit() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !alreadySet {
|
if !alreadySet {
|
||||||
if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
|
if originalConfig != nil {
|
||||||
|
copy(c.SessionTicketKey[:], originalConfig.SessionTicketKey[:])
|
||||||
|
} else if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil {
|
||||||
c.SessionTicketsDisabled = true
|
c.SessionTicketsDisabled = true
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
|
if originalConfig != nil {
|
||||||
|
originalConfig.mutex.RLock()
|
||||||
|
c.sessionTicketKeys = originalConfig.sessionTicketKeys
|
||||||
|
originalConfig.mutex.RUnlock()
|
||||||
|
} else {
|
||||||
|
c.sessionTicketKeys = []ticketKey{ticketKeyFromBytes(c.SessionTicketKey)}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) ticketKeys() []ticketKey {
|
func (c *Config) ticketKeys() []ticketKey {
|
||||||
|
@ -37,11 +37,9 @@ type serverHandshakeState struct {
|
|||||||
// serverHandshake performs a TLS handshake as a server.
|
// serverHandshake performs a TLS handshake as a server.
|
||||||
// c.out.Mutex <= L; c.handshakeMutex <= L.
|
// c.out.Mutex <= L; c.handshakeMutex <= L.
|
||||||
func (c *Conn) serverHandshake() error {
|
func (c *Conn) serverHandshake() error {
|
||||||
config := c.config
|
|
||||||
|
|
||||||
// If this is the first server handshake, we generate a random key to
|
// If this is the first server handshake, we generate a random key to
|
||||||
// encrypt the tickets with.
|
// encrypt the tickets with.
|
||||||
config.serverInitOnce.Do(config.serverInit)
|
c.config.serverInitOnce.Do(c.config.serverInit)
|
||||||
|
|
||||||
hs := serverHandshakeState{
|
hs := serverHandshakeState{
|
||||||
c: c,
|
c: c,
|
||||||
@ -112,7 +110,6 @@ func (c *Conn) serverHandshake() error {
|
|||||||
// readClientHello reads a ClientHello message from the client and decides
|
// readClientHello reads a ClientHello message from the client and decides
|
||||||
// whether we will perform session resumption.
|
// whether we will perform session resumption.
|
||||||
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
||||||
config := hs.c.config
|
|
||||||
c := hs.c
|
c := hs.c
|
||||||
|
|
||||||
msg, err := c.readHandshake()
|
msg, err := c.readHandshake()
|
||||||
@ -125,7 +122,29 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
|||||||
c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
return false, unexpectedMessageError(hs.clientHello, msg)
|
return false, unexpectedMessageError(hs.clientHello, msg)
|
||||||
}
|
}
|
||||||
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
|
|
||||||
|
clientHelloInfo := &ClientHelloInfo{
|
||||||
|
CipherSuites: hs.clientHello.cipherSuites,
|
||||||
|
ServerName: hs.clientHello.serverName,
|
||||||
|
SupportedCurves: hs.clientHello.supportedCurves,
|
||||||
|
SupportedPoints: hs.clientHello.supportedPoints,
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.config.GetConfigForClient != nil {
|
||||||
|
if newConfig, err := c.config.GetConfigForClient(clientHelloInfo); err != nil {
|
||||||
|
c.sendAlert(alertInternalError)
|
||||||
|
return false, err
|
||||||
|
} else if newConfig != nil {
|
||||||
|
newConfig.mutex.Lock()
|
||||||
|
newConfig.originalConfig = c.config
|
||||||
|
newConfig.mutex.Unlock()
|
||||||
|
|
||||||
|
newConfig.serverInitOnce.Do(newConfig.serverInit)
|
||||||
|
c.config = newConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.sendAlert(alertProtocolVersion)
|
c.sendAlert(alertProtocolVersion)
|
||||||
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
|
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
|
||||||
@ -135,7 +154,7 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
|||||||
hs.hello = new(serverHelloMsg)
|
hs.hello = new(serverHelloMsg)
|
||||||
|
|
||||||
supportedCurve := false
|
supportedCurve := false
|
||||||
preferredCurves := config.curvePreferences()
|
preferredCurves := c.config.curvePreferences()
|
||||||
Curves:
|
Curves:
|
||||||
for _, curve := range hs.clientHello.supportedCurves {
|
for _, curve := range hs.clientHello.supportedCurves {
|
||||||
for _, supported := range preferredCurves {
|
for _, supported := range preferredCurves {
|
||||||
@ -171,7 +190,7 @@ Curves:
|
|||||||
|
|
||||||
hs.hello.vers = c.vers
|
hs.hello.vers = c.vers
|
||||||
hs.hello.random = make([]byte, 32)
|
hs.hello.random = make([]byte, 32)
|
||||||
_, err = io.ReadFull(config.rand(), hs.hello.random)
|
_, err = io.ReadFull(c.config.rand(), hs.hello.random)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return false, err
|
return false, err
|
||||||
@ -196,20 +215,15 @@ Curves:
|
|||||||
} else {
|
} else {
|
||||||
// Although sending an empty NPN extension is reasonable, Firefox has
|
// Although sending an empty NPN extension is reasonable, Firefox has
|
||||||
// had a bug around this. Best to send nothing at all if
|
// had a bug around this. Best to send nothing at all if
|
||||||
// config.NextProtos is empty. See
|
// c.config.NextProtos is empty. See
|
||||||
// https://golang.org/issue/5445.
|
// https://golang.org/issue/5445.
|
||||||
if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 {
|
if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
|
||||||
hs.hello.nextProtoNeg = true
|
hs.hello.nextProtoNeg = true
|
||||||
hs.hello.nextProtos = config.NextProtos
|
hs.hello.nextProtos = c.config.NextProtos
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.cert, err = config.getCertificate(&ClientHelloInfo{
|
hs.cert, err = c.config.getCertificate(clientHelloInfo)
|
||||||
CipherSuites: hs.clientHello.cipherSuites,
|
|
||||||
ServerName: hs.clientHello.serverName,
|
|
||||||
SupportedCurves: hs.clientHello.supportedCurves,
|
|
||||||
SupportedPoints: hs.clientHello.supportedPoints,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return false, err
|
return false, err
|
||||||
@ -354,18 +368,17 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hs *serverHandshakeState) doFullHandshake() error {
|
func (hs *serverHandshakeState) doFullHandshake() error {
|
||||||
config := hs.c.config
|
|
||||||
c := hs.c
|
c := hs.c
|
||||||
|
|
||||||
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
|
if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
|
||||||
hs.hello.ocspStapling = true
|
hs.hello.ocspStapling = true
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled
|
hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
|
||||||
hs.hello.cipherSuite = hs.suite.id
|
hs.hello.cipherSuite = hs.suite.id
|
||||||
|
|
||||||
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
|
hs.finishedHash = newFinishedHash(hs.c.vers, hs.suite)
|
||||||
if config.ClientAuth == NoClientCert {
|
if c.config.ClientAuth == NoClientCert {
|
||||||
// No need to keep a full record of the handshake if client
|
// No need to keep a full record of the handshake if client
|
||||||
// certificates won't be used.
|
// certificates won't be used.
|
||||||
hs.finishedHash.discardHandshakeBuffer()
|
hs.finishedHash.discardHandshakeBuffer()
|
||||||
@ -394,7 +407,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
keyAgreement := hs.suite.ka(c.vers)
|
keyAgreement := hs.suite.ka(c.vers)
|
||||||
skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello)
|
skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.clientHello, hs.hello)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
return err
|
return err
|
||||||
@ -406,7 +419,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.ClientAuth >= RequestClientCert {
|
if c.config.ClientAuth >= RequestClientCert {
|
||||||
// Request a client certificate
|
// Request a client certificate
|
||||||
certReq := new(certificateRequestMsg)
|
certReq := new(certificateRequestMsg)
|
||||||
certReq.certificateTypes = []byte{
|
certReq.certificateTypes = []byte{
|
||||||
@ -423,8 +436,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
// to our request. When we know the CAs we trust, then
|
// to our request. When we know the CAs we trust, then
|
||||||
// we can send them down, so that the client can choose
|
// we can send them down, so that the client can choose
|
||||||
// an appropriate certificate to give to us.
|
// an appropriate certificate to give to us.
|
||||||
if config.ClientCAs != nil {
|
if c.config.ClientCAs != nil {
|
||||||
certReq.certificateAuthorities = config.ClientCAs.Subjects()
|
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(certReq.marshal())
|
hs.finishedHash.Write(certReq.marshal())
|
||||||
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
|
if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
|
||||||
@ -452,7 +465,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
var ok bool
|
var ok bool
|
||||||
// If we requested a client certificate, then the client must send a
|
// If we requested a client certificate, then the client must send a
|
||||||
// certificate message, even if it's empty.
|
// certificate message, even if it's empty.
|
||||||
if config.ClientAuth >= RequestClientCert {
|
if c.config.ClientAuth >= RequestClientCert {
|
||||||
if certMsg, ok = msg.(*certificateMsg); !ok {
|
if certMsg, ok = msg.(*certificateMsg); !ok {
|
||||||
c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
return unexpectedMessageError(certMsg, msg)
|
return unexpectedMessageError(certMsg, msg)
|
||||||
@ -461,7 +474,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
|
|
||||||
if len(certMsg.certificates) == 0 {
|
if len(certMsg.certificates) == 0 {
|
||||||
// The client didn't actually send a certificate
|
// The client didn't actually send a certificate
|
||||||
switch config.ClientAuth {
|
switch c.config.ClientAuth {
|
||||||
case RequireAnyClientCert, RequireAndVerifyClientCert:
|
case RequireAnyClientCert, RequireAndVerifyClientCert:
|
||||||
c.sendAlert(alertBadCertificate)
|
c.sendAlert(alertBadCertificate)
|
||||||
return errors.New("tls: client didn't provide a certificate")
|
return errors.New("tls: client didn't provide a certificate")
|
||||||
@ -487,13 +500,13 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
hs.finishedHash.Write(ckx.marshal())
|
hs.finishedHash.Write(ckx.marshal())
|
||||||
|
|
||||||
preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers)
|
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
|
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
|
||||||
if err := config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
|
if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
|
||||||
c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1141,6 +1141,141 @@ func TestSNIGivenOnFailure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var getConfigForClientTests = []struct {
|
||||||
|
setup func(config *Config)
|
||||||
|
callback func(clientHello *ClientHelloInfo) (*Config, error)
|
||||||
|
errorSubstring string
|
||||||
|
verify func(config *Config) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
nil,
|
||||||
|
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
return nil, nil
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nil,
|
||||||
|
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
return nil, errors.New("should bubble up")
|
||||||
|
},
|
||||||
|
"should bubble up",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
nil,
|
||||||
|
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
// Setting a maximum version of TLS 1.1 should cause
|
||||||
|
// the handshake to fail.
|
||||||
|
config.MaxVersion = VersionTLS11
|
||||||
|
return config, nil
|
||||||
|
},
|
||||||
|
"version 301 when expecting version 302",
|
||||||
|
nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
func(config *Config) {
|
||||||
|
for i := range config.SessionTicketKey {
|
||||||
|
config.SessionTicketKey[i] = byte(i)
|
||||||
|
}
|
||||||
|
config.sessionTicketKeys = nil
|
||||||
|
},
|
||||||
|
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
for i := range config.SessionTicketKey {
|
||||||
|
config.SessionTicketKey[i] = 0
|
||||||
|
}
|
||||||
|
config.sessionTicketKeys = nil
|
||||||
|
return config, nil
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
func(config *Config) error {
|
||||||
|
// The value of SessionTicketKey should have been
|
||||||
|
// duplicated into the per-connection Config.
|
||||||
|
for i := range config.SessionTicketKey {
|
||||||
|
if b := config.SessionTicketKey[i]; b != byte(i) {
|
||||||
|
return fmt.Errorf("SessionTicketKey was not duplicated from original Config: byte %d has value %d", i, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
func(config *Config) {
|
||||||
|
var dummyKey [32]byte
|
||||||
|
for i := range dummyKey {
|
||||||
|
dummyKey[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
config.SetSessionTicketKeys([][32]byte{dummyKey})
|
||||||
|
},
|
||||||
|
func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
config := testConfig.Clone()
|
||||||
|
config.sessionTicketKeys = nil
|
||||||
|
return config, nil
|
||||||
|
},
|
||||||
|
"",
|
||||||
|
func(config *Config) error {
|
||||||
|
// The session ticket keys should have been duplicated
|
||||||
|
// into the per-connection Config.
|
||||||
|
if l := len(config.sessionTicketKeys); l != 1 {
|
||||||
|
return fmt.Errorf("got len(sessionTicketKeys) == %d, wanted 1", l)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetConfigForClient(t *testing.T) {
|
||||||
|
serverConfig := testConfig.Clone()
|
||||||
|
clientConfig := testConfig.Clone()
|
||||||
|
clientConfig.MinVersion = VersionTLS12
|
||||||
|
|
||||||
|
for i, test := range getConfigForClientTests {
|
||||||
|
if test.setup != nil {
|
||||||
|
test.setup(serverConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
var configReturned *Config
|
||||||
|
serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
|
||||||
|
config, err := test.callback(clientHello)
|
||||||
|
configReturned = config
|
||||||
|
return config, err
|
||||||
|
}
|
||||||
|
c, s := net.Pipe()
|
||||||
|
done := make(chan error)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer s.Close()
|
||||||
|
done <- Server(s, serverConfig).Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientErr := Client(c, clientConfig).Handshake()
|
||||||
|
c.Close()
|
||||||
|
|
||||||
|
serverErr := <-done
|
||||||
|
|
||||||
|
if len(test.errorSubstring) == 0 {
|
||||||
|
if serverErr != nil || clientErr != nil {
|
||||||
|
t.Errorf("%#d: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
|
||||||
|
}
|
||||||
|
if test.verify != nil {
|
||||||
|
if err := test.verify(configReturned); err != nil {
|
||||||
|
t.Errorf("#%d: verify returned error: %v", i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if serverErr == nil {
|
||||||
|
t.Errorf("%#d: expected error containing %q but got no error", i, test.errorSubstring)
|
||||||
|
} else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
|
||||||
|
t.Errorf("%#d: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func bigFromString(s string) *big.Int {
|
func bigFromString(s string) *big.Int {
|
||||||
ret := new(big.Int)
|
ret := new(big.Int)
|
||||||
ret.SetString(s, 10)
|
ret.SetString(s, 10)
|
||||||
|
@ -477,7 +477,7 @@ func TestClone(t *testing.T) {
|
|||||||
case "Rand":
|
case "Rand":
|
||||||
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
||||||
continue
|
continue
|
||||||
case "Time", "GetCertificate":
|
case "Time", "GetCertificate", "GetConfigForClient":
|
||||||
// DeepEqual can't compare functions.
|
// DeepEqual can't compare functions.
|
||||||
continue
|
continue
|
||||||
case "Certificates":
|
case "Certificates":
|
||||||
|
Loading…
Reference in New Issue
Block a user