LGTM=bradfitz R=golang-codereviews, bradfitz CC=golang-codereviews https://golang.org/cl/60580046v1.2.3
@@ -9,6 +9,7 @@ import ( | |||
"crypto" | |||
"crypto/rand" | |||
"crypto/x509" | |||
"fmt" | |||
"io" | |||
"math/big" | |||
"strings" | |||
@@ -540,3 +541,7 @@ func initDefaultCipherSuites() { | |||
varDefaultCipherSuites[i] = suite.id | |||
} | |||
} | |||
func unexpectedMessageError(wanted, got interface{}) error { | |||
return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) | |||
} |
@@ -12,6 +12,7 @@ import ( | |||
"crypto/subtle" | |||
"crypto/x509" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net" | |||
"sync" | |||
@@ -518,14 +519,17 @@ func (c *Conn) readRecord(want recordType) error { | |||
// else application data. (We don't support renegotiation.) | |||
switch want { | |||
default: | |||
return c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return errors.New("tls: unknown record type requested") | |||
case recordTypeHandshake, recordTypeChangeCipherSpec: | |||
if c.handshakeComplete { | |||
return c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete") | |||
} | |||
case recordTypeApplicationData: | |||
if !c.handshakeComplete { | |||
return c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return errors.New("tls: application data record requested before handshake complete") | |||
} | |||
} | |||
@@ -562,10 +566,12 @@ Again: | |||
vers := uint16(b.data[1])<<8 | uint16(b.data[2]) | |||
n := int(b.data[3])<<8 | int(b.data[4]) | |||
if c.haveVers && vers != c.vers { | |||
return c.sendAlert(alertProtocolVersion) | |||
c.sendAlert(alertProtocolVersion) | |||
return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers) | |||
} | |||
if n > maxCiphertext { | |||
return c.sendAlert(alertRecordOverflow) | |||
c.sendAlert(alertRecordOverflow) | |||
return fmt.Errorf("tls: oversized record received with length %d", n) | |||
} | |||
if !c.haveVers { | |||
// First message, be extra suspicious: | |||
@@ -577,7 +583,8 @@ Again: | |||
// well under a kilobyte. If the length is >= 12 kB, | |||
// it's probably not real. | |||
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return fmt.Errorf("tls: first record does not look like a TLS handshake") | |||
} | |||
} | |||
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { | |||
@@ -990,10 +997,10 @@ func (c *Conn) VerifyHostname(host string) error { | |||
c.handshakeMutex.Lock() | |||
defer c.handshakeMutex.Unlock() | |||
if !c.isClient { | |||
return errors.New("VerifyHostname called on TLS server connection") | |||
return errors.New("tls: VerifyHostname called on TLS server connection") | |||
} | |||
if !c.handshakeComplete { | |||
return errors.New("TLS handshake has not yet been performed") | |||
return errors.New("tls: handshake has not yet been performed") | |||
} | |||
return c.peerCertificates[0].VerifyHostname(host) | |||
} |
@@ -12,6 +12,7 @@ import ( | |||
"crypto/x509" | |||
"encoding/asn1" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"net" | |||
"strconv" | |||
@@ -126,20 +127,23 @@ NextCipherSuite: | |||
} | |||
serverHello, ok := msg.(*serverHelloMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(serverHello, msg) | |||
} | |||
vers, ok := c.config.mutualVersion(serverHello.vers) | |||
if !ok || vers < VersionTLS10 { | |||
// TLS 1.0 is the minimum version supported as a client. | |||
return c.sendAlert(alertProtocolVersion) | |||
c.sendAlert(alertProtocolVersion) | |||
return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) | |||
} | |||
c.vers = vers | |||
c.haveVers = true | |||
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) | |||
if suite == nil { | |||
return c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertHandshakeFailure) | |||
return fmt.Errorf("tls: server selected an unsupported cipher suite") | |||
} | |||
hs := &clientHandshakeState{ | |||
@@ -209,7 +213,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
} | |||
certMsg, ok := msg.(*certificateMsg) | |||
if !ok || len(certMsg.certificates) == 0 { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(certMsg, msg) | |||
} | |||
hs.finishedHash.Write(certMsg.marshal()) | |||
@@ -218,7 +223,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
cert, err := x509.ParseCertificate(asn1Data) | |||
if err != nil { | |||
c.sendAlert(alertBadCertificate) | |||
return errors.New("failed to parse certificate from server: " + err.Error()) | |||
return errors.New("tls: failed to parse certificate from server: " + err.Error()) | |||
} | |||
certs[i] = cert | |||
} | |||
@@ -248,7 +253,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
case *rsa.PublicKey, *ecdsa.PublicKey: | |||
break | |||
default: | |||
return c.sendAlert(alertUnsupportedCertificate) | |||
c.sendAlert(alertUnsupportedCertificate) | |||
return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey) | |||
} | |||
c.peerCertificates = certs | |||
@@ -260,7 +266,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
} | |||
cs, ok := msg.(*certificateStatusMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(cs, msg) | |||
} | |||
hs.finishedHash.Write(cs.marshal()) | |||
@@ -371,7 +378,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
shd, ok := msg.(*serverHelloDoneMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(shd, msg) | |||
} | |||
hs.finishedHash.Write(shd.marshal()) | |||
@@ -421,7 +429,8 @@ func (hs *clientHandshakeState) doFullHandshake() error { | |||
err = errors.New("unknown private key type") | |||
} | |||
if err != nil { | |||
return c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return errors.New("tls: failed to sign handshake with client certificate: " + err.Error()) | |||
} | |||
certVerify.signature = signed | |||
@@ -466,12 +475,13 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) { | |||
c := hs.c | |||
if hs.serverHello.compressionMethod != compressionNone { | |||
return false, c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return false, errors.New("tls: server selected unsupported compression format") | |||
} | |||
if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg { | |||
c.sendAlert(alertHandshakeFailure) | |||
return false, errors.New("server advertised unrequested NPN") | |||
return false, errors.New("server advertised unrequested NPN extension") | |||
} | |||
if hs.serverResumedSession() { | |||
@@ -497,13 +507,15 @@ func (hs *clientHandshakeState) readFinished() error { | |||
} | |||
serverFinished, ok := msg.(*finishedMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(serverFinished, msg) | |||
} | |||
verify := hs.finishedHash.serverSum(hs.masterSecret) | |||
if len(verify) != len(serverFinished.verifyData) || | |||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { | |||
return c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertHandshakeFailure) | |||
return errors.New("tls: server's Finished message was incorrect") | |||
} | |||
hs.finishedHash.Write(serverFinished.marshal()) | |||
return nil | |||
@@ -521,7 +533,8 @@ func (hs *clientHandshakeState) readSessionTicket() error { | |||
} | |||
sessionTicketMsg, ok := msg.(*newSessionTicketMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(sessionTicketMsg, msg) | |||
} | |||
hs.finishedHash.Write(sessionTicketMsg.marshal()) | |||
@@ -12,6 +12,7 @@ import ( | |||
"crypto/x509" | |||
"encoding/asn1" | |||
"errors" | |||
"fmt" | |||
"io" | |||
) | |||
@@ -100,11 +101,13 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { | |||
var ok bool | |||
hs.clientHello, ok = msg.(*clientHelloMsg) | |||
if !ok { | |||
return false, c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return false, unexpectedMessageError(hs.clientHello, msg) | |||
} | |||
c.vers, ok = config.mutualVersion(hs.clientHello.vers) | |||
if !ok { | |||
return false, c.sendAlert(alertProtocolVersion) | |||
c.sendAlert(alertProtocolVersion) | |||
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) | |||
} | |||
c.haveVers = true | |||
@@ -142,14 +145,16 @@ Curves: | |||
} | |||
if !foundCompression { | |||
return false, c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertHandshakeFailure) | |||
return false, errors.New("tls: client does not support uncompressed connections") | |||
} | |||
hs.hello.vers = c.vers | |||
hs.hello.random = make([]byte, 32) | |||
_, err = io.ReadFull(config.rand(), hs.hello.random) | |||
if err != nil { | |||
return false, c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return false, err | |||
} | |||
hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation | |||
hs.hello.compressionMethod = compressionNone | |||
@@ -166,7 +171,8 @@ Curves: | |||
} | |||
if len(config.Certificates) == 0 { | |||
return false, c.sendAlert(alertInternalError) | |||
c.sendAlert(alertInternalError) | |||
return false, errors.New("tls: no certificates configured") | |||
} | |||
hs.cert = &config.Certificates[0] | |||
if len(hs.clientHello.serverName) > 0 { | |||
@@ -195,7 +201,8 @@ Curves: | |||
} | |||
if hs.suite == nil { | |||
return false, c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertHandshakeFailure) | |||
return false, errors.New("tls: no cipher suite supported by both client and server") | |||
} | |||
return false, nil | |||
@@ -345,7 +352,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
// certificate message, even if it's empty. | |||
if config.ClientAuth >= RequestClientCert { | |||
if certMsg, ok = msg.(*certificateMsg); !ok { | |||
return c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(certMsg, msg) | |||
} | |||
hs.finishedHash.Write(certMsg.marshal()) | |||
@@ -372,7 +380,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
// Get client key exchange | |||
ckx, ok := msg.(*clientKeyExchangeMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(ckx, msg) | |||
} | |||
hs.finishedHash.Write(ckx.marshal()) | |||
@@ -389,7 +398,8 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||
} | |||
certVerify, ok := msg.(*certificateVerifyMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(certVerify, msg) | |||
} | |||
switch key := pub.(type) { | |||
@@ -469,7 +479,8 @@ func (hs *serverHandshakeState) readFinished() error { | |||
} | |||
nextProto, ok := msg.(*nextProtoMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(nextProto, msg) | |||
} | |||
hs.finishedHash.Write(nextProto.marshal()) | |||
c.clientProtocol = nextProto.proto | |||
@@ -481,13 +492,15 @@ func (hs *serverHandshakeState) readFinished() error { | |||
} | |||
clientFinished, ok := msg.(*finishedMsg) | |||
if !ok { | |||
return c.sendAlert(alertUnexpectedMessage) | |||
c.sendAlert(alertUnexpectedMessage) | |||
return unexpectedMessageError(clientFinished, msg) | |||
} | |||
verify := hs.finishedHash.clientSum(hs.masterSecret) | |||
if len(verify) != len(clientFinished.verifyData) || | |||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { | |||
return c.sendAlert(alertHandshakeFailure) | |||
c.sendAlert(alertHandshakeFailure) | |||
return errors.New("tls: client's Finished message is incorrect") | |||
} | |||
hs.finishedHash.Write(clientFinished.marshal()) | |||
@@ -590,7 +603,8 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c | |||
case *ecdsa.PublicKey, *rsa.PublicKey: | |||
pub = key | |||
default: | |||
return nil, c.sendAlert(alertUnsupportedCertificate) | |||
c.sendAlert(alertUnsupportedCertificate) | |||
return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey) | |||
} | |||
c.peerCertificates = certs | |||
return pub, nil | |||
@@ -20,6 +20,7 @@ import ( | |||
"os" | |||
"os/exec" | |||
"path/filepath" | |||
"strings" | |||
"testing" | |||
"time" | |||
) | |||
@@ -53,7 +54,7 @@ func init() { | |||
testConfig.BuildNameToCertificate() | |||
} | |||
func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) { | |||
func testClientHelloFailure(t *testing.T, m handshakeMessage, expectedSubStr string) { | |||
// Create in-memory network connection, | |||
// send message to server. Should return | |||
// expected error. | |||
@@ -68,20 +69,20 @@ func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) { | |||
}() | |||
err := Server(s, testConfig).Handshake() | |||
s.Close() | |||
if e, ok := err.(*net.OpError); !ok || e.Err != expected { | |||
t.Errorf("Got error: %s; expected: %s", err, expected) | |||
if err == nil || !strings.Contains(err.Error(), expectedSubStr) { | |||
t.Errorf("Got error: %s; expected to match substring '%s'", err, expectedSubStr) | |||
} | |||
} | |||
func TestSimpleError(t *testing.T) { | |||
testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage) | |||
testClientHelloFailure(t, &serverHelloDoneMsg{}, "unexpected handshake message") | |||
} | |||
var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205} | |||
func TestRejectBadProtocolVersion(t *testing.T) { | |||
for _, v := range badProtocolVersions { | |||
testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion) | |||
testClientHelloFailure(t, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version") | |||
} | |||
} | |||
@@ -91,7 +92,7 @@ func TestNoSuiteOverlap(t *testing.T) { | |||
cipherSuites: []uint16{0xff00}, | |||
compressionMethods: []uint8{0}, | |||
} | |||
testClientHelloFailure(t, clientHello, alertHandshakeFailure) | |||
testClientHelloFailure(t, clientHello, "no cipher suite supported by both client and server") | |||
} | |||
func TestNoCompressionOverlap(t *testing.T) { | |||
@@ -100,7 +101,7 @@ func TestNoCompressionOverlap(t *testing.T) { | |||
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, | |||
compressionMethods: []uint8{0xff}, | |||
} | |||
testClientHelloFailure(t, clientHello, alertHandshakeFailure) | |||
testClientHelloFailure(t, clientHello, "client does not support uncompressed connections") | |||
} | |||
func TestTLS12OnlyCipherSuites(t *testing.T) { | |||
@@ -19,6 +19,9 @@ import ( | |||
"math/big" | |||
) | |||
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") | |||
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") | |||
// rsaKeyAgreement implements the standard TLS key agreement where the client | |||
// encrypts the pre-master secret to the server's public key. | |||
type rsaKeyAgreement struct{} | |||
@@ -35,14 +38,14 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi | |||
} | |||
if len(ckx.ciphertext) < 2 { | |||
return nil, errors.New("bad ClientKeyExchange") | |||
return nil, errClientKeyExchange | |||
} | |||
ciphertext := ckx.ciphertext | |||
if version != VersionSSL30 { | |||
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) | |||
if ciphertextLen != len(ckx.ciphertext)-2 { | |||
return nil, errors.New("bad ClientKeyExchange") | |||
return nil, errClientKeyExchange | |||
} | |||
ciphertext = ckx.ciphertext[2:] | |||
} | |||
@@ -61,7 +64,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi | |||
} | |||
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { | |||
return errors.New("unexpected ServerKeyExchange") | |||
return errors.New("tls: unexpected ServerKeyExchange") | |||
} | |||
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { | |||
@@ -271,11 +274,11 @@ Curve: | |||
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { | |||
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { | |||
return nil, errors.New("bad ClientKeyExchange") | |||
return nil, errClientKeyExchange | |||
} | |||
x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:]) | |||
if x == nil { | |||
return nil, errors.New("bad ClientKeyExchange") | |||
return nil, errClientKeyExchange | |||
} | |||
x, _ = ka.curve.ScalarMult(x, y, ka.privateKey) | |||
preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3) | |||
@@ -285,8 +288,6 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert | |||
return preMasterSecret, nil | |||
} | |||
var errServerKeyExchange = errors.New("invalid ServerKeyExchange") | |||
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { | |||
if len(skx.key) < 4 { | |||
return errServerKeyExchange | |||