diff --git a/common.go b/common.go index 992a163..6ae3fca 100644 --- a/common.go +++ b/common.go @@ -785,11 +785,16 @@ func (c *Config) curvePreferences() []CurveID { } // mutualVersion returns the protocol version to use given the advertised -// version of the peer. +// version of the peer using the legacy non-extension methods. func (c *Config) mutualVersion(vers uint16) (uint16, bool) { minVersion := c.minVersion() maxVersion := c.maxVersion() + // Version 1.3 and higher are not negotiated via this mechanism. + if maxVersion > VersionTLS12 { + maxVersion = VersionTLS12 + } + if vers < minVersion { return 0, false } @@ -799,6 +804,33 @@ func (c *Config) mutualVersion(vers uint16) (uint16, bool) { return vers, true } +// pickVersion returns the protocol version to use given the advertised +// versions of the peer using the Supported Versions extension. +func (c *Config) pickVersion(supportedVersions []uint16) (uint16, bool) { + minVersion := c.minVersion() + maxVersion := c.maxVersion() + if c == nil || c.MaxVersion == 0 { + maxVersion = VersionTLS13 // override the default if pickVersion is used + } + + tls13Enabled := maxVersion >= VersionTLS13 + if maxVersion > VersionTLS12 { + maxVersion = VersionTLS12 + } + + var vers uint16 + for _, v := range supportedVersions { + if v >= minVersion && v <= maxVersion || + (tls13Enabled && v == VersionTLS13Draft18) { + if v > vers { + vers = v + } + } + } + + return vers, vers != 0 +} + // getCertificate returns the best certificate for the given ClientHelloInfo, // defaulting to the first element of c.Certificates. func (c *Config) getCertificate(clientHello *ClientHelloInfo) (*Certificate, error) { diff --git a/handshake_server.go b/handshake_server.go index c6427b2..7a9a428 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -171,14 +171,8 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { } if hs.clientHello.supportedVersions != nil { - for _, v := range hs.clientHello.supportedVersions { - if (v >= c.config.minVersion() && v <= c.config.maxVersion()) || - v == VersionTLS13Draft18 { - c.vers = v - break - } - } - if c.vers == 0 { + c.vers, ok = c.config.pickVersion(hs.clientHello.supportedVersions) + if !ok { c.sendAlert(alertProtocolVersion) return false, fmt.Errorf("tls: none of the client versions (%x) are supported", hs.clientHello.supportedVersions) }