Speed up test collection (#298)

* don't do filesystem operations during collection

* Greatly speed up test collection

* fixup! Greatly speed up test collection

* Silence junit warning

* fixup! Greatly speed up test collection
This commit is contained in:
Thom Wiggers 2020-06-22 04:10:07 +02:00 committed by Kris Kwiatkowski
parent 4604907c4c
commit 1f8b852e8f
3 changed files with 39 additions and 22 deletions

View File

@ -2,11 +2,13 @@ import atexit
import functools import functools
import logging import logging
import os import os
import secrets
import shutil import shutil
import string
import subprocess import subprocess
import sys import sys
import tempfile
import unittest import unittest
from functools import lru_cache
import pqclean import pqclean
@ -22,6 +24,12 @@ def cleanup_testcases():
TEST_TEMPDIRS = [] TEST_TEMPDIRS = []
ALPHABET = string.ascii_letters + string.digits + '_'
def mktmpdir(parent, prefix):
"""Returns a unique directory name"""
uniq = ''.join(secrets.choice(ALPHABET) for i in range(8))
return os.path.join(parent, "{}_{}".format(prefix, uniq))
def isolate_test_files(impl_path, test_prefix, def isolate_test_files(impl_path, test_prefix,
dir=os.path.join('..', 'testcases')): dir=os.path.join('..', 'testcases')):
@ -34,24 +42,22 @@ def isolate_test_files(impl_path, test_prefix,
os.mkdir(dir) os.mkdir(dir)
except FileExistsError: except FileExistsError:
pass pass
test_dir = tempfile.mkdtemp(prefix=test_prefix, dir=dir) test_dir = mktmpdir(dir, test_prefix)
test_dir = os.path.abspath(test_dir) test_dir = os.path.abspath(test_dir)
TEST_TEMPDIRS.append(test_dir) TEST_TEMPDIRS.append(test_dir)
# Create layers in folder structure
nested_dir = os.path.join(test_dir, 'crypto_bla')
os.mkdir(nested_dir)
nested_dir = os.path.join(nested_dir, 'scheme')
os.mkdir(nested_dir)
# Create test dependencies structure
os.mkdir(os.path.join(test_dir, 'test'))
# the implementation will go here. # the implementation will go here.
new_impl_dir = os.path.abspath(os.path.join(nested_dir, 'impl')) scheme_dir = os.path.join(test_dir, 'crypto_bla', 'scheme')
new_impl_dir = os.path.abspath(os.path.join(scheme_dir, 'impl'))
def initializer(): def initializer():
"""Isolate the files to be tested""" """Isolate the files to be tested"""
# Create layers in folder structure
os.makedirs(scheme_dir)
# Create test dependencies structure
os.mkdir(os.path.join(test_dir, 'test'))
# Copy common files (randombytes.c, aes.c, ...) # Copy common files (randombytes.c, aes.c, ...)
shutil.copytree( shutil.copytree(
os.path.join('..', 'common'), os.path.join(test_dir, 'common')) os.path.join('..', 'common'), os.path.join(test_dir, 'common'))
@ -160,6 +166,7 @@ def slow_test(f):
return wrapper return wrapper
@lru_cache(maxsize=None)
def ensure_available(executable): def ensure_available(executable):
""" """
Checks if a command is available. Checks if a command is available.
@ -278,18 +285,16 @@ def filtered_test(func):
return wrapper return wrapper
__CPUINFO = None @lru_cache(maxsize=1)
def get_cpu_info(): def get_cpu_info():
global __CPUINFO the_info = None
while __CPUINFO is None or 'flags' not in __CPUINFO: while the_info is None or 'flags' not in the_info:
import cpuinfo import cpuinfo
__CPUINFO = cpuinfo.get_cpu_info() the_info = cpuinfo.get_cpu_info()
# CPUINFO is unreliable on Travis CI Macs # CPUINFO is unreliable on Travis CI Macs
if 'CI' in os.environ and sys.platform == 'darwin': if 'CI' in os.environ and sys.platform == 'darwin':
__CPUINFO['flags'] = [ the_info['flags'] = [
'aes', 'apic', 'avx1.0', 'clfsh', 'cmov', 'cx16', 'cx8', 'de', 'aes', 'apic', 'avx1.0', 'clfsh', 'cmov', 'cx16', 'cx8', 'de',
'em64t', 'erms', 'f16c', 'fpu', 'fxsr', 'lahf', 'mca', 'mce', 'em64t', 'erms', 'f16c', 'fpu', 'fxsr', 'lahf', 'mca', 'mce',
'mmx', 'mon', 'msr', 'mtrr', 'osxsave', 'pae', 'pat', 'pcid', 'mmx', 'mon', 'msr', 'mtrr', 'osxsave', 'pae', 'pat', 'pcid',
@ -299,4 +304,4 @@ def get_cpu_info():
'tsc_thread_offset', 'tsci', 'tsctmr', 'vme', 'vmm', 'x2apic', 'tsc_thread_offset', 'tsci', 'tsctmr', 'vme', 'vmm', 'x2apic',
'xd', 'xsave'] 'xd', 'xsave']
return __CPUINFO return the_info

View File

@ -1,6 +1,7 @@
import glob import glob
import os import os
from typing import Optional from typing import Optional
from functools import lru_cache
import yaml import yaml
import platform import platform
@ -20,6 +21,7 @@ class Scheme:
return 'PQCLEAN_{}_'.format(self.name.upper()).replace('-', '') return 'PQCLEAN_{}_'.format(self.name.upper()).replace('-', '')
@staticmethod @staticmethod
@lru_cache(maxsize=None)
def by_name(scheme_name): def by_name(scheme_name):
for scheme in Scheme.all_schemes(): for scheme in Scheme.all_schemes():
if scheme.name == scheme_name: if scheme.name == scheme_name:
@ -27,6 +29,7 @@ class Scheme:
raise KeyError() raise KeyError()
@staticmethod @staticmethod
@lru_cache(maxsize=1)
def all_schemes(): def all_schemes():
schemes = [] schemes = []
schemes.extend(Scheme.all_schemes_of_type('kem')) schemes.extend(Scheme.all_schemes_of_type('kem'))
@ -34,6 +37,7 @@ class Scheme:
return schemes return schemes
@staticmethod @staticmethod
@lru_cache(maxsize=1)
def all_implementations(): def all_implementations():
implementations = [] implementations = []
for scheme in Scheme.all_schemes(): for scheme in Scheme.all_schemes():
@ -41,11 +45,13 @@ class Scheme:
return implementations return implementations
@staticmethod @staticmethod
@lru_cache(maxsize=1)
def all_supported_implementations(): def all_supported_implementations():
return [impl for impl in Scheme.all_implementations() return [impl for impl in Scheme.all_implementations()
if impl.supported_on_current_platform()] if impl.supported_on_current_platform()]
@staticmethod @staticmethod
@lru_cache(maxsize=32)
def all_schemes_of_type(type: str) -> list: def all_schemes_of_type(type: str) -> list:
schemes = [] schemes = []
p = os.path.join('..', 'crypto_' + type) p = os.path.join('..', 'crypto_' + type)
@ -60,12 +66,13 @@ class Scheme:
assert('Unknown type') assert('Unknown type')
return schemes return schemes
@lru_cache(maxsize=None)
def metadata(self): def metadata(self):
metafile = os.path.join(self.path(), 'META.yml') metafile = os.path.join(self.path(), 'META.yml')
try: try:
with open(metafile, encoding='utf-8') as f: with open(metafile, encoding='utf-8') as f:
metadata = yaml.safe_load(f.read()) metadata = yaml.safe_load(f)
return metadata return metadata
except Exception as e: except Exception as e:
print("Can't open {}: {}".format(metafile, e)) print("Can't open {}: {}".format(metafile, e))
return None return None
@ -80,6 +87,7 @@ class Implementation:
self.scheme = scheme self.scheme = scheme
self.name = name self.name = name
@lru_cache(maxsize=None)
def metadata(self): def metadata(self):
for i in self.scheme.metadata()['implementations']: for i in self.scheme.metadata()['implementations']:
if i['name'] == self.name: if i['name'] == self.name:
@ -104,6 +112,7 @@ class Implementation:
'*.o' if os.name != 'nt' else '*.obj')) '*.o' if os.name != 'nt' else '*.obj'))
@staticmethod @staticmethod
@lru_cache(maxsize=None)
def by_name(scheme_name, implementation_name): def by_name(scheme_name, implementation_name):
scheme = Scheme.by_name(scheme_name) scheme = Scheme.by_name(scheme_name)
for implementation in scheme.implementations: for implementation in scheme.implementations:
@ -112,6 +121,7 @@ class Implementation:
raise KeyError() raise KeyError()
@staticmethod @staticmethod
@lru_cache(maxsize=None)
def all_implementations(scheme: Scheme) -> list: def all_implementations(scheme: Scheme) -> list:
implementations = [] implementations = []
for d in os.listdir(scheme.path()): for d in os.listdir(scheme.path()):
@ -143,6 +153,7 @@ class Implementation:
return True return True
@lru_cache(maxsize=10000)
def supported_on_current_platform(self) -> bool: def supported_on_current_platform(self) -> bool:
if 'supported_platforms' not in self.metadata(): if 'supported_platforms' not in self.metadata():
return True return True

View File

@ -2,3 +2,4 @@
norecursedirs = .git * norecursedirs = .git *
empty_parameter_set_mark = fail_at_collect empty_parameter_set_mark = fail_at_collect
junit_log_passing_tests = False junit_log_passing_tests = False
junit_family=xunit2