use crate::handshake::binder::PskBinder; use crate::handshake::finished::Finished; use crate::{ProtocolError, config::TlsCipherSuite}; use digest::OutputSizeUser; use digest::generic_array::ArrayLength; use hmac::{Mac, SimpleHmac}; use sha2::Digest; use sha2::digest::generic_array::{GenericArray, typenum::Unsigned}; pub type HashOutputSize = <::Hash as OutputSizeUser>::OutputSize; pub type LabelBufferSize = ::LabelBufferSize; pub type IvArray = GenericArray::IvLen>; pub type KeyArray = GenericArray::KeyLen>; pub type HashArray = GenericArray>; type Hkdf = hkdf::Hkdf< ::Hash, SimpleHmac<::Hash>, >; enum Secret where CipherSuite: TlsCipherSuite, { Uninitialized, Initialized(Hkdf), } impl Secret where CipherSuite: TlsCipherSuite, { fn replace(&mut self, secret: Hkdf) { *self = Self::Initialized(secret); } fn as_ref(&self) -> Result<&Hkdf, ProtocolError> { match self { Secret::Initialized(secret) => Ok(secret), Secret::Uninitialized => Err(ProtocolError::InternalError), } } fn make_expanded_hkdf_label>( &self, label: &[u8], context_type: ContextType, ) -> Result, ProtocolError> { //info!("make label {:?} {}", label, len); let mut hkdf_label = heapless_typenum::Vec::>::new(); hkdf_label .extend_from_slice(&N::to_u16().to_be_bytes()) .map_err(|()| ProtocolError::InternalError)?; let label_len = 6 + label.len() as u8; hkdf_label .extend_from_slice(&label_len.to_be_bytes()) .map_err(|()| ProtocolError::InternalError)?; hkdf_label .extend_from_slice(b"tls13 ") .map_err(|()| ProtocolError::InternalError)?; hkdf_label .extend_from_slice(label) .map_err(|()| ProtocolError::InternalError)?; match context_type { ContextType::None => { hkdf_label.push(0).map_err(|_| ProtocolError::InternalError)?; } ContextType::Hash(context) => { hkdf_label .extend_from_slice(&(context.len() as u8).to_be_bytes()) .map_err(|()| ProtocolError::InternalError)?; hkdf_label .extend_from_slice(&context) .map_err(|()| ProtocolError::InternalError)?; } } let mut okm = GenericArray::default(); //info!("label {:x?}", label); self.as_ref()? .expand(&hkdf_label, &mut okm) .map_err(|_| ProtocolError::CryptoError)?; //info!("expand {:x?}", okm); Ok(okm) } } pub struct SharedState where CipherSuite: TlsCipherSuite, { secret: HashArray, hkdf: Secret, } impl SharedState where CipherSuite: TlsCipherSuite, { fn new() -> Self { Self { secret: GenericArray::default(), hkdf: Secret::Uninitialized, } } fn initialize(&mut self, ikm: &[u8]) { let (secret, hkdf) = Hkdf::::extract(Some(self.secret.as_ref()), ikm); self.hkdf.replace(hkdf); self.secret = secret; } fn derive_secret( &mut self, label: &[u8], context_type: ContextType, ) -> Result, ProtocolError> { self.hkdf .make_expanded_hkdf_label::>(label, context_type) } fn derived(&mut self) -> Result<(), ProtocolError> { self.secret = self.derive_secret(b"derived", ContextType::empty_hash())?; Ok(()) } } pub(crate) struct KeyScheduleState where CipherSuite: TlsCipherSuite, { traffic_secret: Secret, counter: u64, key: KeyArray, iv: IvArray, } impl KeyScheduleState where CipherSuite: TlsCipherSuite, { fn new() -> Self { Self { traffic_secret: Secret::Uninitialized, counter: 0, key: KeyArray::::default(), iv: IvArray::::default(), } } #[inline] pub fn get_key(&self) -> Result<&KeyArray, ProtocolError> { Ok(&self.key) } #[inline] pub fn get_iv(&self) -> Result<&IvArray, ProtocolError> { Ok(&self.iv) } pub fn get_nonce(&self) -> Result, ProtocolError> { let iv = self.get_iv()?; Ok(KeySchedule::::get_nonce(self.counter, iv)) } fn calculate_traffic_secret( &mut self, label: &[u8], shared: &mut SharedState, transcript_hash: &CipherSuite::Hash, ) -> Result<(), ProtocolError> { let secret = shared.derive_secret(label, ContextType::transcript_hash(transcript_hash))?; let traffic_secret = Hkdf::::from_prk(&secret).map_err(|_| ProtocolError::InternalError)?; self.traffic_secret.replace(traffic_secret); self.key = self .traffic_secret .make_expanded_hkdf_label(b"key", ContextType::None)?; self.iv = self .traffic_secret .make_expanded_hkdf_label(b"iv", ContextType::None)?; self.counter = 0; Ok(()) } pub fn increment_counter(&mut self) { self.counter = unwrap!(self.counter.checked_add(1)); } } enum ContextType where CipherSuite: TlsCipherSuite, { None, Hash(HashArray), } impl ContextType where CipherSuite: TlsCipherSuite, { fn transcript_hash(hash: &CipherSuite::Hash) -> Self { Self::Hash(hash.clone().finalize()) } fn empty_hash() -> Self { Self::Hash( ::new() .chain_update([]) .finalize(), ) } } pub struct KeySchedule where CipherSuite: TlsCipherSuite, { shared: SharedState, client_state: WriteKeySchedule, server_state: ReadKeySchedule, } impl KeySchedule where CipherSuite: TlsCipherSuite, { pub fn new() -> Self { Self { shared: SharedState::new(), client_state: WriteKeySchedule { state: KeyScheduleState::new(), binder_key: Secret::Uninitialized, }, server_state: ReadKeySchedule { state: KeyScheduleState::new(), transcript_hash: ::new(), }, } } pub(crate) fn transcript_hash(&mut self) -> &mut CipherSuite::Hash { &mut self.server_state.transcript_hash } pub(crate) fn replace_transcript_hash(&mut self, hash: CipherSuite::Hash) { self.server_state.transcript_hash = hash; } pub fn as_split( &mut self, ) -> ( &mut WriteKeySchedule, &mut ReadKeySchedule, ) { (&mut self.client_state, &mut self.server_state) } pub(crate) fn write_state(&mut self) -> &mut WriteKeySchedule { &mut self.client_state } pub(crate) fn read_state(&mut self) -> &mut ReadKeySchedule { &mut self.server_state } pub fn create_client_finished( &self, ) -> Result>, ProtocolError> { let key = self .client_state .state .traffic_secret .make_expanded_hkdf_label::>( b"finished", ContextType::None, )?; let mut hmac = SimpleHmac::::new_from_slice(&key) .map_err(|_| ProtocolError::CryptoError)?; Mac::update( &mut hmac, &self.server_state.transcript_hash.clone().finalize(), ); let verify = hmac.finalize().into_bytes(); Ok(Finished { verify, hash: None }) } fn get_nonce(counter: u64, iv: &IvArray) -> IvArray { //info!("counter = {} {:x?}", counter, &counter.to_be_bytes(),); let counter = Self::pad::(&counter.to_be_bytes()); //info!("counter = {:x?}", counter); // info!("iv = {:x?}", iv); let mut nonce = GenericArray::default(); for (index, (l, r)) in iv[0..CipherSuite::IvLen::to_usize()] .iter() .zip(counter.iter()) .enumerate() { nonce[index] = l ^ r; } //debug!("nonce {:x?}", nonce); nonce } fn pad>(input: &[u8]) -> GenericArray { // info!("padding input = {:x?}", input); let mut padded = GenericArray::default(); for (index, byte) in input.iter().rev().enumerate() { /*info!( "{} pad {}={:x?}", index, ((N::to_usize() - index) - 1), *byte );*/ padded[(N::to_usize() - index) - 1] = *byte; } padded } fn zero() -> HashArray { GenericArray::default() } // Initializes the early secrets with a callback for any PSK binders pub fn initialize_early_secret(&mut self, psk: Option<&[u8]>) -> Result<(), ProtocolError> { self.shared.initialize( #[allow(clippy::or_fun_call)] psk.unwrap_or(Self::zero().as_slice()), ); let binder_key = self .shared .derive_secret(b"ext binder", ContextType::empty_hash())?; self.client_state.binder_key.replace( Hkdf::::from_prk(&binder_key).map_err(|_| ProtocolError::InternalError)?, ); self.shared.derived() } pub fn initialize_handshake_secret(&mut self, ikm: &[u8]) -> Result<(), ProtocolError> { self.shared.initialize(ikm); self.calculate_traffic_secrets(b"c hs traffic", b"s hs traffic")?; self.shared.derived() } pub fn initialize_master_secret(&mut self) -> Result<(), ProtocolError> { self.shared.initialize(Self::zero().as_slice()); //let context = self.transcript_hash.as_ref().unwrap().clone().finalize(); //info!("Derive keys, hash: {:x?}", context); self.calculate_traffic_secrets(b"c ap traffic", b"s ap traffic")?; self.shared.derived() } fn calculate_traffic_secrets( &mut self, client_label: &[u8], server_label: &[u8], ) -> Result<(), ProtocolError> { self.client_state.state.calculate_traffic_secret( client_label, &mut self.shared, &self.server_state.transcript_hash, )?; self.server_state.state.calculate_traffic_secret( server_label, &mut self.shared, &self.server_state.transcript_hash, )?; Ok(()) } } impl Default for KeySchedule where CipherSuite: TlsCipherSuite, { fn default() -> Self { KeySchedule::new() } } pub struct WriteKeySchedule where CipherSuite: TlsCipherSuite, { state: KeyScheduleState, binder_key: Secret, } impl WriteKeySchedule where CipherSuite: TlsCipherSuite, { pub(crate) fn increment_counter(&mut self) { self.state.increment_counter(); } pub(crate) fn get_key(&self) -> Result<&KeyArray, ProtocolError> { self.state.get_key() } pub(crate) fn get_nonce(&self) -> Result, ProtocolError> { self.state.get_nonce() } pub fn create_psk_binder( &self, transcript_hash: &CipherSuite::Hash, ) -> Result>, ProtocolError> { let key = self .binder_key .make_expanded_hkdf_label::>( b"finished", ContextType::None, )?; let mut hmac = SimpleHmac::::new_from_slice(&key) .map_err(|_| ProtocolError::CryptoError)?; Mac::update(&mut hmac, &transcript_hash.clone().finalize()); let verify = hmac.finalize().into_bytes(); Ok(PskBinder { verify }) } } pub struct ReadKeySchedule where CipherSuite: TlsCipherSuite, { state: KeyScheduleState, transcript_hash: CipherSuite::Hash, } impl ReadKeySchedule where CipherSuite: TlsCipherSuite, { pub(crate) fn increment_counter(&mut self) { self.state.increment_counter(); } pub(crate) fn transcript_hash(&mut self) -> &mut CipherSuite::Hash { &mut self.transcript_hash } pub(crate) fn get_key(&self) -> Result<&KeyArray, ProtocolError> { self.state.get_key() } pub(crate) fn get_nonce(&self) -> Result, ProtocolError> { self.state.get_nonce() } pub fn verify_server_finished( &self, finished: &Finished>, ) -> Result { //info!("verify server finished: {:x?}", finished.verify); //self.client_traffic_secret.as_ref().unwrap().expand() //info!("size ===> {}", D::OutputSize::to_u16()); let key = self .state .traffic_secret .make_expanded_hkdf_label::>( b"finished", ContextType::None, )?; // info!("hmac sign key {:x?}", key); let mut hmac = SimpleHmac::::new_from_slice(&key) .map_err(|_| ProtocolError::InternalError)?; Mac::update( &mut hmac, finished.hash.as_ref().ok_or_else(|| { warn!("No hash in Finished"); ProtocolError::InternalError })?, ); //let code = hmac.clone().finalize().into_bytes(); Ok(hmac.verify(&finished.verify).is_ok()) //info!("verified {:?}", verified); //unimplemented!() } }