Browse Source

crypto/tls: use correct alerts

BoGo: Resume-Server-PSKBinderFirstExtension
BoGo: Resume-Server-ExtraPSKBinder
BoGo: Resume-Server-ExtraIdentityNoBinder
BoGo: Renegotiate-Server-Forbidden
BoGo: NoNullCompression
BoGo: TrailingMessageData-*
v1.2.3
Filippo Valsorda 7 years ago
committed by Peter Wu
parent
commit
4191962f25
11 changed files with 255 additions and 225 deletions
  1. +31
    -25
      13.go
  2. +1
    -0
      alert.go
  3. +1
    -1
      common.go
  4. +2
    -2
      conn.go
  5. +1
    -1
      handshake_client.go
  6. +1
    -1
      handshake_client_test.go
  7. +186
    -170
      handshake_messages.go
  8. +6
    -6
      handshake_messages_test.go
  9. +8
    -4
      handshake_server.go
  10. +1
    -1
      handshake_server_test.go
  11. +17
    -14
      ticket.go

+ 31
- 25
13.go View File

@@ -56,11 +56,16 @@ CurvePreferenceLoop:
hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
hashSize := hash.Size() hashSize := hash.Size()


earlySecret, isPSK := hs.checkPSK()
if !isPSK {
earlySecret, pskAlert := hs.checkPSK()
switch {
case pskAlert != alertSuccess:
c.sendAlert(pskAlert)
return errors.New("tls: invalid client PSK")
case earlySecret == nil:
earlySecret = hkdfExtract(hash, nil, nil) earlySecret = hkdfExtract(hash, nil, nil)
case earlySecret != nil:
c.didResume = true
} }
c.didResume = isPSK


hs.finishedHash13 = hash.New() hs.finishedHash13 = hash.New()
hs.finishedHash13.Write(hs.clientHello.marshal()) hs.finishedHash13.Write(hs.clientHello.marshal())
@@ -94,7 +99,7 @@ CurvePreferenceLoop:
return err return err
} }


if !isPSK {
if !c.didResume {
if err := hs.sendCertificate13(); err != nil { if err := hs.sendCertificate13(); err != nil {
return err return err
} }
@@ -151,7 +156,7 @@ func (hs *serverHandshakeState) readClientFinished13() error {
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey)
if len(expectedVerifyData) != len(clientFinished.verifyData) || if len(expectedVerifyData) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect") return errors.New("tls: client's Finished message is incorrect")
} }
hs.finishedHash13.Write(clientFinished.marshal()) hs.finishedHash13.Write(clientFinished.marshal())
@@ -403,9 +408,9 @@ func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2.
const ticketAgeSkewAllowance = 10 * time.Second const ticketAgeSkewAllowance = 10 * time.Second


func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) {
if hs.c.config.SessionTicketsDisabled { if hs.c.config.SessionTicketsDisabled {
return nil, false
return nil, alertSuccess
} }


foundDHE := false foundDHE := false
@@ -416,7 +421,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
} }
} }
if !foundDHE { if !foundDHE {
return nil, false
return nil, alertSuccess
} }


hash := hashForSuite(hs.suite) hash := hashForSuite(hs.suite)
@@ -428,7 +433,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
continue continue
} }
s := &sessionState13{} s := &sessionState13{}
if ok := s.unmarshal(serializedTicket); !ok {
if s.unmarshal(serializedTicket) != alertSuccess {
continue continue
} }
if s.vers != hs.c.vers { if s.vers != hs.c.vers {
@@ -465,26 +470,27 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) {
chHash.Write(hs.clientHello.rawTruncated) chHash.Write(hs.clientHello.rawTruncated)
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)


if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 {
if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such. We could close the connection
// here, but instead we just ignore the ticket and the 0-RTT data.
continue
}
if hs.c.config.Accept0RTTData {
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
hs.hello13Enc.earlyData = true
}
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 {
return nil, alertDecryptError
}

if i == 0 && hs.clientHello.earlyData {
// This is a ticket intended to be used for 0-RTT
if s.maxEarlyDataLen == 0 {
// But we had not tagged it as such.
return nil, alertIllegalParameter
}
if hs.c.config.Accept0RTTData {
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen)
hs.hello13Enc.earlyData = true
} }
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
return earlySecret, true
} }
hs.hello13.psk = true
hs.hello13.pskIdentity = uint16(i)
return earlySecret, alertSuccess
} }


return nil, false
return nil, alertSuccess
} }


func (hs *serverHandshakeState) sendSessionTicket13() error { func (hs *serverHandshakeState) sendSessionTicket13() error {


+ 1
- 0
alert.go View File

@@ -40,6 +40,7 @@ const (
alertUserCanceled alert = 90 alertUserCanceled alert = 90
alertNoRenegotiation alert = 100 alertNoRenegotiation alert = 100
alertNoApplicationProtocol alert = 120 alertNoApplicationProtocol alert = 120
alertSuccess alert = 255 // dummy value returned by unmarshal functions
) )


var alertText = map[alert]string{ var alertText = map[alert]string{


+ 1
- 1
common.go View File

@@ -939,7 +939,7 @@ type Certificate struct {


type handshakeMessage interface { type handshakeMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool
unmarshal([]byte) alert
} }


// lruSessionCache is a ClientSessionCache implementation that uses an LRU // lruSessionCache is a ClientSessionCache implementation that uses an LRU


+ 2
- 2
conn.go View File

@@ -1126,8 +1126,8 @@ func (c *Conn) readHandshake() (interface{}, error) {
// so pass in a fresh copy that won't be overwritten. // so pass in a fresh copy that won't be overwritten.
data = append([]byte(nil), data...) data = append([]byte(nil), data...)


if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
if unmarshalAlert := m.unmarshal(data); unmarshalAlert != alertSuccess {
return nil, c.in.setErrorLocked(c.sendAlert(unmarshalAlert))
} }
return m, nil return m, nil
} }


+ 1
- 1
handshake_client.go View File

@@ -602,7 +602,7 @@ func (hs *clientHandshakeState) readFinished(out []byte) error {
verify := hs.finishedHash.serverSum(hs.masterSecret) verify := hs.finishedHash.serverSum(hs.masterSecret)
if len(verify) != len(serverFinished.verifyData) || if len(verify) != len(serverFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
c.sendAlert(alertDecryptError)
return errors.New("tls: server's Finished message was incorrect") return errors.New("tls: server's Finished message was incorrect")
} }
hs.finishedHash.Write(serverFinished.marshal()) hs.finishedHash.Write(serverFinished.marshal())


+ 1
- 1
handshake_client_test.go View File

@@ -1024,7 +1024,7 @@ func TestHostnameInSNI(t *testing.T) {
s.Close() s.Close()


var m clientHelloMsg var m clientHelloMsg
if !m.unmarshal(record) {
if m.unmarshal(record) != alertSuccess {
t.Errorf("unmarshaling ClientHello for %q failed", tt.in) t.Errorf("unmarshaling ClientHello for %q failed", tt.in)
continue continue
} }


+ 186
- 170
handshake_messages.go
File diff suppressed because it is too large
View File


+ 6
- 6
handshake_messages_test.go View File

@@ -35,7 +35,7 @@ var tests = []interface{}{


type testMessage interface { type testMessage interface {
marshal() []byte marshal() []byte
unmarshal([]byte) bool
unmarshal([]byte) alert
equal(interface{}) bool equal(interface{}) bool
} }


@@ -59,7 +59,7 @@ func TestMarshalUnmarshal(t *testing.T) {
m1 := v.Interface().(testMessage) m1 := v.Interface().(testMessage)
marshaled := m1.marshal() marshaled := m1.marshal()
m2 := iface.(testMessage) m2 := iface.(testMessage)
if !m2.unmarshal(marshaled) {
if m2.unmarshal(marshaled) != alertSuccess {
t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled) t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled)
break break
} }
@@ -77,7 +77,7 @@ func TestMarshalUnmarshal(t *testing.T) {
// data is optional and the length of the // data is optional and the length of the
// Finished varies across versions. // Finished varies across versions.
for j := 0; j < len(marshaled); j++ { for j := 0; j < len(marshaled); j++ {
if m2.unmarshal(marshaled[0:j]) {
if m2.unmarshal(marshaled[0:j]) == alertSuccess {
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
break break
} }
@@ -364,7 +364,7 @@ func TestRejectEmptySCTList(t *testing.T) {
serverHelloBytes := serverHello.marshal() serverHelloBytes := serverHello.marshal()


var serverHelloCopy serverHelloMsg var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess {
t.Fatal("Failed to unmarshal initial message") t.Fatal("Failed to unmarshal initial message")
} }


@@ -389,7 +389,7 @@ func TestRejectEmptySCTList(t *testing.T) {
serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8)
serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44))


if serverHelloCopy.unmarshal(serverHelloEmptySCT) {
if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess {
t.Fatal("Unmarshaled ServerHello with empty SCT list") t.Fatal("Unmarshaled ServerHello with empty SCT list")
} }
} }
@@ -407,7 +407,7 @@ func TestRejectEmptySCT(t *testing.T) {
serverHelloBytes := serverHello.marshal() serverHelloBytes := serverHello.marshal()


var serverHelloCopy serverHelloMsg var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess {
t.Fatal("Unmarshaled ServerHello with zero-length SCT") t.Fatal("Unmarshaled ServerHello with zero-length SCT")
} }
} }

+ 8
- 4
handshake_server.go View File

@@ -222,7 +222,7 @@ Curves:
} }


if !foundCompression { if !foundCompression {
c.sendAlert(alertHandshakeFailure)
c.sendAlert(alertIllegalParameter)
return false, errors.New("tls: client does not support uncompressed connections") return false, errors.New("tls: client does not support uncompressed connections")
} }
if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 { if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 {
@@ -370,7 +370,7 @@ func (hs *serverHandshakeState) checkForResumption() bool {
sessionTicket := append([]uint8{}, hs.clientHello.sessionTicket...) sessionTicket := append([]uint8{}, hs.clientHello.sessionTicket...)
serializedState, usedOldKey := c.decryptTicket(sessionTicket) serializedState, usedOldKey := c.decryptTicket(sessionTicket)
hs.sessionState = &sessionState{usedOldKey: usedOldKey} hs.sessionState = &sessionState{usedOldKey: usedOldKey}
if ok := hs.sessionState.unmarshal(serializedState); !ok {
if hs.sessionState.unmarshal(serializedState) != alertSuccess {
return false return false
} }


@@ -570,7 +570,11 @@ func (hs *serverHandshakeState) doFullHandshake() error {


preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil { if err != nil {
c.sendAlert(alertHandshakeFailure)
if err == errClientKeyExchange {
c.sendAlert(alertDecodeError)
} else {
c.sendAlert(alertInternalError)
}
return err return err
} }
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
@@ -721,7 +725,7 @@ func (hs *serverHandshakeState) readFinished(out []byte) error {
verify := hs.finishedHash.clientSum(hs.masterSecret) verify := hs.finishedHash.clientSum(hs.masterSecret)
if len(verify) != len(clientFinished.verifyData) || if len(verify) != len(clientFinished.verifyData) ||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
c.sendAlert(alertHandshakeFailure)
c.sendAlert(alertDecryptError)
return errors.New("tls: client's Finished message is incorrect") return errors.New("tls: client's Finished message is incorrect")
} }




+ 1
- 1
handshake_server_test.go View File

@@ -246,7 +246,7 @@ func TestRenegotiationExtension(t *testing.T) {
var serverHello serverHelloMsg var serverHello serverHelloMsg
// unmarshal expects to be given the handshake header, but // unmarshal expects to be given the handshake header, but
// serverHelloLen doesn't include it. // serverHelloLen doesn't include it.
if !serverHello.unmarshal(buf[5 : 9+serverHelloLen]) {
if serverHello.unmarshal(buf[5:9+serverHelloLen]) != alertSuccess {
t.Fatalf("Failed to parse ServerHello") t.Fatalf("Failed to parse ServerHello")
} }




+ 17
- 14
ticket.go View File

@@ -86,9 +86,9 @@ func (s *sessionState) marshal() []byte {
return ret return ret
} }


func (s *sessionState) unmarshal(data []byte) bool {
func (s *sessionState) unmarshal(data []byte) alert {
if len(data) < 8 { if len(data) < 8 {
return false
return alertDecodeError
} }


s.vers = uint16(data[0])<<8 | uint16(data[1]) s.vers = uint16(data[0])<<8 | uint16(data[1])
@@ -96,14 +96,14 @@ func (s *sessionState) unmarshal(data []byte) bool {
masterSecretLen := int(data[4])<<8 | int(data[5]) masterSecretLen := int(data[4])<<8 | int(data[5])
data = data[6:] data = data[6:]
if len(data) < masterSecretLen { if len(data) < masterSecretLen {
return false
return alertDecodeError
} }


s.masterSecret = data[:masterSecretLen] s.masterSecret = data[:masterSecretLen]
data = data[masterSecretLen:] data = data[masterSecretLen:]


if len(data) < 2 { if len(data) < 2 {
return false
return alertDecodeError
} }


numCerts := int(data[0])<<8 | int(data[1]) numCerts := int(data[0])<<8 | int(data[1])
@@ -112,21 +112,24 @@ func (s *sessionState) unmarshal(data []byte) bool {
s.certificates = make([][]byte, numCerts) s.certificates = make([][]byte, numCerts)
for i := range s.certificates { for i := range s.certificates {
if len(data) < 4 { if len(data) < 4 {
return false
return alertDecodeError
} }
certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
data = data[4:] data = data[4:]
if certLen < 0 { if certLen < 0 {
return false
return alertDecodeError
} }
if len(data) < certLen { if len(data) < certLen {
return false
return alertDecodeError
} }
s.certificates[i] = data[:certLen] s.certificates[i] = data[:certLen]
data = data[certLen:] data = data[certLen:]
} }


return len(data) == 0
if len(data) != 0 {
return alertDecodeError
}
return alertSuccess
} }


type sessionState13 struct { type sessionState13 struct {
@@ -195,9 +198,9 @@ func (s *sessionState13) marshal() []byte {
return x return x
} }


func (s *sessionState13) unmarshal(data []byte) bool {
func (s *sessionState13) unmarshal(data []byte) alert {
if len(data) < 24 { if len(data) < 24 {
return false
return alertDecodeError
} }


s.vers = uint16(data[0])<<8 | uint16(data[1]) s.vers = uint16(data[0])<<8 | uint16(data[1])
@@ -209,25 +212,25 @@ func (s *sessionState13) unmarshal(data []byte) bool {


l := int(data[20])<<8 | int(data[21]) l := int(data[20])<<8 | int(data[21])
if len(data) < 22+l+2 { if len(data) < 22+l+2 {
return false
return alertDecodeError
} }
s.resumptionSecret = data[22 : 22+l] s.resumptionSecret = data[22 : 22+l]
z := data[22+l:] z := data[22+l:]


l = int(z[0])<<8 | int(z[1]) l = int(z[0])<<8 | int(z[1])
if len(z) < 2+l+2 { if len(z) < 2+l+2 {
return false
return alertDecodeError
} }
s.alpnProtocol = string(z[2 : 2+l]) s.alpnProtocol = string(z[2 : 2+l])
z = z[2+l:] z = z[2+l:]


l = int(z[0])<<8 | int(z[1]) l = int(z[0])<<8 | int(z[1])
if len(z) != 2+l { if len(z) != 2+l {
return false
return alertDecodeError
} }
s.SNI = string(z[2 : 2+l]) s.SNI = string(z[2 : 2+l])


return true
return alertSuccess
} }


func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) {


Loading…
Cancel
Save