diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 766d515c..bec97ddc 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go @@ -537,11 +537,7 @@ func (m *clientHelloMsg) marshal() []byte { srtpProtectionProfiles := useSrtpExt.addU16LengthPrefixed() for _, p := range m.srtpProtectionProfiles { - // An SRTPProtectionProfile is defined as uint8[2], - // not uint16. For some reason, we're storing it - // as a uint16. - srtpProtectionProfiles.addU8(byte(p >> 8)) - srtpProtectionProfiles.addU8(byte(p)) + srtpProtectionProfiles.addU16(p) } srtpMki := useSrtpExt.addU8LengthPrefixed() srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) @@ -967,76 +963,56 @@ func (m *serverHelloMsg) marshal() []byte { } func (m *serverHelloMsg) unmarshal(data []byte) bool { - if len(data) < 42 { + m.raw = data + reader := byteReader(data[4:]) + if !reader.readU16(&m.vers) || + !reader.readBytes(&m.random, 32) { return false } - m.raw = data - m.vers = uint16(data[4])<<8 | uint16(data[5]) vers, ok := wireToVersion(m.vers, m.isDTLS) if !ok { return false } - m.random = data[6:38] - data = data[38:] if vers < VersionTLS13 || isResumptionExperiment(m.vers) { - sessionIdLen := int(data[0]) - if sessionIdLen > 32 || len(data) < 1+sessionIdLen { + if !reader.readU8LengthPrefixedBytes(&m.sessionId) { return false } - m.sessionId = data[1 : 1+sessionIdLen] - data = data[1+sessionIdLen:] } - if len(data) < 2 { + if !reader.readU16(&m.cipherSuite) { return false } - m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] if vers < VersionTLS13 || isResumptionExperiment(m.vers) { - if len(data) < 1 { + if !reader.readU8(&m.compressionMethod) { return false } - m.compressionMethod = data[0] - data = data[1:] } - if len(data) == 0 && m.vers < VersionTLS13 { + if len(reader) == 0 && m.vers < VersionTLS13 { // Extension data is optional before TLS 1.3. m.extensions = serverExtensions{} m.omitExtensions = true return true } - if len(data) < 2 { - return false - } - extensionsLength := int(data[0])<<8 | int(data[1]) - data = data[2:] - if len(data) != extensionsLength { + var extensions byteReader + if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 { return false } // Parse out the version from supported_versions if available. if m.vers == VersionTLS12 { - vdata := data - for len(vdata) != 0 { - if len(vdata) < 4 { + extensionsCopy := extensions + for len(extensionsCopy) > 0 { + var extension uint16 + var body byteReader + if !extensionsCopy.readU16(&extension) || + !extensionsCopy.readU16LengthPrefixed(&body) { return false } - extension := uint16(vdata[0])<<8 | uint16(vdata[1]) - length := int(vdata[2])<<8 | int(vdata[3]) - vdata = vdata[4:] - - if len(vdata) < length { - return false - } - d := vdata[:length] - vdata = vdata[length:] - if extension == extensionSupportedVersions { - if len(d) < 2 { + if !body.readU16(&m.vers) || len(body) != 0 { return false } - m.vers = uint16(d[0])<<8 | uint16(d[1]) vers, ok = wireToVersion(m.vers, m.isDTLS) if !ok { return false @@ -1046,38 +1022,27 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { } if vers >= VersionTLS13 { - for len(data) != 0 { - if len(data) < 4 { + for len(extensions) > 0 { + var extension uint16 + var body byteReader + if !extensions.readU16(&extension) || + !extensions.readU16LengthPrefixed(&body) { return false } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - - if len(data) < length { - return false - } - d := data[:length] - data = data[length:] - switch extension { case extensionKeyShare: m.hasKeyShare = true - if len(d) < 4 { + var group uint16 + if !body.readU16(&group) || + !body.readU16LengthPrefixedBytes(&m.keyShare.keyExchange) || + len(body) != 0 { return false } - m.keyShare.group = CurveID(uint16(d[0])<<8 | uint16(d[1])) - keyExchLen := int(d[2])<<8 | int(d[3]) - if keyExchLen != len(d)-4 { - return false - } - m.keyShare.keyExchange = make([]byte, keyExchLen) - copy(m.keyShare.keyExchange, d[4:]) + m.keyShare.group = CurveID(group) case extensionPreSharedKey: - if len(d) != 2 { + if !body.readU16(&m.pskIdentity) || len(body) != 0 { return false } - m.pskIdentity = uint16(d[0])<<8 | uint16(d[1]) m.hasPSKIdentity = true case extensionSupportedVersions: if !isResumptionExperiment(m.vers) { @@ -1089,7 +1054,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return false } } - } else if !m.extensions.unmarshal(data, vers) { + } else if !m.extensions.unmarshal(extensions, vers) { return false } @@ -1121,23 +1086,12 @@ func (m *encryptedExtensionsMsg) marshal() []byte { func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { m.raw = data - if len(data) < 6 { + reader := byteReader(data[4:]) + var extensions byteReader + if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 { return false } - if data[0] != typeEncryptedExtensions { - return false - } - msgLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) != msgLen { - return false - } - extLen := int(data[0])<<8 | int(data[1]) - data = data[2:] - if extLen != len(data) { - return false - } - return m.extensions.unmarshal(data, VersionTLS13) + return m.extensions.unmarshal(extensions, VersionTLS13) } type serverExtensions struct { @@ -1223,8 +1177,7 @@ func (m *serverExtensions) marshal(extensions *byteBuilder) { extension := extensions.addU16LengthPrefixed() srtpProtectionProfiles := extension.addU16LengthPrefixed() - srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile >> 8)) - srtpProtectionProfiles.addU8(byte(m.srtpProtectionProfile)) + srtpProtectionProfiles.addU16(m.srtpProtectionProfile) srtpMki := extension.addU8LengthPrefixed() srtpMki.addBytes([]byte(m.srtpMasterKeyIdentifier)) } @@ -1288,96 +1241,77 @@ func (m *serverExtensions) marshal(extensions *byteBuilder) { } } -func (m *serverExtensions) unmarshal(data []byte, version uint16) bool { +func (m *serverExtensions) unmarshal(data byteReader, version uint16) bool { // Reset all fields. *m = serverExtensions{} - for len(data) != 0 { - if len(data) < 4 { + for len(data) > 0 { + var extension uint16 + var body byteReader + if !data.readU16(&extension) || + !data.readU16LengthPrefixed(&body) { return false } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - return false - } - switch extension { case extensionNextProtoNeg: m.nextProtoNeg = true - d := data[:length] - for len(d) > 0 { - l := int(d[0]) - d = d[1:] - if l == 0 || l > len(d) { + for len(body) > 0 { + var protocol []byte + if !body.readU8LengthPrefixedBytes(&protocol) { return false } - m.nextProtos = append(m.nextProtos, string(d[:l])) - d = d[l:] + m.nextProtos = append(m.nextProtos, string(protocol)) } case extensionStatusRequest: - if length > 0 { + if len(body) != 0 { return false } m.ocspStapling = true case extensionSessionTicket: - if length > 0 { + if len(body) != 0 { return false } m.ticketSupported = true case extensionRenegotiationInfo: - if length < 1 || length != int(data[0])+1 { + if !body.readU8LengthPrefixedBytes(&m.secureRenegotiation) || len(body) != 0 { return false } - m.secureRenegotiation = data[1:length] case extensionALPN: - d := data[:length] - if len(d) < 3 { + var protocols, protocol byteReader + if !body.readU16LengthPrefixed(&protocols) || + len(body) != 0 || + !protocols.readU8LengthPrefixed(&protocol) || + len(protocols) != 0 { return false } - l := int(d[0])<<8 | int(d[1]) - if l != len(d)-2 { - return false - } - d = d[2:] - l = int(d[0]) - if l != len(d)-1 { - return false - } - d = d[1:] - m.alpnProtocol = string(d) - m.alpnProtocolEmpty = len(d) == 0 + m.alpnProtocol = string(protocol) + m.alpnProtocolEmpty = len(protocol) == 0 case extensionChannelID: - if length > 0 { + if len(body) != 0 { return false } m.channelIDRequested = true case extensionExtendedMasterSecret: - if length != 0 { + if len(body) != 0 { return false } m.extendedMasterSecret = true case extensionUseSRTP: - if length < 2+2+1 { + var profiles, mki byteReader + if !body.readU16LengthPrefixed(&profiles) || + !profiles.readU16(&m.srtpProtectionProfile) || + len(profiles) != 0 || + !body.readU8LengthPrefixed(&mki) || + len(body) != 0 { return false } - if data[0] != 0 || data[1] != 2 { - return false - } - m.srtpProtectionProfile = uint16(data[2])<<8 | uint16(data[3]) - d := data[4:length] - l := int(d[0]) - if l != len(d)-1 { - return false - } - m.srtpMasterKeyIdentifier = string(d[1:]) + m.srtpMasterKeyIdentifier = string(mki) case extensionSignedCertificateTimestamp: - m.sctList = data[:length] + m.sctList = []byte(body) case extensionCustom: - m.customExtension = string(data[:length]) + m.customExtension = string(body) case extensionServerName: - if length != 0 { + if len(body) != 0 { return false } m.serverNameAck = true @@ -1387,21 +1321,16 @@ func (m *serverExtensions) unmarshal(data []byte, version uint16) bool { return false } // http://tools.ietf.org/html/rfc4492#section-5.5.2 - if length < 1 { + if !body.readU8LengthPrefixedBytes(&m.supportedPoints) || len(body) != 0 { return false } - l := int(data[0]) - if length != l+1 { - return false - } - m.supportedPoints = data[1 : 1+l] case extensionSupportedCurves: // The server can only send supported_curves in TLS 1.3. if version < VersionTLS13 { return false } case extensionEarlyData: - if version < VersionTLS13 || length != 0 { + if version < VersionTLS13 || len(body) != 0 { return false } m.hasEarlyData = true @@ -1409,7 +1338,6 @@ func (m *serverExtensions) unmarshal(data []byte, version uint16) bool { // Unknown extensions are illegal from the server. return false } - data = data[length:] } return true @@ -1490,73 +1418,56 @@ func (m *helloRetryRequestMsg) marshal() []byte { func (m *helloRetryRequestMsg) unmarshal(data []byte) bool { m.raw = data - if len(data) < 8 { + reader := byteReader(data[4:]) + if !reader.readU16(&m.vers) { return false } - m.vers = uint16(data[4])<<8 | uint16(data[5]) - data = data[6:] if m.isServerHello { - if len(data) < 33 { + var random []byte + var compressionMethod byte + if !reader.readBytes(&random, 32) || + !reader.readU8LengthPrefixedBytes(&m.sessionId) || + !reader.readU16(&m.cipherSuite) || + !reader.readU8(&compressionMethod) || + compressionMethod != 0 { return false } - data = data[32:] // Random - sessionIdLen := int(data[0]) - if sessionIdLen > 32 || len(data) < 1+sessionIdLen+3 { - return false - } - m.sessionId = data[1 : 1+sessionIdLen] - data = data[1+sessionIdLen:] - m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - data = data[1:] // Compression Method - } else { - if isDraft21(m.vers) { - m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) - data = data[2:] - } - } - extLen := int(data[0])<<8 | int(data[1]) - data = data[2:] - if len(data) != extLen || len(data) == 0 { + } else if isDraft21(m.vers) && !reader.readU16(&m.cipherSuite) { return false } - for len(data) > 0 { - if len(data) < 4 { + var extensions byteReader + if !reader.readU16LengthPrefixed(&extensions) || len(reader) != 0 { + return false + } + for len(extensions) > 0 { + var extension uint16 + var body byteReader + if !extensions.readU16(&extension) || + !extensions.readU16LengthPrefixed(&body) { return false } - extension := uint16(data[0])<<8 | uint16(data[1]) - length := int(data[2])<<8 | int(data[3]) - data = data[4:] - if len(data) < length { - return false - } - switch extension { case extensionSupportedVersions: - if length != 2 || !m.isServerHello { + if !m.isServerHello || + !body.readU16(&m.vers) || + len(body) != 0 { return false } - m.vers = uint16(data[0])<<8 | uint16(data[1]) case extensionKeyShare: - if length != 2 { + var v uint16 + if !body.readU16(&v) || len(body) != 0 { return false } m.hasSelectedGroup = true - m.selectedGroup = CurveID(data[0])<<8 | CurveID(data[1]) + m.selectedGroup = CurveID(v) case extensionCookie: - if length < 2 { + if !body.readU16LengthPrefixedBytes(&m.cookie) || len(body) != 0 { return false } - cookieLen := int(data[0])<<8 | int(data[1]) - if 2+cookieLen != length { - return false - } - m.cookie = data[2 : 2+cookieLen] default: // Unknown extensions are illegal from the server. return false } - data = data[length:] } return true }