diff --git a/handshake_client.go b/handshake_client.go index a4ca5d3..f8db662 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -29,51 +29,38 @@ type clientHandshakeState struct { session *ClientSessionState } -// c.out.Mutex <= L; c.handshakeMutex <= L. -func (c *Conn) clientHandshake() error { - if c.config == nil { - c.config = defaultConfig() - } - - // This may be a renegotiation handshake, in which case some fields - // need to be reset. - c.didResume = false - - if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { - return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") +func makeClientHello(config *Config) (*clientHelloMsg, error) { + if len(config.ServerName) == 0 && !config.InsecureSkipVerify { + return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") } nextProtosLength := 0 - for _, proto := range c.config.NextProtos { + for _, proto := range config.NextProtos { if l := len(proto); l == 0 || l > 255 { - return errors.New("tls: invalid NextProtos value") + return nil, errors.New("tls: invalid NextProtos value") } else { nextProtosLength += 1 + l } } + if nextProtosLength > 0xffff { - return errors.New("tls: NextProtos values too large") + return nil, errors.New("tls: NextProtos values too large") } hello := &clientHelloMsg{ - vers: c.config.maxVersion(), + vers: config.maxVersion(), compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), ocspStapling: true, scts: true, - serverName: hostnameInSNI(c.config.ServerName), - supportedCurves: c.config.curvePreferences(), + serverName: hostnameInSNI(config.ServerName), + supportedCurves: config.curvePreferences(), supportedPoints: []uint8{pointFormatUncompressed}, - nextProtoNeg: len(c.config.NextProtos) > 0, + nextProtoNeg: len(config.NextProtos) > 0, secureRenegotiationSupported: true, - alpnProtocols: c.config.NextProtos, + alpnProtocols: config.NextProtos, } - - if c.handshakes > 0 { - hello.secureRenegotiation = c.clientFinished[:] - } - - possibleCipherSuites := c.config.cipherSuites() + possibleCipherSuites := config.cipherSuites() hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) NextCipherSuite: @@ -92,16 +79,37 @@ NextCipherSuite: } } - _, err := io.ReadFull(c.config.rand(), hello.random) + _, err := io.ReadFull(config.rand(), hello.random) if err != nil { - c.sendAlert(alertInternalError) - return errors.New("tls: short read from Rand: " + err.Error()) + return nil, errors.New("tls: short read from Rand: " + err.Error()) } if hello.vers >= VersionTLS12 { hello.signatureAndHashes = supportedSignatureAlgorithms } + return hello, nil +} + +// c.out.Mutex <= L; c.handshakeMutex <= L. +func (c *Conn) clientHandshake() error { + if c.config == nil { + c.config = defaultConfig() + } + + // This may be a renegotiation handshake, in which case some fields + // need to be reset. + c.didResume = false + + hello, err := makeClientHello(c.config) + if err != nil { + return err + } + + if c.handshakes > 0 { + hello.secureRenegotiation = c.clientFinished[:] + } + var session *ClientSessionState var cacheKey string sessionCache := c.config.ClientSessionCache @@ -147,12 +155,36 @@ NextCipherSuite: // (see RFC 5077). hello.sessionId = make([]byte, 16) if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { - c.sendAlert(alertInternalError) return errors.New("tls: short read from Rand: " + err.Error()) } } - if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil { + hs := &clientHandshakeState{ + c: c, + hello: hello, + session: session, + } + + if err = hs.handshake(); err != nil { + return err + } + + // If we had a successful handshake and hs.session is different from + // the one already cached - cache a new one + if sessionCache != nil && hs.session != nil && session != hs.session { + sessionCache.Put(cacheKey, hs.session) + } + + return nil +} + +// Does the handshake, either a full one or resumes old session. +// Requires hs.c, hs.hello, and, optionally, hs.session to be set. +func (hs *clientHandshakeState) handshake() error { + c := hs.c + + // send ClientHello + if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil { return err } @@ -160,34 +192,19 @@ NextCipherSuite: if err != nil { return err } - serverHello, ok := msg.(*serverHelloMsg) - if !ok { - c.sendAlert(alertUnexpectedMessage) - return unexpectedMessageError(serverHello, msg) - } - vers, ok := c.config.mutualVersion(serverHello.vers) - if !ok || vers < VersionTLS10 { - // TLS 1.0 is the minimum version supported as a client. - c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) + var ok bool + if hs.serverHello, ok = msg.(*serverHelloMsg); !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(hs.serverHello, msg) } - c.vers = vers - c.haveVers = true - suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite) - if suite == nil { - c.sendAlert(alertHandshakeFailure) - return errors.New("tls: server chose an unconfigured cipher suite") + if err = hs.pickTLSVersion(); err != nil { + return err } - hs := &clientHandshakeState{ - c: c, - serverHello: serverHello, - hello: hello, - suite: suite, - finishedHash: newFinishedHash(c.vers, suite), - session: session, + if err = hs.pickCipherSuite(); err != nil { + return err } isResume, err := hs.processServerHello() @@ -195,6 +212,8 @@ NextCipherSuite: return err } + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + // No signatures of the handshake are needed in a resumption. // Otherwise, in a full handshake, if we don't have any certificates // configured then we will never send a CertificateVerify message and @@ -246,13 +265,33 @@ NextCipherSuite: } } - if sessionCache != nil && hs.session != nil && session != hs.session { - sessionCache.Put(cacheKey, hs.session) - } - c.didResume = isResume c.handshakeComplete = true - c.cipherSuite = suite.id + + return nil +} + +func (hs *clientHandshakeState) pickTLSVersion() error { + vers, ok := hs.c.config.mutualVersion(hs.serverHello.vers) + if !ok || vers < VersionTLS10 { + // TLS 1.0 is the minimum version supported as a client. + hs.c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers) + } + + hs.c.vers = vers + hs.c.haveVers = true + + return nil +} + +func (hs *clientHandshakeState) pickCipherSuite() error { + if hs.suite = mutualCipherSuite(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil { + hs.c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server chose an unconfigured cipher suite") + } + + hs.c.cipherSuite = hs.suite.id return nil }