diff --git a/test/katrunner/Cargo.lock b/test/katrunner/Cargo.lock index 7bb56bc8..d6f2e2ae 100644 --- a/test/katrunner/Cargo.lock +++ b/test/katrunner/Cargo.lock @@ -172,6 +172,7 @@ version = "0.1.0" dependencies = [ "hex", "katwalk", + "lazy_static", "pqc-sys", "rust-crypto", "threadpool", diff --git a/test/katrunner/Cargo.toml b/test/katrunner/Cargo.toml index a9201968..ee1d8744 100644 --- a/test/katrunner/Cargo.toml +++ b/test/katrunner/Cargo.toml @@ -9,4 +9,5 @@ katwalk = "0.0.3" pqc-sys = { path = "../../src/rustapi/pqc-sys" } hex = "0.4.2" threadpool = "1.8.1" -rust-crypto = "^0.2" \ No newline at end of file +rust-crypto = "^0.2" +lazy_static = "1.4.0" \ No newline at end of file diff --git a/test/katrunner/src/main.rs b/test/katrunner/src/main.rs index 975c1516..ec7886dc 100644 --- a/test/katrunner/src/main.rs +++ b/test/katrunner/src/main.rs @@ -6,6 +6,10 @@ use std::path::Path; use threadpool::ThreadPool; use std::convert::TryInto; use drbg::ctr::DrbgCtx; +use std::collections::HashMap; +use std::thread; +use std::sync::Mutex; +use lazy_static::lazy_static; mod drbg; @@ -32,16 +36,14 @@ macro_rules! REG_KEM { } } -const KAT_DIR : &'static str= "."; -type ExecFn = fn(&TestVector); -struct Register { - kat: katwalk::reader::Kat, - execfn: ExecFn, +// Stores one DRBG context per execution thread. DRBG +// is inserted in this map, just after thread starts +// and removed after thread is finished. Operation +// is synchronized. +lazy_static! { + static ref DRBGV: Mutex> = Mutex::new(HashMap::new()); } -// Define global DRBG object -static mut DRBG: DrbgCtx = DrbgCtx::new(); - // We have to provide the implementation for qrs_randombytes #[no_mangle] unsafe extern "C" fn randombytes( @@ -49,16 +51,29 @@ unsafe extern "C" fn randombytes( len: usize ) { let mut slice = std::slice::from_raw_parts_mut(data, len); - DRBG.get_random(&mut slice); + // get thread specific DRBG. + if let Some(drbg) = DRBGV.lock().unwrap().get_mut(&thread::current().id()) { + drbg.get_random(&mut slice); + } + +} + +type ExecFn = fn(&TestVector); +struct Register { + kat: katwalk::reader::Kat, + execfn: ExecFn, } fn test_sign_vector(el: &TestVector) { let mut pk = Vec::new(); let mut sk = Vec::new(); let mut sm = Vec::new(); - unsafe { - DRBG.init(el.sig.seed.as_slice(), Vec::new()); + if let Some(drbg) = DRBGV.lock().unwrap().get_mut(&thread::current().id()) { + drbg.init(el.sig.seed.as_slice(), Vec::new()); + } + + unsafe { // Check Verification // pqc doesn't use "envelope" API. From the other // hand in KATs for signature scheme, the signature @@ -105,8 +120,12 @@ fn test_kem_vector(el: &TestVector) { let mut ct = Vec::new(); let mut ss = Vec::new(); + if let Some(drbg) = DRBGV.lock().unwrap().get_mut(&thread::current().id()) { + drbg.init(el.sig.seed.as_slice(), Vec::new()); + } + unsafe { - DRBG.init(el.kem.seed.as_slice(), Vec::new()); + let p = pqc_kem_alg_by_id(el.scheme_id as u8); assert_ne!(p.is_null(), true); @@ -202,12 +221,14 @@ const KATS: &'static[Register] = &[ //REG_SIGN!(RAINBOWIIICLASSIC), ]; -fn execute(kat_dir: String) { +fn execute(kat_dir: String, thc: usize) { // Can't do multi-threads as DRBG context is global - let pool = ThreadPool::new(1); + let pool = ThreadPool::new(thc); for k in KATS.iter() { let tmp = kat_dir.clone(); pool.execute(move || { + DRBGV.lock().unwrap() + .insert(thread::current().id(), DrbgCtx::new()); let f = Path::new(&tmp.to_string()).join(k.kat.kat_file); let file = File::open(format!("{}", f.to_str().unwrap())); println!("Processing file: {}", Path::new(k.kat.kat_file).to_str().unwrap()); @@ -216,22 +237,32 @@ fn execute(kat_dir: String) { for el in KatReader::new(b, k.kat.scheme_type, k.kat.scheme_id) { (k.execfn)(&el); } + DRBGV.lock().unwrap() + .remove(&thread::current().id()); }); } pool.join(); } fn main() { - let kat_dir: String; let args: Vec = env::args().collect(); - if args.len() > 1 { - if args[1] == "--katdir" && args.len() == 3 { - kat_dir = args[2].to_string(); - } else { - panic!("Unrecognized argument"); - } - } else { - kat_dir = String::from(KAT_DIR); + let mut argmap = HashMap::new(); + + if args.len() % 2 == 0 { + panic!("Wrong number of arguments"); + } + + for i in (1..args.len()).step_by(2) { + argmap.insert(&args[i], &args[i+1]); } - execute(kat_dir); + + let thread_number: usize = match argmap.get(&"--threads".to_string()) { + Some(n) => n.to_string().parse::().unwrap(), + None => 4 /* by default 4 threads */, + }; + + match argmap.get(&"--katdir".to_string()) { + Some(kat_dir) => execute(kat_dir.to_string(), thread_number), + None => panic!("--katdir required") + }; }