crypto/tls: better error messages.
LGTM=bradfitz R=golang-codereviews, bradfitz CC=golang-codereviews https://golang.org/cl/60580046
This commit is contained in:
parent
8cf5d703de
commit
5a2aacff2f
@ -9,6 +9,7 @@ import (
|
|||||||
"crypto"
|
"crypto"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
@ -540,3 +541,7 @@ func initDefaultCipherSuites() {
|
|||||||
varDefaultCipherSuites[i] = suite.id
|
varDefaultCipherSuites[i] = suite.id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func unexpectedMessageError(wanted, got interface{}) error {
|
||||||
|
return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted)
|
||||||
|
}
|
||||||
|
23
conn.go
23
conn.go
@ -12,6 +12,7 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -518,14 +519,17 @@ func (c *Conn) readRecord(want recordType) error {
|
|||||||
// else application data. (We don't support renegotiation.)
|
// else application data. (We don't support renegotiation.)
|
||||||
switch want {
|
switch want {
|
||||||
default:
|
default:
|
||||||
return c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return errors.New("tls: unknown record type requested")
|
||||||
case recordTypeHandshake, recordTypeChangeCipherSpec:
|
case recordTypeHandshake, recordTypeChangeCipherSpec:
|
||||||
if c.handshakeComplete {
|
if c.handshakeComplete {
|
||||||
return c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")
|
||||||
}
|
}
|
||||||
case recordTypeApplicationData:
|
case recordTypeApplicationData:
|
||||||
if !c.handshakeComplete {
|
if !c.handshakeComplete {
|
||||||
return c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return errors.New("tls: application data record requested before handshake complete")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -562,10 +566,12 @@ Again:
|
|||||||
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
|
vers := uint16(b.data[1])<<8 | uint16(b.data[2])
|
||||||
n := int(b.data[3])<<8 | int(b.data[4])
|
n := int(b.data[3])<<8 | int(b.data[4])
|
||||||
if c.haveVers && vers != c.vers {
|
if c.haveVers && vers != c.vers {
|
||||||
return c.sendAlert(alertProtocolVersion)
|
c.sendAlert(alertProtocolVersion)
|
||||||
|
return fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)
|
||||||
}
|
}
|
||||||
if n > maxCiphertext {
|
if n > maxCiphertext {
|
||||||
return c.sendAlert(alertRecordOverflow)
|
c.sendAlert(alertRecordOverflow)
|
||||||
|
return fmt.Errorf("tls: oversized record received with length %d", n)
|
||||||
}
|
}
|
||||||
if !c.haveVers {
|
if !c.haveVers {
|
||||||
// First message, be extra suspicious:
|
// First message, be extra suspicious:
|
||||||
@ -577,7 +583,8 @@ Again:
|
|||||||
// well under a kilobyte. If the length is >= 12 kB,
|
// well under a kilobyte. If the length is >= 12 kB,
|
||||||
// it's probably not real.
|
// it's probably not real.
|
||||||
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
|
if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return fmt.Errorf("tls: first record does not look like a TLS handshake")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
|
if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
|
||||||
@ -990,10 +997,10 @@ func (c *Conn) VerifyHostname(host string) error {
|
|||||||
c.handshakeMutex.Lock()
|
c.handshakeMutex.Lock()
|
||||||
defer c.handshakeMutex.Unlock()
|
defer c.handshakeMutex.Unlock()
|
||||||
if !c.isClient {
|
if !c.isClient {
|
||||||
return errors.New("VerifyHostname called on TLS server connection")
|
return errors.New("tls: VerifyHostname called on TLS server connection")
|
||||||
}
|
}
|
||||||
if !c.handshakeComplete {
|
if !c.handshakeComplete {
|
||||||
return errors.New("TLS handshake has not yet been performed")
|
return errors.New("tls: handshake has not yet been performed")
|
||||||
}
|
}
|
||||||
return c.peerCertificates[0].VerifyHostname(host)
|
return c.peerCertificates[0].VerifyHostname(host)
|
||||||
}
|
}
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -126,20 +127,23 @@ NextCipherSuite:
|
|||||||
}
|
}
|
||||||
serverHello, ok := msg.(*serverHelloMsg)
|
serverHello, ok := msg.(*serverHelloMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(serverHello, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
vers, ok := c.config.mutualVersion(serverHello.vers)
|
vers, ok := c.config.mutualVersion(serverHello.vers)
|
||||||
if !ok || vers < VersionTLS10 {
|
if !ok || vers < VersionTLS10 {
|
||||||
// TLS 1.0 is the minimum version supported as a client.
|
// TLS 1.0 is the minimum version supported as a client.
|
||||||
return c.sendAlert(alertProtocolVersion)
|
c.sendAlert(alertProtocolVersion)
|
||||||
|
return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers)
|
||||||
}
|
}
|
||||||
c.vers = vers
|
c.vers = vers
|
||||||
c.haveVers = true
|
c.haveVers = true
|
||||||
|
|
||||||
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
|
suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite)
|
||||||
if suite == nil {
|
if suite == nil {
|
||||||
return c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
|
return fmt.Errorf("tls: server selected an unsupported cipher suite")
|
||||||
}
|
}
|
||||||
|
|
||||||
hs := &clientHandshakeState{
|
hs := &clientHandshakeState{
|
||||||
@ -209,7 +213,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
certMsg, ok := msg.(*certificateMsg)
|
certMsg, ok := msg.(*certificateMsg)
|
||||||
if !ok || len(certMsg.certificates) == 0 {
|
if !ok || len(certMsg.certificates) == 0 {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(certMsg, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(certMsg.marshal())
|
hs.finishedHash.Write(certMsg.marshal())
|
||||||
|
|
||||||
@ -218,7 +223,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
cert, err := x509.ParseCertificate(asn1Data)
|
cert, err := x509.ParseCertificate(asn1Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.sendAlert(alertBadCertificate)
|
c.sendAlert(alertBadCertificate)
|
||||||
return errors.New("failed to parse certificate from server: " + err.Error())
|
return errors.New("tls: failed to parse certificate from server: " + err.Error())
|
||||||
}
|
}
|
||||||
certs[i] = cert
|
certs[i] = cert
|
||||||
}
|
}
|
||||||
@ -248,7 +253,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
case *rsa.PublicKey, *ecdsa.PublicKey:
|
case *rsa.PublicKey, *ecdsa.PublicKey:
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
return c.sendAlert(alertUnsupportedCertificate)
|
c.sendAlert(alertUnsupportedCertificate)
|
||||||
|
return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.peerCertificates = certs
|
c.peerCertificates = certs
|
||||||
@ -260,7 +266,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
cs, ok := msg.(*certificateStatusMsg)
|
cs, ok := msg.(*certificateStatusMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(cs, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(cs.marshal())
|
hs.finishedHash.Write(cs.marshal())
|
||||||
|
|
||||||
@ -371,7 +378,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
|
|
||||||
shd, ok := msg.(*serverHelloDoneMsg)
|
shd, ok := msg.(*serverHelloDoneMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(shd, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(shd.marshal())
|
hs.finishedHash.Write(shd.marshal())
|
||||||
|
|
||||||
@ -421,7 +429,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
|||||||
err = errors.New("unknown private key type")
|
err = errors.New("unknown private key type")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return errors.New("tls: failed to sign handshake with client certificate: " + err.Error())
|
||||||
}
|
}
|
||||||
certVerify.signature = signed
|
certVerify.signature = signed
|
||||||
|
|
||||||
@ -466,12 +475,13 @@ func (hs *clientHandshakeState) processServerHello() (bool, error) {
|
|||||||
c := hs.c
|
c := hs.c
|
||||||
|
|
||||||
if hs.serverHello.compressionMethod != compressionNone {
|
if hs.serverHello.compressionMethod != compressionNone {
|
||||||
return false, c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return false, errors.New("tls: server selected unsupported compression format")
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
|
if !hs.hello.nextProtoNeg && hs.serverHello.nextProtoNeg {
|
||||||
c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
return false, errors.New("server advertised unrequested NPN")
|
return false, errors.New("server advertised unrequested NPN extension")
|
||||||
}
|
}
|
||||||
|
|
||||||
if hs.serverResumedSession() {
|
if hs.serverResumedSession() {
|
||||||
@ -497,13 +507,15 @@ func (hs *clientHandshakeState) readFinished() error {
|
|||||||
}
|
}
|
||||||
serverFinished, ok := msg.(*finishedMsg)
|
serverFinished, ok := msg.(*finishedMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(serverFinished, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
verify := hs.finishedHash.serverSum(hs.masterSecret)
|
verify := hs.finishedHash.serverSum(hs.masterSecret)
|
||||||
if len(verify) != len(serverFinished.verifyData) ||
|
if len(verify) != len(serverFinished.verifyData) ||
|
||||||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
|
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
|
||||||
return c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
|
return errors.New("tls: server's Finished message was incorrect")
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(serverFinished.marshal())
|
hs.finishedHash.Write(serverFinished.marshal())
|
||||||
return nil
|
return nil
|
||||||
@ -521,7 +533,8 @@ func (hs *clientHandshakeState) readSessionTicket() error {
|
|||||||
}
|
}
|
||||||
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
|
sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(sessionTicketMsg, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(sessionTicketMsg.marshal())
|
hs.finishedHash.Write(sessionTicketMsg.marshal())
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/asn1"
|
"encoding/asn1"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,11 +101,13 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) {
|
|||||||
var ok bool
|
var ok bool
|
||||||
hs.clientHello, ok = msg.(*clientHelloMsg)
|
hs.clientHello, ok = msg.(*clientHelloMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return false, unexpectedMessageError(hs.clientHello, msg)
|
||||||
}
|
}
|
||||||
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
|
c.vers, ok = config.mutualVersion(hs.clientHello.vers)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false, c.sendAlert(alertProtocolVersion)
|
c.sendAlert(alertProtocolVersion)
|
||||||
|
return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
|
||||||
}
|
}
|
||||||
c.haveVers = true
|
c.haveVers = true
|
||||||
|
|
||||||
@ -142,14 +145,16 @@ Curves:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !foundCompression {
|
if !foundCompression {
|
||||||
return false, c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
|
return false, errors.New("tls: client does not support uncompressed connections")
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.hello.vers = c.vers
|
hs.hello.vers = c.vers
|
||||||
hs.hello.random = make([]byte, 32)
|
hs.hello.random = make([]byte, 32)
|
||||||
_, err = io.ReadFull(config.rand(), hs.hello.random)
|
_, err = io.ReadFull(config.rand(), hs.hello.random)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation
|
hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation
|
||||||
hs.hello.compressionMethod = compressionNone
|
hs.hello.compressionMethod = compressionNone
|
||||||
@ -166,7 +171,8 @@ Curves:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Certificates) == 0 {
|
if len(config.Certificates) == 0 {
|
||||||
return false, c.sendAlert(alertInternalError)
|
c.sendAlert(alertInternalError)
|
||||||
|
return false, errors.New("tls: no certificates configured")
|
||||||
}
|
}
|
||||||
hs.cert = &config.Certificates[0]
|
hs.cert = &config.Certificates[0]
|
||||||
if len(hs.clientHello.serverName) > 0 {
|
if len(hs.clientHello.serverName) > 0 {
|
||||||
@ -195,7 +201,8 @@ Curves:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if hs.suite == nil {
|
if hs.suite == nil {
|
||||||
return false, c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
|
return false, errors.New("tls: no cipher suite supported by both client and server")
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
@ -345,7 +352,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
// certificate message, even if it's empty.
|
// certificate message, even if it's empty.
|
||||||
if config.ClientAuth >= RequestClientCert {
|
if config.ClientAuth >= RequestClientCert {
|
||||||
if certMsg, ok = msg.(*certificateMsg); !ok {
|
if certMsg, ok = msg.(*certificateMsg); !ok {
|
||||||
return c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(certMsg, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(certMsg.marshal())
|
hs.finishedHash.Write(certMsg.marshal())
|
||||||
|
|
||||||
@ -372,7 +380,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
// Get client key exchange
|
// Get client key exchange
|
||||||
ckx, ok := msg.(*clientKeyExchangeMsg)
|
ckx, ok := msg.(*clientKeyExchangeMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(ckx, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(ckx.marshal())
|
hs.finishedHash.Write(ckx.marshal())
|
||||||
|
|
||||||
@ -389,7 +398,8 @@ func (hs *serverHandshakeState) doFullHandshake() error {
|
|||||||
}
|
}
|
||||||
certVerify, ok := msg.(*certificateVerifyMsg)
|
certVerify, ok := msg.(*certificateVerifyMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(certVerify, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch key := pub.(type) {
|
switch key := pub.(type) {
|
||||||
@ -469,7 +479,8 @@ func (hs *serverHandshakeState) readFinished() error {
|
|||||||
}
|
}
|
||||||
nextProto, ok := msg.(*nextProtoMsg)
|
nextProto, ok := msg.(*nextProtoMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(nextProto, msg)
|
||||||
}
|
}
|
||||||
hs.finishedHash.Write(nextProto.marshal())
|
hs.finishedHash.Write(nextProto.marshal())
|
||||||
c.clientProtocol = nextProto.proto
|
c.clientProtocol = nextProto.proto
|
||||||
@ -481,13 +492,15 @@ func (hs *serverHandshakeState) readFinished() error {
|
|||||||
}
|
}
|
||||||
clientFinished, ok := msg.(*finishedMsg)
|
clientFinished, ok := msg.(*finishedMsg)
|
||||||
if !ok {
|
if !ok {
|
||||||
return c.sendAlert(alertUnexpectedMessage)
|
c.sendAlert(alertUnexpectedMessage)
|
||||||
|
return unexpectedMessageError(clientFinished, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
verify := hs.finishedHash.clientSum(hs.masterSecret)
|
verify := hs.finishedHash.clientSum(hs.masterSecret)
|
||||||
if len(verify) != len(clientFinished.verifyData) ||
|
if len(verify) != len(clientFinished.verifyData) ||
|
||||||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
|
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
|
||||||
return c.sendAlert(alertHandshakeFailure)
|
c.sendAlert(alertHandshakeFailure)
|
||||||
|
return errors.New("tls: client's Finished message is incorrect")
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.finishedHash.Write(clientFinished.marshal())
|
hs.finishedHash.Write(clientFinished.marshal())
|
||||||
@ -590,7 +603,8 @@ func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (c
|
|||||||
case *ecdsa.PublicKey, *rsa.PublicKey:
|
case *ecdsa.PublicKey, *rsa.PublicKey:
|
||||||
pub = key
|
pub = key
|
||||||
default:
|
default:
|
||||||
return nil, c.sendAlert(alertUnsupportedCertificate)
|
c.sendAlert(alertUnsupportedCertificate)
|
||||||
|
return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey)
|
||||||
}
|
}
|
||||||
c.peerCertificates = certs
|
c.peerCertificates = certs
|
||||||
return pub, nil
|
return pub, nil
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -53,7 +54,7 @@ func init() {
|
|||||||
testConfig.BuildNameToCertificate()
|
testConfig.BuildNameToCertificate()
|
||||||
}
|
}
|
||||||
|
|
||||||
func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) {
|
func testClientHelloFailure(t *testing.T, m handshakeMessage, expectedSubStr string) {
|
||||||
// Create in-memory network connection,
|
// Create in-memory network connection,
|
||||||
// send message to server. Should return
|
// send message to server. Should return
|
||||||
// expected error.
|
// expected error.
|
||||||
@ -68,20 +69,20 @@ func testClientHelloFailure(t *testing.T, m handshakeMessage, expected error) {
|
|||||||
}()
|
}()
|
||||||
err := Server(s, testConfig).Handshake()
|
err := Server(s, testConfig).Handshake()
|
||||||
s.Close()
|
s.Close()
|
||||||
if e, ok := err.(*net.OpError); !ok || e.Err != expected {
|
if err == nil || !strings.Contains(err.Error(), expectedSubStr) {
|
||||||
t.Errorf("Got error: %s; expected: %s", err, expected)
|
t.Errorf("Got error: %s; expected to match substring '%s'", err, expectedSubStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSimpleError(t *testing.T) {
|
func TestSimpleError(t *testing.T) {
|
||||||
testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage)
|
testClientHelloFailure(t, &serverHelloDoneMsg{}, "unexpected handshake message")
|
||||||
}
|
}
|
||||||
|
|
||||||
var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205}
|
var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205}
|
||||||
|
|
||||||
func TestRejectBadProtocolVersion(t *testing.T) {
|
func TestRejectBadProtocolVersion(t *testing.T) {
|
||||||
for _, v := range badProtocolVersions {
|
for _, v := range badProtocolVersions {
|
||||||
testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion)
|
testClientHelloFailure(t, &clientHelloMsg{vers: v}, "unsupported, maximum protocol version")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,7 +92,7 @@ func TestNoSuiteOverlap(t *testing.T) {
|
|||||||
cipherSuites: []uint16{0xff00},
|
cipherSuites: []uint16{0xff00},
|
||||||
compressionMethods: []uint8{0},
|
compressionMethods: []uint8{0},
|
||||||
}
|
}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, "no cipher suite supported by both client and server")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNoCompressionOverlap(t *testing.T) {
|
func TestNoCompressionOverlap(t *testing.T) {
|
||||||
@ -100,7 +101,7 @@ func TestNoCompressionOverlap(t *testing.T) {
|
|||||||
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
|
||||||
compressionMethods: []uint8{0xff},
|
compressionMethods: []uint8{0xff},
|
||||||
}
|
}
|
||||||
testClientHelloFailure(t, clientHello, alertHandshakeFailure)
|
testClientHelloFailure(t, clientHello, "client does not support uncompressed connections")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTLS12OnlyCipherSuites(t *testing.T) {
|
func TestTLS12OnlyCipherSuites(t *testing.T) {
|
||||||
|
@ -19,6 +19,9 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
|
||||||
|
var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
|
||||||
|
|
||||||
// rsaKeyAgreement implements the standard TLS key agreement where the client
|
// rsaKeyAgreement implements the standard TLS key agreement where the client
|
||||||
// encrypts the pre-master secret to the server's public key.
|
// encrypts the pre-master secret to the server's public key.
|
||||||
type rsaKeyAgreement struct{}
|
type rsaKeyAgreement struct{}
|
||||||
@ -35,14 +38,14 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(ckx.ciphertext) < 2 {
|
if len(ckx.ciphertext) < 2 {
|
||||||
return nil, errors.New("bad ClientKeyExchange")
|
return nil, errClientKeyExchange
|
||||||
}
|
}
|
||||||
|
|
||||||
ciphertext := ckx.ciphertext
|
ciphertext := ckx.ciphertext
|
||||||
if version != VersionSSL30 {
|
if version != VersionSSL30 {
|
||||||
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
|
ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
|
||||||
if ciphertextLen != len(ckx.ciphertext)-2 {
|
if ciphertextLen != len(ckx.ciphertext)-2 {
|
||||||
return nil, errors.New("bad ClientKeyExchange")
|
return nil, errClientKeyExchange
|
||||||
}
|
}
|
||||||
ciphertext = ckx.ciphertext[2:]
|
ciphertext = ckx.ciphertext[2:]
|
||||||
}
|
}
|
||||||
@ -61,7 +64,7 @@ func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certifi
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||||
return errors.New("unexpected ServerKeyExchange")
|
return errors.New("tls: unexpected ServerKeyExchange")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
|
||||||
@ -271,11 +274,11 @@ Curve:
|
|||||||
|
|
||||||
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
|
||||||
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
|
if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
|
||||||
return nil, errors.New("bad ClientKeyExchange")
|
return nil, errClientKeyExchange
|
||||||
}
|
}
|
||||||
x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:])
|
x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:])
|
||||||
if x == nil {
|
if x == nil {
|
||||||
return nil, errors.New("bad ClientKeyExchange")
|
return nil, errClientKeyExchange
|
||||||
}
|
}
|
||||||
x, _ = ka.curve.ScalarMult(x, y, ka.privateKey)
|
x, _ = ka.curve.ScalarMult(x, y, ka.privateKey)
|
||||||
preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3)
|
preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3)
|
||||||
@ -285,8 +288,6 @@ func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Cert
|
|||||||
return preMasterSecret, nil
|
return preMasterSecret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var errServerKeyExchange = errors.New("invalid ServerKeyExchange")
|
|
||||||
|
|
||||||
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
|
||||||
if len(skx.key) < 4 {
|
if len(skx.key) < 4 {
|
||||||
return errServerKeyExchange
|
return errServerKeyExchange
|
||||||
|
Loading…
Reference in New Issue
Block a user