BoGo: Resume-Server-PSKBinderFirstExtension BoGo: Resume-Server-ExtraPSKBinder BoGo: Resume-Server-ExtraIdentityNoBinder BoGo: Renegotiate-Server-Forbidden BoGo: NoNullCompression BoGo: TrailingMessageData-*v1.2.3
@@ -56,11 +56,16 @@ CurvePreferenceLoop: | |||||
hash := hashForSuite(hs.suite) | hash := hashForSuite(hs.suite) | ||||
hashSize := hash.Size() | hashSize := hash.Size() | ||||
earlySecret, isPSK := hs.checkPSK() | |||||
if !isPSK { | |||||
earlySecret, pskAlert := hs.checkPSK() | |||||
switch { | |||||
case pskAlert != alertSuccess: | |||||
c.sendAlert(pskAlert) | |||||
return errors.New("tls: invalid client PSK") | |||||
case earlySecret == nil: | |||||
earlySecret = hkdfExtract(hash, nil, nil) | earlySecret = hkdfExtract(hash, nil, nil) | ||||
case earlySecret != nil: | |||||
c.didResume = true | |||||
} | } | ||||
c.didResume = isPSK | |||||
hs.finishedHash13 = hash.New() | hs.finishedHash13 = hash.New() | ||||
hs.finishedHash13.Write(hs.clientHello.marshal()) | hs.finishedHash13.Write(hs.clientHello.marshal()) | ||||
@@ -94,7 +99,7 @@ CurvePreferenceLoop: | |||||
return err | return err | ||||
} | } | ||||
if !isPSK { | |||||
if !c.didResume { | |||||
if err := hs.sendCertificate13(); err != nil { | if err := hs.sendCertificate13(); err != nil { | ||||
return err | return err | ||||
} | } | ||||
@@ -151,7 +156,7 @@ func (hs *serverHandshakeState) readClientFinished13() error { | |||||
expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) | expectedVerifyData := hmacOfSum(hash, hs.finishedHash13, hs.clientFinishedKey) | ||||
if len(expectedVerifyData) != len(clientFinished.verifyData) || | if len(expectedVerifyData) != len(clientFinished.verifyData) || | ||||
subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { | subtle.ConstantTimeCompare(expectedVerifyData, clientFinished.verifyData) != 1 { | ||||
c.sendAlert(alertHandshakeFailure) | |||||
c.sendAlert(alertDecryptError) | |||||
return errors.New("tls: client's Finished message is incorrect") | return errors.New("tls: client's Finished message is incorrect") | ||||
} | } | ||||
hs.finishedHash13.Write(clientFinished.marshal()) | hs.finishedHash13.Write(clientFinished.marshal()) | ||||
@@ -403,9 +408,9 @@ func (hs *serverHandshakeState) prepareCipher(handshakeCtx, secret []byte, label | |||||
// https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. | // https://tools.ietf.org/html/draft-ietf-tls-tls13-18#section-4.2.8.2. | ||||
const ticketAgeSkewAllowance = 10 * time.Second | const ticketAgeSkewAllowance = 10 * time.Second | ||||
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { | |||||
func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, alert alert) { | |||||
if hs.c.config.SessionTicketsDisabled { | if hs.c.config.SessionTicketsDisabled { | ||||
return nil, false | |||||
return nil, alertSuccess | |||||
} | } | ||||
foundDHE := false | foundDHE := false | ||||
@@ -416,7 +421,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { | |||||
} | } | ||||
} | } | ||||
if !foundDHE { | if !foundDHE { | ||||
return nil, false | |||||
return nil, alertSuccess | |||||
} | } | ||||
hash := hashForSuite(hs.suite) | hash := hashForSuite(hs.suite) | ||||
@@ -428,7 +433,7 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { | |||||
continue | continue | ||||
} | } | ||||
s := &sessionState13{} | s := &sessionState13{} | ||||
if ok := s.unmarshal(serializedTicket); !ok { | |||||
if s.unmarshal(serializedTicket) != alertSuccess { | |||||
continue | continue | ||||
} | } | ||||
if s.vers != hs.c.vers { | if s.vers != hs.c.vers { | ||||
@@ -465,26 +470,27 @@ func (hs *serverHandshakeState) checkPSK() (earlySecret []byte, ok bool) { | |||||
chHash.Write(hs.clientHello.rawTruncated) | chHash.Write(hs.clientHello.rawTruncated) | ||||
expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) | expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey) | ||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) == 1 { | |||||
if i == 0 && hs.clientHello.earlyData { | |||||
// This is a ticket intended to be used for 0-RTT | |||||
if s.maxEarlyDataLen == 0 { | |||||
// But we had not tagged it as such. We could close the connection | |||||
// here, but instead we just ignore the ticket and the 0-RTT data. | |||||
continue | |||||
} | |||||
if hs.c.config.Accept0RTTData { | |||||
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen) | |||||
hs.hello13Enc.earlyData = true | |||||
} | |||||
if subtle.ConstantTimeCompare(expectedBinder, hs.clientHello.psks[i].binder) != 1 { | |||||
return nil, alertDecryptError | |||||
} | |||||
if i == 0 && hs.clientHello.earlyData { | |||||
// This is a ticket intended to be used for 0-RTT | |||||
if s.maxEarlyDataLen == 0 { | |||||
// But we had not tagged it as such. | |||||
return nil, alertIllegalParameter | |||||
} | |||||
if hs.c.config.Accept0RTTData { | |||||
hs.c.ticketMaxEarlyData = int64(s.maxEarlyDataLen) | |||||
hs.hello13Enc.earlyData = true | |||||
} | } | ||||
hs.hello13.psk = true | |||||
hs.hello13.pskIdentity = uint16(i) | |||||
return earlySecret, true | |||||
} | } | ||||
hs.hello13.psk = true | |||||
hs.hello13.pskIdentity = uint16(i) | |||||
return earlySecret, alertSuccess | |||||
} | } | ||||
return nil, false | |||||
return nil, alertSuccess | |||||
} | } | ||||
func (hs *serverHandshakeState) sendSessionTicket13() error { | func (hs *serverHandshakeState) sendSessionTicket13() error { | ||||
@@ -40,6 +40,7 @@ const ( | |||||
alertUserCanceled alert = 90 | alertUserCanceled alert = 90 | ||||
alertNoRenegotiation alert = 100 | alertNoRenegotiation alert = 100 | ||||
alertNoApplicationProtocol alert = 120 | alertNoApplicationProtocol alert = 120 | ||||
alertSuccess alert = 255 // dummy value returned by unmarshal functions | |||||
) | ) | ||||
var alertText = map[alert]string{ | var alertText = map[alert]string{ | ||||
@@ -939,7 +939,7 @@ type Certificate struct { | |||||
type handshakeMessage interface { | type handshakeMessage interface { | ||||
marshal() []byte | marshal() []byte | ||||
unmarshal([]byte) bool | |||||
unmarshal([]byte) alert | |||||
} | } | ||||
// lruSessionCache is a ClientSessionCache implementation that uses an LRU | // lruSessionCache is a ClientSessionCache implementation that uses an LRU | ||||
@@ -1126,8 +1126,8 @@ func (c *Conn) readHandshake() (interface{}, error) { | |||||
// so pass in a fresh copy that won't be overwritten. | // so pass in a fresh copy that won't be overwritten. | ||||
data = append([]byte(nil), data...) | data = append([]byte(nil), data...) | ||||
if !m.unmarshal(data) { | |||||
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) | |||||
if unmarshalAlert := m.unmarshal(data); unmarshalAlert != alertSuccess { | |||||
return nil, c.in.setErrorLocked(c.sendAlert(unmarshalAlert)) | |||||
} | } | ||||
return m, nil | return m, nil | ||||
} | } | ||||
@@ -602,7 +602,7 @@ func (hs *clientHandshakeState) readFinished(out []byte) error { | |||||
verify := hs.finishedHash.serverSum(hs.masterSecret) | verify := hs.finishedHash.serverSum(hs.masterSecret) | ||||
if len(verify) != len(serverFinished.verifyData) || | if len(verify) != len(serverFinished.verifyData) || | ||||
subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { | subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { | ||||
c.sendAlert(alertHandshakeFailure) | |||||
c.sendAlert(alertDecryptError) | |||||
return errors.New("tls: server's Finished message was incorrect") | return errors.New("tls: server's Finished message was incorrect") | ||||
} | } | ||||
hs.finishedHash.Write(serverFinished.marshal()) | hs.finishedHash.Write(serverFinished.marshal()) | ||||
@@ -1024,7 +1024,7 @@ func TestHostnameInSNI(t *testing.T) { | |||||
s.Close() | s.Close() | ||||
var m clientHelloMsg | var m clientHelloMsg | ||||
if !m.unmarshal(record) { | |||||
if m.unmarshal(record) != alertSuccess { | |||||
t.Errorf("unmarshaling ClientHello for %q failed", tt.in) | t.Errorf("unmarshaling ClientHello for %q failed", tt.in) | ||||
continue | continue | ||||
} | } | ||||
@@ -35,7 +35,7 @@ var tests = []interface{}{ | |||||
type testMessage interface { | type testMessage interface { | ||||
marshal() []byte | marshal() []byte | ||||
unmarshal([]byte) bool | |||||
unmarshal([]byte) alert | |||||
equal(interface{}) bool | equal(interface{}) bool | ||||
} | } | ||||
@@ -59,7 +59,7 @@ func TestMarshalUnmarshal(t *testing.T) { | |||||
m1 := v.Interface().(testMessage) | m1 := v.Interface().(testMessage) | ||||
marshaled := m1.marshal() | marshaled := m1.marshal() | ||||
m2 := iface.(testMessage) | m2 := iface.(testMessage) | ||||
if !m2.unmarshal(marshaled) { | |||||
if m2.unmarshal(marshaled) != alertSuccess { | |||||
t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled) | t.Errorf("#%d.%d failed to unmarshal %#v %x", i, j, m1, marshaled) | ||||
break | break | ||||
} | } | ||||
@@ -77,7 +77,7 @@ func TestMarshalUnmarshal(t *testing.T) { | |||||
// data is optional and the length of the | // data is optional and the length of the | ||||
// Finished varies across versions. | // Finished varies across versions. | ||||
for j := 0; j < len(marshaled); j++ { | for j := 0; j < len(marshaled); j++ { | ||||
if m2.unmarshal(marshaled[0:j]) { | |||||
if m2.unmarshal(marshaled[0:j]) == alertSuccess { | |||||
t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) | t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1) | ||||
break | break | ||||
} | } | ||||
@@ -364,7 +364,7 @@ func TestRejectEmptySCTList(t *testing.T) { | |||||
serverHelloBytes := serverHello.marshal() | serverHelloBytes := serverHello.marshal() | ||||
var serverHelloCopy serverHelloMsg | var serverHelloCopy serverHelloMsg | ||||
if !serverHelloCopy.unmarshal(serverHelloBytes) { | |||||
if serverHelloCopy.unmarshal(serverHelloBytes) != alertSuccess { | |||||
t.Fatal("Failed to unmarshal initial message") | t.Fatal("Failed to unmarshal initial message") | ||||
} | } | ||||
@@ -389,7 +389,7 @@ func TestRejectEmptySCTList(t *testing.T) { | |||||
serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) | serverHelloEmptySCT[42] = byte((len(serverHelloEmptySCT) - 44) >> 8) | ||||
serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) | serverHelloEmptySCT[43] = byte((len(serverHelloEmptySCT) - 44)) | ||||
if serverHelloCopy.unmarshal(serverHelloEmptySCT) { | |||||
if serverHelloCopy.unmarshal(serverHelloEmptySCT) == alertSuccess { | |||||
t.Fatal("Unmarshaled ServerHello with empty SCT list") | t.Fatal("Unmarshaled ServerHello with empty SCT list") | ||||
} | } | ||||
} | } | ||||
@@ -407,7 +407,7 @@ func TestRejectEmptySCT(t *testing.T) { | |||||
serverHelloBytes := serverHello.marshal() | serverHelloBytes := serverHello.marshal() | ||||
var serverHelloCopy serverHelloMsg | var serverHelloCopy serverHelloMsg | ||||
if serverHelloCopy.unmarshal(serverHelloBytes) { | |||||
if serverHelloCopy.unmarshal(serverHelloBytes) == alertSuccess { | |||||
t.Fatal("Unmarshaled ServerHello with zero-length SCT") | t.Fatal("Unmarshaled ServerHello with zero-length SCT") | ||||
} | } | ||||
} | } |
@@ -222,7 +222,7 @@ Curves: | |||||
} | } | ||||
if !foundCompression { | if !foundCompression { | ||||
c.sendAlert(alertHandshakeFailure) | |||||
c.sendAlert(alertIllegalParameter) | |||||
return false, errors.New("tls: client does not support uncompressed connections") | return false, errors.New("tls: client does not support uncompressed connections") | ||||
} | } | ||||
if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 { | if len(hs.clientHello.compressionMethods) != 1 && c.vers >= VersionTLS13 { | ||||
@@ -370,7 +370,7 @@ func (hs *serverHandshakeState) checkForResumption() bool { | |||||
sessionTicket := append([]uint8{}, hs.clientHello.sessionTicket...) | sessionTicket := append([]uint8{}, hs.clientHello.sessionTicket...) | ||||
serializedState, usedOldKey := c.decryptTicket(sessionTicket) | serializedState, usedOldKey := c.decryptTicket(sessionTicket) | ||||
hs.sessionState = &sessionState{usedOldKey: usedOldKey} | hs.sessionState = &sessionState{usedOldKey: usedOldKey} | ||||
if ok := hs.sessionState.unmarshal(serializedState); !ok { | |||||
if hs.sessionState.unmarshal(serializedState) != alertSuccess { | |||||
return false | return false | ||||
} | } | ||||
@@ -570,7 +570,11 @@ func (hs *serverHandshakeState) doFullHandshake() error { | |||||
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) | preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers) | ||||
if err != nil { | if err != nil { | ||||
c.sendAlert(alertHandshakeFailure) | |||||
if err == errClientKeyExchange { | |||||
c.sendAlert(alertDecodeError) | |||||
} else { | |||||
c.sendAlert(alertInternalError) | |||||
} | |||||
return err | return err | ||||
} | } | ||||
hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) | hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) | ||||
@@ -721,7 +725,7 @@ func (hs *serverHandshakeState) readFinished(out []byte) error { | |||||
verify := hs.finishedHash.clientSum(hs.masterSecret) | verify := hs.finishedHash.clientSum(hs.masterSecret) | ||||
if len(verify) != len(clientFinished.verifyData) || | if len(verify) != len(clientFinished.verifyData) || | ||||
subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { | subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { | ||||
c.sendAlert(alertHandshakeFailure) | |||||
c.sendAlert(alertDecryptError) | |||||
return errors.New("tls: client's Finished message is incorrect") | return errors.New("tls: client's Finished message is incorrect") | ||||
} | } | ||||
@@ -246,7 +246,7 @@ func TestRenegotiationExtension(t *testing.T) { | |||||
var serverHello serverHelloMsg | var serverHello serverHelloMsg | ||||
// unmarshal expects to be given the handshake header, but | // unmarshal expects to be given the handshake header, but | ||||
// serverHelloLen doesn't include it. | // serverHelloLen doesn't include it. | ||||
if !serverHello.unmarshal(buf[5 : 9+serverHelloLen]) { | |||||
if serverHello.unmarshal(buf[5:9+serverHelloLen]) != alertSuccess { | |||||
t.Fatalf("Failed to parse ServerHello") | t.Fatalf("Failed to parse ServerHello") | ||||
} | } | ||||
@@ -86,9 +86,9 @@ func (s *sessionState) marshal() []byte { | |||||
return ret | return ret | ||||
} | } | ||||
func (s *sessionState) unmarshal(data []byte) bool { | |||||
func (s *sessionState) unmarshal(data []byte) alert { | |||||
if len(data) < 8 { | if len(data) < 8 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.vers = uint16(data[0])<<8 | uint16(data[1]) | s.vers = uint16(data[0])<<8 | uint16(data[1]) | ||||
@@ -96,14 +96,14 @@ func (s *sessionState) unmarshal(data []byte) bool { | |||||
masterSecretLen := int(data[4])<<8 | int(data[5]) | masterSecretLen := int(data[4])<<8 | int(data[5]) | ||||
data = data[6:] | data = data[6:] | ||||
if len(data) < masterSecretLen { | if len(data) < masterSecretLen { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.masterSecret = data[:masterSecretLen] | s.masterSecret = data[:masterSecretLen] | ||||
data = data[masterSecretLen:] | data = data[masterSecretLen:] | ||||
if len(data) < 2 { | if len(data) < 2 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
numCerts := int(data[0])<<8 | int(data[1]) | numCerts := int(data[0])<<8 | int(data[1]) | ||||
@@ -112,21 +112,24 @@ func (s *sessionState) unmarshal(data []byte) bool { | |||||
s.certificates = make([][]byte, numCerts) | s.certificates = make([][]byte, numCerts) | ||||
for i := range s.certificates { | for i := range s.certificates { | ||||
if len(data) < 4 { | if len(data) < 4 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) | certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) | ||||
data = data[4:] | data = data[4:] | ||||
if certLen < 0 { | if certLen < 0 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
if len(data) < certLen { | if len(data) < certLen { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.certificates[i] = data[:certLen] | s.certificates[i] = data[:certLen] | ||||
data = data[certLen:] | data = data[certLen:] | ||||
} | } | ||||
return len(data) == 0 | |||||
if len(data) != 0 { | |||||
return alertDecodeError | |||||
} | |||||
return alertSuccess | |||||
} | } | ||||
type sessionState13 struct { | type sessionState13 struct { | ||||
@@ -195,9 +198,9 @@ func (s *sessionState13) marshal() []byte { | |||||
return x | return x | ||||
} | } | ||||
func (s *sessionState13) unmarshal(data []byte) bool { | |||||
func (s *sessionState13) unmarshal(data []byte) alert { | |||||
if len(data) < 24 { | if len(data) < 24 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.vers = uint16(data[0])<<8 | uint16(data[1]) | s.vers = uint16(data[0])<<8 | uint16(data[1]) | ||||
@@ -209,25 +212,25 @@ func (s *sessionState13) unmarshal(data []byte) bool { | |||||
l := int(data[20])<<8 | int(data[21]) | l := int(data[20])<<8 | int(data[21]) | ||||
if len(data) < 22+l+2 { | if len(data) < 22+l+2 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.resumptionSecret = data[22 : 22+l] | s.resumptionSecret = data[22 : 22+l] | ||||
z := data[22+l:] | z := data[22+l:] | ||||
l = int(z[0])<<8 | int(z[1]) | l = int(z[0])<<8 | int(z[1]) | ||||
if len(z) < 2+l+2 { | if len(z) < 2+l+2 { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.alpnProtocol = string(z[2 : 2+l]) | s.alpnProtocol = string(z[2 : 2+l]) | ||||
z = z[2+l:] | z = z[2+l:] | ||||
l = int(z[0])<<8 | int(z[1]) | l = int(z[0])<<8 | int(z[1]) | ||||
if len(z) != 2+l { | if len(z) != 2+l { | ||||
return false | |||||
return alertDecodeError | |||||
} | } | ||||
s.SNI = string(z[2 : 2+l]) | s.SNI = string(z[2 : 2+l]) | ||||
return true | |||||
return alertSuccess | |||||
} | } | ||||
func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { | func (c *Conn) encryptTicket(serialized []byte) ([]byte, error) { | ||||