diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc index 48fe0529..dfb9c925 100644 --- a/ssl/handshake_client.cc +++ b/ssl/handshake_client.cc @@ -806,12 +806,80 @@ static int dtls1_get_hello_verify_request(SSL_HANDSHAKE *hs) { return 1; } +static int parse_server_version(SSL_HANDSHAKE *hs, uint16_t *out) { + SSL *const ssl = hs->ssl; + if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO && + ssl->s3->tmp.message_type != SSL3_MT_HELLO_RETRY_REQUEST) { + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); + return 0; + } + + CBS server_hello; + CBS_init(&server_hello, ssl->init_msg, ssl->init_num); + if (!CBS_get_u16(&server_hello, out)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR); + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); + return 0; + } + + /* The server version may also be in the supported_versions extension if + * applicable. */ + if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO || + *out != TLS1_2_VERSION) { + return 1; + } + + uint8_t sid_length; + if (!CBS_skip(&server_hello, SSL3_RANDOM_SIZE) || + !CBS_get_u8(&server_hello, &sid_length) || + !CBS_skip(&server_hello, sid_length + 2 /* cipher_suite */ + + 1 /* compression_method */)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR); + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); + return 0; + } + + /* The extensions block may not be present. */ + if (CBS_len(&server_hello) == 0) { + return 1; + } + + CBS extensions; + if (!CBS_get_u16_length_prefixed(&server_hello, &extensions) || + CBS_len(&server_hello) != 0) { + OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR); + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); + return 0; + } + + int have_supported_versions; + CBS supported_versions; + const SSL_EXTENSION_TYPE ext_types[] = { + {TLSEXT_TYPE_supported_versions, &have_supported_versions, + &supported_versions}, + }; + + uint8_t alert = SSL_AD_DECODE_ERROR; + if (!ssl_parse_extensions(&extensions, &alert, ext_types, + OPENSSL_ARRAY_SIZE(ext_types), + 1 /* ignore unknown */)) { + ssl3_send_alert(ssl, SSL3_AL_FATAL, alert); + return 0; + } + + if (have_supported_versions && + (!CBS_get_u16(&supported_versions, out) || + CBS_len(&supported_versions) != 0)) { + ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); + return 0; + } + + return 1; +} + static int ssl3_get_server_hello(SSL_HANDSHAKE *hs) { SSL *const ssl = hs->ssl; - CBS server_hello, server_random, session_id; - uint16_t server_version, cipher_suite; - uint8_t compression_method; - int ret = ssl->method->ssl_get_message(ssl); if (ret <= 0) { uint32_t err = ERR_peek_error(); @@ -828,62 +896,11 @@ static int ssl3_get_server_hello(SSL_HANDSHAKE *hs) { return ret; } - if (ssl->s3->tmp.message_type != SSL3_MT_SERVER_HELLO && - ssl->s3->tmp.message_type != SSL3_MT_HELLO_RETRY_REQUEST) { - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE); - OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE); + uint16_t server_version; + if (!parse_server_version(hs, &server_version)) { return -1; } - CBS_init(&server_hello, ssl->init_msg, ssl->init_num); - - if (!CBS_get_u16(&server_hello, &server_version)) { - OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR); - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); - return -1; - } - - /* Parse out server version from supported_versions if available. */ - if (ssl->s3->tmp.message_type == SSL3_MT_SERVER_HELLO && - server_version == TLS1_2_VERSION) { - CBS copy = server_hello; - CBS extensions; - uint8_t sid_length; - if (!CBS_skip(©, SSL3_RANDOM_SIZE) || - !CBS_get_u8(©, &sid_length) || - !CBS_skip(©, sid_length + 2 /* cipher_suite */ + - 1 /* compression_method */) || - !CBS_get_u16_length_prefixed(©, &extensions) || - CBS_len(©) != 0) { - OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR); - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); - return -1; - } - - int have_supported_versions; - CBS supported_versions; - const SSL_EXTENSION_TYPE ext_types[] = { - {TLSEXT_TYPE_supported_versions, &have_supported_versions, - &supported_versions}, - }; - - uint8_t alert = SSL_AD_DECODE_ERROR; - if (!ssl_parse_extensions(&extensions, &alert, ext_types, - OPENSSL_ARRAY_SIZE(ext_types), - 1 /* ignore unknown */)) { - ssl3_send_alert(ssl, SSL3_AL_FATAL, alert); - return -1; - } - - if (have_supported_versions) { - if (!CBS_get_u16(&supported_versions, &server_version) || - CBS_len(&supported_versions) != 0) { - ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_DECODE_ERROR); - return -1; - } - } - } - if (!ssl_supports_version(hs, server_version)) { OPENSSL_PUT_ERROR(SSL, SSL_R_UNSUPPORTED_PROTOCOL); ssl3_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_PROTOCOL_VERSION); @@ -920,7 +937,12 @@ static int ssl3_get_server_hello(SSL_HANDSHAKE *hs) { return -1; } - if (!CBS_get_bytes(&server_hello, &server_random, SSL3_RANDOM_SIZE) || + CBS server_hello, server_random, session_id; + uint16_t cipher_suite; + uint8_t compression_method; + CBS_init(&server_hello, ssl->init_msg, ssl->init_num); + if (!CBS_skip(&server_hello, 2 /* version */) || + !CBS_get_bytes(&server_hello, &server_random, SSL3_RANDOM_SIZE) || !CBS_get_u8_length_prefixed(&server_hello, &session_id) || CBS_len(&session_id) > SSL3_SESSION_ID_SIZE || !CBS_get_u16(&server_hello, &cipher_suite) || diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index a3c744cc..fd9fb3d5 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go @@ -1384,6 +1384,14 @@ type ProtocolBugs struct { // RejectUnsolicitedKeyUpdate, if true, causes all unsolicited // KeyUpdates from the peer to be rejected. RejectUnsolicitedKeyUpdate bool + + // OmitExtensions, if true, causes the extensions field in ClientHello + // and ServerHello messages to be omitted. + OmitExtensions bool + + // EmptyExtensions, if true, causese the extensions field in ClientHello + // and ServerHello messages to be present, but empty. + EmptyExtensions bool } func (c *Config) serverInit() { diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index cac1ebf5..05e7311c 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go @@ -84,7 +84,6 @@ func (c *Conn) clientHandshake() error { sctListSupported: !c.config.Bugs.NoSignedCertificateTimestamps, serverName: c.config.ServerName, supportedCurves: c.config.curvePreferences(), - pskKEModes: []byte{pskDHEKEMode}, supportedPoints: []uint8{pointFormatUncompressed}, nextProtoNeg: len(c.config.NextProtos) > 0, secureRenegotiation: []byte{}, @@ -97,6 +96,8 @@ func (c *Conn) clientHandshake() error { srtpMasterKeyIdentifier: c.config.Bugs.SRTPMasterKeyIdentifer, customExtension: c.config.Bugs.CustomExtension, pskBinderFirst: c.config.Bugs.PSKBinderFirst, + omitExtensions: c.config.Bugs.OmitExtensions, + emptyExtensions: c.config.Bugs.EmptyExtensions, } if maxVersion >= VersionTLS13 { @@ -104,6 +105,7 @@ func (c *Conn) clientHandshake() error { if !c.config.Bugs.OmitSupportedVersions { hello.supportedVersions = c.config.supportedVersions(c.isDTLS) } + hello.pskKEModes = []byte{pskDHEKEMode} } else { hello.vers = mapClientHelloVersion(maxVersion, c.isDTLS) } diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 1e36f081..4be873d8 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -7,6 +7,7 @@ package runner import ( "bytes" "encoding/binary" + "fmt" ) func writeLen(buf []byte, v, size int) { @@ -35,6 +36,11 @@ func (bb *byteBuilder) len() int { return len(*bb.buf) - bb.start - bb.prefixLen } +func (bb *byteBuilder) data() []byte { + bb.flush() + return (*bb.buf)[bb.start+bb.prefixLen:] +} + func (bb *byteBuilder) flush() { if bb.child == nil { return @@ -112,11 +118,11 @@ func (bb *byteBuilder) createChild(lengthPrefixSize int) *byteBuilder { } func (bb *byteBuilder) discardChild() { - if bb.child != nil { + if bb.child == nil { return } + *bb.buf = (*bb.buf)[:bb.child.start] bb.child = nil - *bb.buf = (*bb.buf)[:bb.start] } type keyShareEntry struct { @@ -167,6 +173,8 @@ type clientHelloMsg struct { customExtension string hasGREASEExtension bool pskBinderFirst bool + omitExtensions bool + emptyExtensions bool } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -212,7 +220,9 @@ func (m *clientHelloMsg) equal(i interface{}) bool { m.sctListSupported == m1.sctListSupported && m.customExtension == m1.customExtension && m.hasGREASEExtension == m1.hasGREASEExtension && - m.pskBinderFirst == m1.pskBinderFirst + m.pskBinderFirst == m1.pskBinderFirst && + m.omitExtensions == m1.omitExtensions && + m.emptyExtensions == m1.emptyExtensions } func (m *clientHelloMsg) marshal() []byte { @@ -444,8 +454,12 @@ func (m *clientHelloMsg) marshal() []byte { } } - if extensions.len() == 0 { + if m.omitExtensions || m.emptyExtensions { + // Silently erase any extensions which were sent. hello.discardChild() + if m.emptyExtensions { + hello.addU16(0) + } } m.raw = handshakeMsg.finish() @@ -828,6 +842,8 @@ type serverHelloMsg struct { compressionMethod uint8 customExtension string unencryptedALPN string + omitExtensions bool + emptyExtensions bool extensions serverExtensions } @@ -904,8 +920,17 @@ func (m *serverHelloMsg) marshal() []byte { } } else { m.extensions.marshal(extensions) - if extensions.len() == 0 { + if m.omitExtensions || m.emptyExtensions { + // Silently erasing server extensions will break the handshake. Instead, + // assert that tests which use this field also disable all features which + // would write an extension. + if extensions.len() != 0 { + panic(fmt.Sprintf("ServerHello unexpectedly contained extensions: %x, %+v", extensions.data(), m)) + } hello.discardChild() + if m.emptyExtensions { + hello.addU16(0) + } } } diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index 35d005e5..b31a562e 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go @@ -1068,6 +1068,8 @@ func (hs *serverHandshakeState) processClientHello() (isResume bool, err error) extensions: serverExtensions{ supportedVersion: config.Bugs.SendServerSupportedExtensionVersion, }, + omitExtensions: config.Bugs.OmitExtensions, + emptyExtensions: config.Bugs.EmptyExtensions, } hs.hello.random = make([]byte, 32) diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index 5c0906a7..29747db6 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go @@ -12074,6 +12074,75 @@ func addExtraHandshakeTests() { }) } +// Test that omitted and empty extensions blocks are tolerated. +func addOmitExtensionsTests() { + for _, ver := range tlsVersions { + if ver.version > VersionTLS12 { + continue + } + + testCases = append(testCases, testCase{ + testType: serverTest, + name: "OmitExtensions-ClientHello-" + ver.name, + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + OmitExtensions: true, + }, + }, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + name: "EmptyExtensions-ClientHello-" + ver.name, + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + EmptyExtensions: true, + }, + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "OmitExtensions-ServerHello-" + ver.name, + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + OmitExtensions: true, + // Disable all ServerHello extensions so + // OmitExtensions works. + NoExtendedMasterSecret: true, + NoRenegotiationInfo: true, + }, + }, + }) + + testCases = append(testCases, testCase{ + testType: clientTest, + name: "EmptyExtensions-ServerHello-" + ver.name, + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + EmptyExtensions: true, + // Disable all ServerHello extensions so + // EmptyExtensions works. + NoExtendedMasterSecret: true, + NoRenegotiationInfo: true, + }, + }, + }) + } +} + func worker(statusChan chan statusMsg, c chan *testCase, shimPath string, wg *sync.WaitGroup) { defer wg.Done() @@ -12201,6 +12270,7 @@ func main() { addRetainOnlySHA256ClientCertTests() addECDSAKeyUsageTests() addExtraHandshakeTests() + addOmitExtensionsTests() var wg sync.WaitGroup