This commit is contained in:
2026-02-21 08:05:44 +00:00
commit af63ad4870
75 changed files with 11874 additions and 0 deletions

121
src/alert.rs Normal file
View File

@@ -0,0 +1,121 @@
use crate::ProtocolError;
use crate::buffer::CryptoBuffer;
use crate::parse_buffer::ParseBuffer;
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum AlertLevel {
Warning = 1,
Fatal = 2,
}
impl AlertLevel {
#[must_use]
pub fn of(num: u8) -> Option<Self> {
match num {
1 => Some(AlertLevel::Warning),
2 => Some(AlertLevel::Fatal),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum AlertDescription {
CloseNotify = 0,
UnexpectedMessage = 10,
BadRecordMac = 20,
RecordOverflow = 22,
HandshakeFailure = 40,
BadCertificate = 42,
UnsupportedCertificate = 43,
CertificateRevoked = 44,
CertificateExpired = 45,
CertificateUnknown = 46,
IllegalParameter = 47,
UnknownCa = 48,
AccessDenied = 49,
DecodeError = 50,
DecryptError = 51,
ProtocolVersion = 70,
InsufficientSecurity = 71,
InternalError = 80,
InappropriateFallback = 86,
UserCanceled = 90,
MissingExtension = 109,
UnsupportedExtension = 110,
UnrecognizedName = 112,
BadCertificateStatusResponse = 113,
UnknownPskIdentity = 115,
CertificateRequired = 116,
NoApplicationProtocol = 120,
}
impl AlertDescription {
#[must_use]
pub fn of(num: u8) -> Option<Self> {
match num {
0 => Some(AlertDescription::CloseNotify),
10 => Some(AlertDescription::UnexpectedMessage),
20 => Some(AlertDescription::BadRecordMac),
22 => Some(AlertDescription::RecordOverflow),
40 => Some(AlertDescription::HandshakeFailure),
42 => Some(AlertDescription::BadCertificate),
43 => Some(AlertDescription::UnsupportedCertificate),
44 => Some(AlertDescription::CertificateRevoked),
45 => Some(AlertDescription::CertificateExpired),
46 => Some(AlertDescription::CertificateUnknown),
47 => Some(AlertDescription::IllegalParameter),
48 => Some(AlertDescription::UnknownCa),
49 => Some(AlertDescription::AccessDenied),
50 => Some(AlertDescription::DecodeError),
51 => Some(AlertDescription::DecryptError),
70 => Some(AlertDescription::ProtocolVersion),
71 => Some(AlertDescription::InsufficientSecurity),
80 => Some(AlertDescription::InternalError),
86 => Some(AlertDescription::InappropriateFallback),
90 => Some(AlertDescription::UserCanceled),
109 => Some(AlertDescription::MissingExtension),
110 => Some(AlertDescription::UnsupportedExtension),
112 => Some(AlertDescription::UnrecognizedName),
113 => Some(AlertDescription::BadCertificateStatusResponse),
115 => Some(AlertDescription::UnknownPskIdentity),
116 => Some(AlertDescription::CertificateRequired),
120 => Some(AlertDescription::NoApplicationProtocol),
_ => None,
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Alert {
pub(crate) level: AlertLevel,
pub(crate) description: AlertDescription,
}
impl Alert {
#[must_use]
pub fn new(level: AlertLevel, description: AlertDescription) -> Self {
Self { level, description }
}
pub fn parse(buf: &mut ParseBuffer<'_>) -> Result<Alert, ProtocolError> {
let level = buf.read_u8()?;
let desc = buf.read_u8()?;
Ok(Self {
level: AlertLevel::of(level).ok_or(ProtocolError::DecodeError)?,
description: AlertDescription::of(desc).ok_or(ProtocolError::DecodeError)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
buf.push(self.level as u8)
.map_err(|_| ProtocolError::EncodeError)?;
buf.push(self.description as u8)
.map_err(|_| ProtocolError::EncodeError)?;
Ok(())
}
}

30
src/application_data.rs Normal file
View File

@@ -0,0 +1,30 @@
use crate::buffer::CryptoBuffer;
use crate::record::RecordHeader;
use core::fmt::{Debug, Formatter};
pub struct ApplicationData<'a> {
pub(crate) header: RecordHeader,
pub(crate) data: CryptoBuffer<'a>,
}
impl Debug for ApplicationData<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "ApplicationData {:x?}", self.data.len())
}
}
#[cfg(feature = "defmt")]
impl<'a> defmt::Format for ApplicationData<'a> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "ApplicationData {}", self.data.len());
}
}
impl<'a> ApplicationData<'a> {
pub fn new(rx_buf: CryptoBuffer<'a>, header: RecordHeader) -> ApplicationData<'a> {
Self {
header,
data: rx_buf,
}
}
}

541
src/asynch.rs Normal file
View File

@@ -0,0 +1,541 @@
use core::sync::atomic::{AtomicBool, Ordering};
use crate::ProtocolError;
use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
use crate::common::decrypted_read_handler::DecryptedReadHandler;
use crate::connection::{Handshake, State, decrypt_record};
use crate::send_policy::FlushPolicy;
use crate::key_schedule::KeySchedule;
use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
use crate::read_buffer::ReadBuffer;
use crate::record::{ClientRecord, ClientRecordHeader};
use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
use embedded_io::Error as _;
use embedded_io::ErrorType;
use embedded_io_async::{BufRead, Read as AsyncRead, Write as AsyncWrite};
pub use crate::config::*;
/// Type representing an async TLS connection. An instance of this type can
/// be used to establish a TLS connection, write and read encrypted data over this connection,
/// and closing to free up the underlying resources.
pub struct SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
delegate: Socket,
opened: AtomicBool,
key_schedule: KeySchedule<CipherSuite>,
record_reader: RecordReader<'a>,
record_write_buf: WriteBuffer<'a>,
decrypted: DecryptedBufferInfo,
flush_policy: FlushPolicy,
}
impl<'a, Socket, CipherSuite> SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
pub fn is_opened(&mut self) -> bool {
*self.opened.get_mut()
}
/// Create a new TLS connection with the provided context and a async I/O implementation
///
/// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
/// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
/// which should be a safe value to use.
///
/// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
/// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
/// records if depending on the size of the write buffer.
/// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
/// buffers must at least be large enough to encode a handshake.
pub fn new(
delegate: Socket,
record_read_buf: &'a mut [u8],
record_write_buf: &'a mut [u8],
) -> Self {
Self {
delegate,
opened: AtomicBool::new(false),
key_schedule: KeySchedule::new(),
record_reader: RecordReader::new(record_read_buf),
record_write_buf: WriteBuffer::new(record_write_buf),
decrypted: DecryptedBufferInfo::default(),
flush_policy: FlushPolicy::default(),
}
}
/// Returns a reference to the current flush policy.
///
/// The flush policy controls whether the underlying transport is flushed
/// (via its `flush()` method) after writing a TLS record.
#[inline]
pub fn flush_policy(&self) -> FlushPolicy {
self.flush_policy
}
/// Replace the current flush policy with the provided one.
///
/// This sets how and when the connection will call `flush()` on the
/// underlying transport after writing records.
#[inline]
pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
self.flush_policy = policy;
}
/// Open a TLS connection, performing the handshake with the configuration provided when
/// creating the connection instance.
///
/// Returns an error if the handshake does not proceed. If an error occurs, the connection
/// instance must be recreated.
pub async fn open<CP>(
&mut self,
mut context: ConnectContext<'_, CP>,
) -> Result<(), ProtocolError>
where
CP: CryptoBackend<CipherSuite = CipherSuite>,
{
let mut handshake: Handshake<CipherSuite> = Handshake::new();
if let (Ok(verifier), Some(server_name)) = (
context.crypto_provider.verifier(),
context.config.server_name,
) {
verifier.set_hostname_verification(server_name)?;
}
let mut state = State::ClientHello;
while state != State::ApplicationData {
let next_state = state
.process(
&mut self.delegate,
&mut handshake,
&mut self.record_reader,
&mut self.record_write_buf,
&mut self.key_schedule,
context.config,
&mut context.crypto_provider,
)
.await?;
trace!("State {:?} -> {:?}", state, next_state);
state = next_state;
}
*self.opened.get_mut() = true;
Ok(())
}
/// Encrypt and send the provided slice over the connection. The connection
/// must be opened before writing.
///
/// The slice may be buffered internally and not written to the connection immediately.
/// In this case [`Self::flush()`] should be called to force the currently buffered writes
/// to be written to the connection.
///
/// Returns the number of bytes buffered/written.
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, ProtocolError> {
if self.is_opened() {
if !self
.record_write_buf
.contains(ClientRecordHeader::ApplicationData)
{
self.flush().await?;
self.record_write_buf
.start_record(ClientRecordHeader::ApplicationData)?;
}
let buffered = self.record_write_buf.append(buf);
if self.record_write_buf.is_full() {
self.flush().await?;
}
Ok(buffered)
} else {
Err(ProtocolError::MissingHandshake)
}
}
/// Force all previously written, buffered bytes to be encoded into a tls record and written
/// to the connection.
pub async fn flush(&mut self) -> Result<(), ProtocolError> {
if !self.record_write_buf.is_empty() {
let key_schedule = self.key_schedule.write_state();
let slice = self.record_write_buf.close_record(key_schedule)?;
self.delegate
.write_all(slice)
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
key_schedule.increment_counter();
if self.flush_policy.flush_transport() {
self.flush_transport().await?;
}
}
Ok(())
}
#[inline]
async fn flush_transport(&mut self) -> Result<(), ProtocolError> {
self.delegate
.flush()
.await
.map_err(|e| ProtocolError::Io(e.kind()))
}
fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
self.decrypted.create_read_buffer(self.record_reader.buf)
}
/// Read and decrypt data filling the provided slice.
pub async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProtocolError> {
if buf.is_empty() {
return Ok(0);
}
let mut buffer = self.read_buffered().await?;
let len = buffer.pop_into(buf);
trace!("Copied {} bytes", len);
Ok(len)
}
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, ProtocolError> {
if self.is_opened() {
while self.decrypted.is_empty() {
self.read_application_data().await?;
}
Ok(self.create_read_buffer())
} else {
Err(ProtocolError::MissingHandshake)
}
}
async fn read_application_data(&mut self) -> Result<(), ProtocolError> {
let buf_ptr_range = self.record_reader.buf.as_ptr_range();
let record = self
.record_reader
.read(&mut self.delegate, self.key_schedule.read_state())
.await?;
let mut handler = DecryptedReadHandler {
source_buffer: buf_ptr_range,
buffer_info: &mut self.decrypted,
is_open: self.opened.get_mut(),
};
decrypt_record(
self.key_schedule.read_state(),
record,
|_key_schedule, record| handler.handle(record),
)?;
Ok(())
}
/// Close a connection instance, returning the ownership of the config, random generator and the async I/O provider.
async fn close_internal(&mut self) -> Result<(), ProtocolError> {
self.flush().await?;
let is_opened = self.is_opened();
let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
let slice = self.record_write_buf.write_record(
&ClientRecord::close_notify(is_opened),
write_key_schedule,
Some(read_key_schedule),
)?;
self.delegate
.write_all(slice)
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
self.key_schedule.write_state().increment_counter();
self.flush_transport().await
}
/// Close a connection instance, returning the ownership of the async I/O provider.
pub async fn close(mut self) -> Result<Socket, (Socket, ProtocolError)> {
match self.close_internal().await {
Ok(()) => Ok(self.delegate),
Err(e) => Err((self.delegate, e)),
}
}
pub fn split(
&mut self,
) -> (
TlsReader<'_, Socket, CipherSuite>,
TlsWriter<'_, Socket, CipherSuite>,
)
where
Socket: Clone,
{
let (wks, rks) = self.key_schedule.as_split();
let reader = TlsReader {
opened: &self.opened,
delegate: self.delegate.clone(),
key_schedule: rks,
record_reader: self.record_reader.reborrow_mut(),
decrypted: &mut self.decrypted,
};
let writer = TlsWriter {
opened: &self.opened,
delegate: self.delegate.clone(),
key_schedule: wks,
record_write_buf: self.record_write_buf.reborrow_mut(),
flush_policy: self.flush_policy,
};
(reader, writer)
}
}
impl<'a, Socket, CipherSuite> ErrorType for SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<'a, Socket, CipherSuite> AsyncRead for SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
SecureStream::read(self, buf).await
}
}
impl<'a, Socket, CipherSuite> BufRead for SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
self.read_buffered().await.map(|mut buf| buf.peek_all())
}
fn consume(&mut self, amt: usize) {
self.create_read_buffer().pop(amt);
}
}
impl<'a, Socket, CipherSuite> AsyncWrite for SecureStream<'a, Socket, CipherSuite>
where
Socket: AsyncRead + AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
SecureStream::write(self, buf).await
}
async fn flush(&mut self) -> Result<(), Self::Error> {
SecureStream::flush(self).await
}
}
pub struct TlsReader<'a, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
opened: &'a AtomicBool,
delegate: Socket,
key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
record_reader: RecordReaderBorrowMut<'a>,
decrypted: &'a mut DecryptedBufferInfo,
}
impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
fn as_ref(&self) -> &Socket {
&self.delegate
}
}
impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
where
Socket: AsyncRead + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
self.decrypted.create_read_buffer(self.record_reader.buf)
}
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
pub async fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, ProtocolError> {
if self.opened.load(Ordering::Acquire) {
while self.decrypted.is_empty() {
self.read_application_data().await?;
}
Ok(self.create_read_buffer())
} else {
Err(ProtocolError::MissingHandshake)
}
}
async fn read_application_data(&mut self) -> Result<(), ProtocolError> {
let buf_ptr_range = self.record_reader.buf.as_ptr_range();
let record = self
.record_reader
.read(&mut self.delegate, self.key_schedule)
.await?;
let mut opened = self.opened.load(Ordering::Acquire);
let mut handler = DecryptedReadHandler {
source_buffer: buf_ptr_range,
buffer_info: self.decrypted,
is_open: &mut opened,
};
let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
handler.handle(record)
});
if !opened {
self.opened.store(false, Ordering::Release);
}
result
}
}
pub struct TlsWriter<'a, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
opened: &'a AtomicBool,
delegate: Socket,
key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
record_write_buf: WriteBufferBorrowMut<'a>,
flush_policy: FlushPolicy,
}
impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
where
Socket: AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
#[inline]
async fn flush_transport(&mut self) -> Result<(), ProtocolError> {
self.delegate
.flush()
.await
.map_err(|e| ProtocolError::Io(e.kind()))
}
}
impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
fn as_ref(&self) -> &Socket {
&self.delegate
}
}
impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<'a, Socket, CipherSuite> AsyncRead for TlsReader<'a, Socket, CipherSuite>
where
Socket: AsyncRead + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if buf.is_empty() {
return Ok(0);
}
let mut buffer = self.read_buffered().await?;
let len = buffer.pop_into(buf);
trace!("Copied {} bytes", len);
Ok(len)
}
}
impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
where
Socket: AsyncRead + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
self.read_buffered().await.map(|mut buf| buf.peek_all())
}
fn consume(&mut self, amt: usize) {
self.create_read_buffer().pop(amt);
}
}
impl<'a, Socket, CipherSuite> AsyncWrite for TlsWriter<'a, Socket, CipherSuite>
where
Socket: AsyncWrite + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if self.opened.load(Ordering::Acquire) {
if !self
.record_write_buf
.contains(ClientRecordHeader::ApplicationData)
{
self.flush().await?;
self.record_write_buf
.start_record(ClientRecordHeader::ApplicationData)?;
}
let buffered = self.record_write_buf.append(buf);
if self.record_write_buf.is_full() {
self.flush().await?;
}
Ok(buffered)
} else {
Err(ProtocolError::MissingHandshake)
}
}
async fn flush(&mut self) -> Result<(), Self::Error> {
if !self.record_write_buf.is_empty() {
let slice = self.record_write_buf.close_record(self.key_schedule)?;
self.delegate
.write_all(slice)
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
self.key_schedule.increment_counter();
if self.flush_policy.flush_transport() {
self.flush_transport().await?;
}
}
Ok(())
}
}

525
src/blocking.rs Normal file
View File

@@ -0,0 +1,525 @@
use core::sync::atomic::Ordering;
use crate::common::decrypted_buffer_info::DecryptedBufferInfo;
use crate::common::decrypted_read_handler::DecryptedReadHandler;
use crate::connection::{Handshake, State, decrypt_record};
use crate::send_policy::FlushPolicy;
use crate::key_schedule::KeySchedule;
use crate::key_schedule::{ReadKeySchedule, WriteKeySchedule};
use crate::read_buffer::ReadBuffer;
use crate::record::{ClientRecord, ClientRecordHeader};
use crate::record_reader::{RecordReader, RecordReaderBorrowMut};
use crate::write_buffer::{WriteBuffer, WriteBufferBorrowMut};
use embedded_io::Error as _;
use embedded_io::{BufRead, ErrorType, Read, Write};
use portable_atomic::AtomicBool;
pub use crate::ProtocolError;
pub use crate::config::*;
/// Type representing a TLS connection. An instance of this type can
/// be used to establish a TLS connection, write and read encrypted data over this connection,
/// and closing to free up the underlying resources.
pub struct SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
delegate: Socket,
opened: AtomicBool,
key_schedule: KeySchedule<CipherSuite>,
record_reader: RecordReader<'a>,
record_write_buf: WriteBuffer<'a>,
decrypted: DecryptedBufferInfo,
flush_policy: FlushPolicy,
}
impl<'a, Socket, CipherSuite> SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn is_opened(&mut self) -> bool {
*self.opened.get_mut()
}
/// Create a new TLS connection with the provided context and a blocking I/O implementation
///
/// NOTE: The record read buffer should be sized to fit an encrypted TLS record. The size of this record
/// depends on the server configuration, but the maximum allowed value for a TLS record is 16640 bytes,
/// which should be a safe value to use.
///
/// The write record buffer can be smaller than the read buffer. During writes [`TLS_RECORD_OVERHEAD`] bytes of
/// overhead is added per record, so the buffer must at least be this large. Large writes are split into multiple
/// records if depending on the size of the write buffer.
/// The largest of the two buffers will be used to encode the TLS handshake record, hence either of the
/// buffers must at least be large enough to encode a handshake.
pub fn new(
delegate: Socket,
record_read_buf: &'a mut [u8],
record_write_buf: &'a mut [u8],
) -> Self {
Self {
delegate,
opened: AtomicBool::new(false),
key_schedule: KeySchedule::new(),
record_reader: RecordReader::new(record_read_buf),
record_write_buf: WriteBuffer::new(record_write_buf),
decrypted: DecryptedBufferInfo::default(),
flush_policy: FlushPolicy::default(),
}
}
/// Returns a reference to the current flush policy.
///
/// The flush policy controls whether the underlying transport is flushed
/// (via its `flush()` method) after writing a TLS record.
#[inline]
pub fn flush_policy(&self) -> FlushPolicy {
self.flush_policy
}
/// Replace the current flush policy with the provided one.
///
/// This sets how and when the connection will call `flush()` on the
/// underlying transport after writing records.
#[inline]
pub fn set_flush_policy(&mut self, policy: FlushPolicy) {
self.flush_policy = policy;
}
/// Open a TLS connection, performing the handshake with the configuration provided when
/// creating the connection instance.
///
/// Returns an error if the handshake does not proceed. If an error occurs, the connection
/// instance must be recreated.
pub fn open<CP>(&mut self, mut context: ConnectContext<CP>) -> Result<(), ProtocolError>
where
CP: CryptoBackend<CipherSuite = CipherSuite>,
{
let mut handshake: Handshake<CipherSuite> = Handshake::new();
if let (Ok(verifier), Some(server_name)) = (
context.crypto_provider.verifier(),
context.config.server_name,
) {
verifier.set_hostname_verification(server_name)?;
}
let mut state = State::ClientHello;
while state != State::ApplicationData {
let next_state = state.process_blocking(
&mut self.delegate,
&mut handshake,
&mut self.record_reader,
&mut self.record_write_buf,
&mut self.key_schedule,
context.config,
&mut context.crypto_provider,
)?;
trace!("State {:?} -> {:?}", state, next_state);
state = next_state;
}
*self.opened.get_mut() = true;
Ok(())
}
/// Encrypt and send the provided slice over the connection. The connection
/// must be opened before writing.
///
/// The slice may be buffered internally and not written to the connection immediately.
/// In this case [`Self::flush()`] should be called to force the currently buffered writes
/// to be written to the connection.
///
/// Returns the number of bytes buffered/written.
pub fn write(&mut self, buf: &[u8]) -> Result<usize, ProtocolError> {
if self.is_opened() {
if !self
.record_write_buf
.contains(ClientRecordHeader::ApplicationData)
{
self.flush()?;
self.record_write_buf
.start_record(ClientRecordHeader::ApplicationData)?;
}
let buffered = self.record_write_buf.append(buf);
if self.record_write_buf.is_full() {
self.flush()?;
}
Ok(buffered)
} else {
Err(ProtocolError::MissingHandshake)
}
}
/// Force all previously written, buffered bytes to be encoded into a tls record and written
/// to the connection.
pub fn flush(&mut self) -> Result<(), ProtocolError> {
if !self.record_write_buf.is_empty() {
let key_schedule = self.key_schedule.write_state();
let slice = self.record_write_buf.close_record(key_schedule)?;
self.delegate
.write_all(slice)
.map_err(|e| ProtocolError::Io(e.kind()))?;
key_schedule.increment_counter();
if self.flush_policy.flush_transport() {
self.flush_transport()?;
}
}
Ok(())
}
#[inline]
fn flush_transport(&mut self) -> Result<(), ProtocolError> {
self.delegate.flush().map_err(|e| ProtocolError::Io(e.kind()))
}
fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
self.decrypted.create_read_buffer(self.record_reader.buf)
}
/// Read and decrypt data filling the provided slice.
pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, ProtocolError> {
if buf.is_empty() {
return Ok(0);
}
let mut buffer = self.read_buffered()?;
let len = buffer.pop_into(buf);
trace!("Copied {} bytes", len);
Ok(len)
}
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, ProtocolError> {
if self.is_opened() {
while self.decrypted.is_empty() {
self.read_application_data()?;
}
Ok(self.create_read_buffer())
} else {
Err(ProtocolError::MissingHandshake)
}
}
fn read_application_data(&mut self) -> Result<(), ProtocolError> {
let buf_ptr_range = self.record_reader.buf.as_ptr_range();
let key_schedule = self.key_schedule.read_state();
let record = self
.record_reader
.read_blocking(&mut self.delegate, key_schedule)?;
let mut handler = DecryptedReadHandler {
source_buffer: buf_ptr_range,
buffer_info: &mut self.decrypted,
is_open: self.opened.get_mut(),
};
decrypt_record(key_schedule, record, |_key_schedule, record| {
handler.handle(record)
})?;
Ok(())
}
fn close_internal(&mut self) -> Result<(), ProtocolError> {
self.flush()?;
let is_opened = self.is_opened();
let (write_key_schedule, read_key_schedule) = self.key_schedule.as_split();
let slice = self.record_write_buf.write_record(
&ClientRecord::close_notify(is_opened),
write_key_schedule,
Some(read_key_schedule),
)?;
self.delegate
.write_all(slice)
.map_err(|e| ProtocolError::Io(e.kind()))?;
self.key_schedule.write_state().increment_counter();
self.flush_transport()?;
Ok(())
}
/// Close a connection instance, returning the ownership of the I/O provider.
pub fn close(mut self) -> Result<Socket, (Socket, ProtocolError)> {
match self.close_internal() {
Ok(()) => Ok(self.delegate),
Err(e) => Err((self.delegate, e)),
}
}
pub fn split(
&mut self,
) -> (
TlsReader<'_, Socket, CipherSuite>,
TlsWriter<'_, Socket, CipherSuite>,
)
where
Socket: Clone,
{
let (wks, rks) = self.key_schedule.as_split();
let reader = TlsReader {
opened: &self.opened,
delegate: self.delegate.clone(),
key_schedule: rks,
record_reader: self.record_reader.reborrow_mut(),
decrypted: &mut self.decrypted,
};
let writer = TlsWriter {
opened: &self.opened,
delegate: self.delegate.clone(),
key_schedule: wks,
record_write_buf: self.record_write_buf.reborrow_mut(),
flush_policy: self.flush_policy,
};
(reader, writer)
}
}
impl<'a, Socket, CipherSuite> ErrorType for SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<'a, Socket, CipherSuite> Read for SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
SecureStream::read(self, buf)
}
}
impl<'a, Socket, CipherSuite> BufRead for SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
self.read_buffered().map(|mut buf| buf.peek_all())
}
fn consume(&mut self, amt: usize) {
self.create_read_buffer().pop(amt);
}
}
impl<'a, Socket, CipherSuite> Write for SecureStream<'a, Socket, CipherSuite>
where
Socket: Read + Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
SecureStream::write(self, buf)
}
fn flush(&mut self) -> Result<(), Self::Error> {
SecureStream::flush(self)
}
}
pub struct TlsReader<'a, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
opened: &'a AtomicBool,
delegate: Socket,
key_schedule: &'a mut ReadKeySchedule<CipherSuite>,
record_reader: RecordReaderBorrowMut<'a>,
decrypted: &'a mut DecryptedBufferInfo,
}
impl<Socket, CipherSuite> AsRef<Socket> for TlsReader<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
fn as_ref(&self) -> &Socket {
&self.delegate
}
}
impl<'a, Socket, CipherSuite> TlsReader<'a, Socket, CipherSuite>
where
Socket: Read + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn create_read_buffer(&mut self) -> ReadBuffer<'_> {
self.decrypted.create_read_buffer(self.record_reader.buf)
}
/// Reads buffered data. If nothing is in memory, it'll wait for a TLS record and process it.
pub fn read_buffered(&mut self) -> Result<ReadBuffer<'_>, ProtocolError> {
if self.opened.load(Ordering::Acquire) {
while self.decrypted.is_empty() {
self.read_application_data()?;
}
Ok(self.create_read_buffer())
} else {
Err(ProtocolError::MissingHandshake)
}
}
fn read_application_data(&mut self) -> Result<(), ProtocolError> {
let buf_ptr_range = self.record_reader.buf.as_ptr_range();
let record = self
.record_reader
.read_blocking(&mut self.delegate, self.key_schedule)?;
let mut opened = self.opened.load(Ordering::Acquire);
let mut handler = DecryptedReadHandler {
source_buffer: buf_ptr_range,
buffer_info: self.decrypted,
is_open: &mut opened,
};
let result = decrypt_record(self.key_schedule, record, |_key_schedule, record| {
handler.handle(record)
});
if !opened {
self.opened.store(false, Ordering::Release);
}
result
}
}
pub struct TlsWriter<'a, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
opened: &'a AtomicBool,
delegate: Socket,
key_schedule: &'a mut WriteKeySchedule<CipherSuite>,
record_write_buf: WriteBufferBorrowMut<'a>,
flush_policy: FlushPolicy,
}
impl<'a, Socket, CipherSuite> TlsWriter<'a, Socket, CipherSuite>
where
Socket: Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn flush_transport(&mut self) -> Result<(), ProtocolError> {
self.delegate.flush().map_err(|e| ProtocolError::Io(e.kind()))
}
}
impl<Socket, CipherSuite> AsRef<Socket> for TlsWriter<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
fn as_ref(&self) -> &Socket {
&self.delegate
}
}
impl<Socket, CipherSuite> ErrorType for TlsWriter<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<Socket, CipherSuite> ErrorType for TlsReader<'_, Socket, CipherSuite>
where
CipherSuite: TlsCipherSuite + 'static,
{
type Error = ProtocolError;
}
impl<'a, Socket, CipherSuite> Read for TlsReader<'a, Socket, CipherSuite>
where
Socket: Read + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if buf.is_empty() {
return Ok(0);
}
let mut buffer = self.read_buffered()?;
let len = buffer.pop_into(buf);
trace!("Copied {} bytes", len);
Ok(len)
}
}
impl<'a, Socket, CipherSuite> BufRead for TlsReader<'a, Socket, CipherSuite>
where
Socket: Read + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn fill_buf(&mut self) -> Result<&[u8], Self::Error> {
self.read_buffered().map(|mut buf| buf.peek_all())
}
fn consume(&mut self, amt: usize) {
self.create_read_buffer().pop(amt);
}
}
impl<'a, Socket, CipherSuite> Write for TlsWriter<'a, Socket, CipherSuite>
where
Socket: Write + 'a,
CipherSuite: TlsCipherSuite + 'static,
{
fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
if self.opened.load(Ordering::Acquire) {
if !self
.record_write_buf
.contains(ClientRecordHeader::ApplicationData)
{
self.flush()?;
self.record_write_buf
.start_record(ClientRecordHeader::ApplicationData)?;
}
let buffered = self.record_write_buf.append(buf);
if self.record_write_buf.is_full() {
self.flush()?;
}
Ok(buffered)
} else {
Err(ProtocolError::MissingHandshake)
}
}
fn flush(&mut self) -> Result<(), Self::Error> {
if !self.record_write_buf.is_empty() {
let slice = self.record_write_buf.close_record(self.key_schedule)?;
self.delegate
.write_all(slice)
.map_err(|e| ProtocolError::Io(e.kind()))?;
self.key_schedule.increment_counter();
if self.flush_policy.flush_transport() {
self.flush_transport()?;
}
}
Ok(())
}
}

304
src/buffer.rs Normal file
View File

@@ -0,0 +1,304 @@
use crate::ProtocolError;
use aes_gcm::Error;
use aes_gcm::aead::Buffer;
pub struct CryptoBuffer<'b> {
buf: &'b mut [u8],
offset: usize,
len: usize,
}
impl<'b> CryptoBuffer<'b> {
#[allow(dead_code)]
pub(crate) fn empty() -> Self {
Self {
buf: &mut [],
offset: 0,
len: 0,
}
}
pub(crate) fn wrap(buf: &'b mut [u8]) -> Self {
Self {
buf,
offset: 0,
len: 0,
}
}
pub(crate) fn wrap_with_pos(buf: &'b mut [u8], pos: usize) -> Self {
Self {
buf,
offset: 0,
len: pos,
}
}
pub fn push(&mut self, b: u8) -> Result<(), ProtocolError> {
if self.space() > 0 {
self.buf[self.offset + self.len] = b;
self.len += 1;
Ok(())
} else {
error!("Failed to push byte");
Err(ProtocolError::InsufficientSpace)
}
}
pub fn push_u16(&mut self, num: u16) -> Result<(), ProtocolError> {
let data = num.to_be_bytes();
self.extend_from_slice(&data)
}
pub fn push_u24(&mut self, num: u32) -> Result<(), ProtocolError> {
let data = num.to_be_bytes();
self.extend_from_slice(&[data[1], data[2], data[3]])
}
pub fn push_u32(&mut self, num: u32) -> Result<(), ProtocolError> {
let data = num.to_be_bytes();
self.extend_from_slice(&data)
}
fn set(&mut self, idx: usize, val: u8) -> Result<(), ProtocolError> {
if idx < self.len {
self.buf[self.offset + idx] = val;
Ok(())
} else {
error!(
"Failed to set byte: index {} is out of range for {} elements",
idx, self.len
);
Err(ProtocolError::InsufficientSpace)
}
}
fn set_u16(&mut self, idx: usize, val: u16) -> Result<(), ProtocolError> {
let [upper, lower] = val.to_be_bytes();
self.set(idx, upper)?;
self.set(idx + 1, lower)?;
Ok(())
}
fn set_u24(&mut self, idx: usize, val: u32) -> Result<(), ProtocolError> {
let [_, upper, mid, lower] = val.to_be_bytes();
self.set(idx, upper)?;
self.set(idx + 1, mid)?;
self.set(idx + 2, lower)?;
Ok(())
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.buf[self.offset..self.offset + self.len]
}
pub fn as_slice(&self) -> &[u8] {
&self.buf[self.offset..self.offset + self.len]
}
fn extend_internal(&mut self, other: &[u8]) -> Result<(), ProtocolError> {
if self.space() < other.len() {
error!(
"Failed to extend buffer. Space: {} required: {}",
self.space(),
other.len()
);
Err(ProtocolError::InsufficientSpace)
} else {
let start = self.offset + self.len;
self.buf[start..start + other.len()].clone_from_slice(other);
self.len += other.len();
Ok(())
}
}
fn space(&self) -> usize {
self.capacity() - (self.offset + self.len)
}
pub fn extend_from_slice(&mut self, other: &[u8]) -> Result<(), ProtocolError> {
self.extend_internal(other)
}
fn truncate_internal(&mut self, len: usize) {
if len <= self.capacity() - self.offset {
self.len = len;
}
}
pub fn truncate(&mut self, len: usize) {
self.truncate_internal(len);
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn release(self) -> (&'b mut [u8], usize, usize) {
(self.buf, self.offset, self.len)
}
pub fn capacity(&self) -> usize {
self.buf.len()
}
pub fn forward(self) -> CryptoBuffer<'b> {
let len = self.len;
self.offset(len)
}
pub fn rewind(self) -> CryptoBuffer<'b> {
self.offset(0)
}
pub(crate) fn offset(self, offset: usize) -> CryptoBuffer<'b> {
let new_len = self.len + self.offset - offset;
/*info!(
"offset({}) len({}) -> offset({}), len({})",
self.offset, self.len, offset, new_len
);*/
CryptoBuffer {
buf: self.buf,
len: new_len,
offset,
}
}
pub fn with_u8_length<R>(
&mut self,
op: impl FnOnce(&mut Self) -> Result<R, ProtocolError>,
) -> Result<R, ProtocolError> {
let len_pos = self.len;
self.push(0)?;
let start = self.len;
let r = op(self)?;
let len = (self.len() - start) as u8;
self.set(len_pos, len)?;
Ok(r)
}
pub fn with_u16_length<R>(
&mut self,
op: impl FnOnce(&mut Self) -> Result<R, ProtocolError>,
) -> Result<R, ProtocolError> {
let len_pos = self.len;
self.push_u16(0)?;
let start = self.len;
let r = op(self)?;
let len = (self.len() - start) as u16;
self.set_u16(len_pos, len)?;
Ok(r)
}
pub fn with_u24_length<R>(
&mut self,
op: impl FnOnce(&mut Self) -> Result<R, ProtocolError>,
) -> Result<R, ProtocolError> {
let len_pos = self.len;
self.push_u24(0)?;
let start = self.len;
let r = op(self)?;
let len = (self.len() - start) as u32;
self.set_u24(len_pos, len)?;
Ok(r)
}
}
impl AsRef<[u8]> for CryptoBuffer<'_> {
fn as_ref(&self) -> &[u8] {
self.as_slice()
}
}
impl AsMut<[u8]> for CryptoBuffer<'_> {
fn as_mut(&mut self) -> &mut [u8] {
self.as_mut_slice()
}
}
impl Buffer for CryptoBuffer<'_> {
fn extend_from_slice(&mut self, other: &[u8]) -> Result<(), Error> {
self.extend_internal(other).map_err(|_| Error)
}
fn truncate(&mut self, len: usize) {
self.truncate_internal(len);
}
}
#[cfg(test)]
mod test {
use super::CryptoBuffer;
#[test]
fn encode() {
let mut buf1 = [0; 4];
let mut c = CryptoBuffer::wrap(&mut buf1);
c.push_u24(1027).unwrap();
let mut buf2 = [0; 4];
let mut c = CryptoBuffer::wrap(&mut buf2);
c.push_u24(0).unwrap();
c.set_u24(0, 1027).unwrap();
assert_eq!(buf1, buf2);
let decoded = u32::from_be_bytes([0, buf1[0], buf1[1], buf1[2]]);
assert_eq!(1027, decoded);
}
#[test]
fn offset_calc() {
let mut buf = [0; 8];
let mut c = CryptoBuffer::wrap(&mut buf);
c.push(1).unwrap();
c.push(2).unwrap();
c.push(3).unwrap();
assert_eq!(&[1, 2, 3], c.as_slice());
let l = c.len();
let mut c = c.offset(l);
c.push(4).unwrap();
c.push(5).unwrap();
c.push(6).unwrap();
assert_eq!(&[4, 5, 6], c.as_slice());
let mut c = c.offset(0);
c.push(7).unwrap();
c.push(8).unwrap();
assert_eq!(&[1, 2, 3, 4, 5, 6, 7, 8], c.as_slice());
let mut c = c.offset(6);
c.set(0, 14).unwrap();
c.set(1, 15).unwrap();
let c = c.offset(0);
assert_eq!(&[1, 2, 3, 4, 5, 6, 14, 15], c.as_slice());
let mut c = c.offset(4);
c.truncate(0);
c.extend_from_slice(&[10, 11, 12, 13]).unwrap();
assert_eq!(&[10, 11, 12, 13], c.as_slice());
let c = c.offset(0);
assert_eq!(&[1, 2, 3, 4, 10, 11, 12, 13], c.as_slice());
}
}

296
src/cert_verify.rs Normal file
View File

@@ -0,0 +1,296 @@
use crate::ProtocolError;
use crate::config::{Certificate, TlsCipherSuite, TlsClock, Verifier};
use crate::extensions::extension_data::signature_algorithms::SignatureScheme;
use crate::handshake::{
certificate::{
Certificate as OwnedCertificate, CertificateEntryRef, CertificateRef as ServerCertificate,
},
certificate_verify::HandshakeVerifyRef,
};
use core::marker::PhantomData;
use digest::Digest;
use heapless::Vec;
#[cfg(all(not(feature = "alloc"), feature = "webpki"))]
impl TryInto<&'static webpki::SignatureAlgorithm> for SignatureScheme {
type Error = ProtocolError;
fn try_into(self) -> Result<&'static webpki::SignatureAlgorithm, Self::Error> {
// TODO: support other schemes via 'alloc' feature
#[allow(clippy::match_same_arms)] // Style
match self {
SignatureScheme::RsaPkcs1Sha256
| SignatureScheme::RsaPkcs1Sha384
| SignatureScheme::RsaPkcs1Sha512 => Err(ProtocolError::InvalidSignatureScheme),
/* ECDSA algorithms */
SignatureScheme::EcdsaSecp256r1Sha256 => Ok(&webpki::ECDSA_P256_SHA256),
SignatureScheme::EcdsaSecp384r1Sha384 => Ok(&webpki::ECDSA_P384_SHA384),
SignatureScheme::EcdsaSecp521r1Sha512 => Err(ProtocolError::InvalidSignatureScheme),
/* RSASSA-PSS algorithms with public key OID rsaEncryption */
SignatureScheme::RsaPssRsaeSha256
| SignatureScheme::RsaPssRsaeSha384
| SignatureScheme::RsaPssRsaeSha512 => Err(ProtocolError::InvalidSignatureScheme),
/* EdDSA algorithms */
SignatureScheme::Ed25519 => Ok(&webpki::ED25519),
SignatureScheme::Ed448
| SignatureScheme::Sha224Ecdsa
| SignatureScheme::Sha224Rsa
| SignatureScheme::Sha224Dsa => Err(ProtocolError::InvalidSignatureScheme),
/* RSASSA-PSS algorithms with public key OID RSASSA-PSS */
SignatureScheme::RsaPssPssSha256
| SignatureScheme::RsaPssPssSha384
| SignatureScheme::RsaPssPssSha512 => Err(ProtocolError::InvalidSignatureScheme),
/* Legacy algorithms */
SignatureScheme::RsaPkcs1Sha1 | SignatureScheme::EcdsaSha1 => {
Err(ProtocolError::InvalidSignatureScheme)
}
/* Ml-DSA */
SignatureScheme::MlDsa44 | SignatureScheme::MlDsa65 | SignatureScheme::MlDsa87 => {
Err(ProtocolError::InvalidSignatureScheme)
}
/* Brainpool */
SignatureScheme::Sha256BrainpoolP256r1
| SignatureScheme::Sha384BrainpoolP384r1
| SignatureScheme::Sha512BrainpoolP512r1 => Err(ProtocolError::InvalidSignatureScheme),
}
}
}
#[cfg(all(feature = "alloc", feature = "webpki"))]
impl TryInto<&'static webpki::SignatureAlgorithm> for SignatureScheme {
type Error = ProtocolError;
fn try_into(self) -> Result<&'static webpki::SignatureAlgorithm, Self::Error> {
match self {
SignatureScheme::RsaPkcs1Sha256 => Ok(&webpki::RSA_PKCS1_2048_8192_SHA256),
SignatureScheme::RsaPkcs1Sha384 => Ok(&webpki::RSA_PKCS1_2048_8192_SHA384),
SignatureScheme::RsaPkcs1Sha512 => Ok(&webpki::RSA_PKCS1_2048_8192_SHA512),
/* ECDSA algorithms */
SignatureScheme::EcdsaSecp256r1Sha256 => Ok(&webpki::ECDSA_P256_SHA256),
SignatureScheme::EcdsaSecp384r1Sha384 => Ok(&webpki::ECDSA_P384_SHA384),
SignatureScheme::EcdsaSecp521r1Sha512 => Err(ProtocolError::InvalidSignatureScheme),
/* RSASSA-PSS algorithms with public key OID rsaEncryption */
SignatureScheme::RsaPssRsaeSha256 => Ok(&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY),
SignatureScheme::RsaPssRsaeSha384 => Ok(&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY),
SignatureScheme::RsaPssRsaeSha512 => Ok(&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY),
/* EdDSA algorithms */
SignatureScheme::Ed25519 => Ok(&webpki::ED25519),
SignatureScheme::Ed448 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::Sha224Ecdsa => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::Sha224Rsa => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::Sha224Dsa => Err(ProtocolError::InvalidSignatureScheme),
/* RSASSA-PSS algorithms with public key OID RSASSA-PSS */
SignatureScheme::RsaPssPssSha256 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::RsaPssPssSha384 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::RsaPssPssSha512 => Err(ProtocolError::InvalidSignatureScheme),
/* Legacy algorithms */
SignatureScheme::RsaPkcs1Sha1 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::EcdsaSha1 => Err(ProtocolError::InvalidSignatureScheme),
/* MlDsa */
SignatureScheme::MlDsa44 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::MlDsa65 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::MlDsa87 => Err(ProtocolError::InvalidSignatureScheme),
/* Brainpool */
SignatureScheme::Sha256BrainpoolP256r1 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::Sha384BrainpoolP384r1 => Err(ProtocolError::InvalidSignatureScheme),
SignatureScheme::Sha512BrainpoolP512r1 => Err(ProtocolError::InvalidSignatureScheme),
}
}
}
static ALL_SIGALGS: &[&webpki::SignatureAlgorithm] = &[
&webpki::ECDSA_P256_SHA256,
&webpki::ECDSA_P256_SHA384,
&webpki::ECDSA_P384_SHA256,
&webpki::ECDSA_P384_SHA384,
&webpki::ED25519,
];
pub struct CertVerifier<'a, CipherSuite, Clock, const CERT_SIZE: usize>
where
Clock: TlsClock,
CipherSuite: TlsCipherSuite,
{
ca: Certificate<&'a [u8]>,
host: Option<heapless::String<64>>,
certificate_transcript: Option<CipherSuite::Hash>,
certificate: Option<OwnedCertificate<CERT_SIZE>>,
_clock: PhantomData<Clock>,
}
impl<'a, CipherSuite, Clock, const CERT_SIZE: usize> CertVerifier<'a, CipherSuite, Clock, CERT_SIZE>
where
Clock: TlsClock,
CipherSuite: TlsCipherSuite,
{
#[must_use]
pub fn new(ca: Certificate<&'a [u8]>) -> Self {
Self {
ca,
host: None,
certificate_transcript: None,
certificate: None,
_clock: PhantomData,
}
}
}
impl<CipherSuite, Clock, const CERT_SIZE: usize> Verifier<CipherSuite>
for CertVerifier<'_, CipherSuite, Clock, CERT_SIZE>
where
CipherSuite: TlsCipherSuite,
Clock: TlsClock,
{
fn set_hostname_verification(&mut self, hostname: &str) -> Result<(), ProtocolError> {
self.host.replace(
heapless::String::try_from(hostname).map_err(|_| ProtocolError::InsufficientSpace)?,
);
Ok(())
}
fn verify_certificate(
&mut self,
transcript: &CipherSuite::Hash,
cert: ServerCertificate,
) -> Result<(), ProtocolError> {
verify_certificate(self.host.as_deref(), &self.ca, &cert, Clock::now())?;
self.certificate.replace(cert.try_into()?);
self.certificate_transcript.replace(transcript.clone());
Ok(())
}
fn verify_signature(&mut self, verify: HandshakeVerifyRef) -> Result<(), ProtocolError> {
let handshake_hash = unwrap!(self.certificate_transcript.take());
let ctx_str = b"TLS 1.3, server CertificateVerify\x00";
let mut msg: Vec<u8, 130> = Vec::new();
msg.resize(64, 0x20).map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(ctx_str)
.map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(&handshake_hash.finalize())
.map_err(|_| ProtocolError::EncodeError)?;
let certificate = unwrap!(self.certificate.as_ref()).try_into()?;
verify_signature(&msg[..], &certificate, &verify)?;
Ok(())
}
}
fn verify_signature(
message: &[u8],
certificate: &ServerCertificate,
verify: &HandshakeVerifyRef,
) -> Result<(), ProtocolError> {
let mut verified = false;
if !certificate.entries.is_empty() {
// TODO: Support intermediates...
if let CertificateEntryRef::X509(certificate) = certificate.entries[0] {
let cert = webpki::EndEntityCert::try_from(certificate).map_err(|e| {
warn!("ProtocolError loading cert: {:?}", e);
ProtocolError::DecodeError
})?;
trace!(
"Verifying with signature scheme {:?}",
verify.signature_scheme
);
info!("Signature: {:x?}", verify.signature);
let pkisig = verify.signature_scheme.try_into()?;
match cert.verify_signature(pkisig, message, verify.signature) {
Ok(()) => {
verified = true;
}
Err(e) => {
info!("ProtocolError verifying signature: {:?}", e);
}
}
}
}
if !verified {
return Err(ProtocolError::InvalidSignature);
}
Ok(())
}
fn verify_certificate(
verify_host: Option<&str>,
ca: &Certificate<&[u8]>,
certificate: &ServerCertificate,
now: Option<u64>,
) -> Result<(), ProtocolError> {
let mut verified = false;
let mut host_verified = false;
if let Certificate::X509(ca) = ca {
let trust = webpki::TrustAnchor::try_from_cert_der(ca).map_err(|e| {
warn!("ProtocolError loading CA: {:?}", e);
ProtocolError::DecodeError
})?;
trace!("We got {} certificate entries", certificate.entries.len());
if !certificate.entries.is_empty() {
// TODO: Support intermediates...
if let CertificateEntryRef::X509(certificate) = certificate.entries[0] {
let cert = webpki::EndEntityCert::try_from(certificate).map_err(|e| {
warn!("ProtocolError loading cert: {:?}", e);
ProtocolError::DecodeError
})?;
let time = if let Some(now) = now {
webpki::Time::from_seconds_since_unix_epoch(now)
} else {
// If no clock is provided, the validity check will fail
webpki::Time::from_seconds_since_unix_epoch(0)
};
info!("Certificate is loaded!");
match cert.verify_for_usage(
ALL_SIGALGS,
&[trust],
&[],
time,
webpki::KeyUsage::server_auth(),
&[],
) {
Ok(()) => verified = true,
Err(e) => {
warn!("ProtocolError verifying certificate: {:?}", e);
}
}
if let Some(server_name) = verify_host {
match webpki::SubjectNameRef::try_from_ascii(server_name.as_bytes()) {
Ok(subject) => match cert.verify_is_valid_for_subject_name(subject) {
Ok(()) => host_verified = true,
Err(e) => {
warn!("ProtocolError verifying host: {:?}", e);
}
},
Err(e) => {
warn!("ProtocolError verifying host: {:?}", e);
}
}
}
}
}
}
if !verified {
return Err(ProtocolError::InvalidCertificate);
}
if !host_verified && verify_host.is_some() {
return Err(ProtocolError::InvalidCertificate);
}
Ok(())
}

160
src/certificate.rs Normal file
View File

@@ -0,0 +1,160 @@
use core::cmp::Ordering;
use der::asn1::{
BitStringRef, GeneralizedTime, IntRef, ObjectIdentifier, SequenceOf, SetOf, UtcTime,
};
use der::{AnyRef, Choice, Enumerated, Sequence, ValueOrd};
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Sequence, ValueOrd)]
pub struct AlgorithmIdentifier<'a> {
pub oid: ObjectIdentifier,
pub parameters: Option<AnyRef<'a>>,
}
#[cfg(feature = "defmt")]
impl<'a> defmt::Format for AlgorithmIdentifier<'a> {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(fmt, "AlgorithmIdentifier:{}", &self.oid.as_bytes())
}
}
pub const ECDSA_SHA256: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.2"),
parameters: None,
};
#[cfg(feature = "p384")]
pub const ECDSA_SHA384: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.2.840.10045.4.3.3"),
parameters: None,
};
#[cfg(feature = "ed25519")]
pub const ED25519: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.3.101.112"),
parameters: None,
};
#[cfg(feature = "rsa")]
pub const RSA_PKCS1_SHA256: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.11"),
parameters: Some(AnyRef::NULL),
};
#[cfg(feature = "rsa")]
pub const RSA_PKCS1_SHA384: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.12"),
parameters: Some(AnyRef::NULL),
};
#[cfg(feature = "rsa")]
pub const RSA_PKCS1_SHA512: AlgorithmIdentifier = AlgorithmIdentifier {
oid: ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.13"),
parameters: Some(AnyRef::NULL),
};
#[derive(Debug, Clone, PartialEq, Eq, Copy, Enumerated)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[asn1(type = "INTEGER")]
#[repr(u8)]
pub enum Version {
/// Version 1 (default)
V1 = 0,
/// Version 2
V2 = 1,
/// Version 3
V3 = 2,
}
impl ValueOrd for Version {
fn value_cmp(&self, other: &Self) -> der::Result<Ordering> {
(*self as u8).value_cmp(&(*other as u8))
}
}
impl Default for Version {
fn default() -> Self {
Self::V1
}
}
#[derive(Sequence, ValueOrd)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct DecodedCertificate<'a> {
pub tbs_certificate: TbsCertificate<'a>,
pub signature_algorithm: AlgorithmIdentifier<'a>,
pub signature: BitStringRef<'a>,
}
#[derive(Debug, Sequence, ValueOrd)]
pub struct AttributeTypeAndValue<'a> {
pub oid: ObjectIdentifier,
pub value: AnyRef<'a>,
}
#[cfg(feature = "defmt")]
impl<'a> defmt::Format for AttributeTypeAndValue<'a> {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(
fmt,
"Attribute:{} Value:{}",
&self.oid.as_bytes(),
&self.value.value()
)
}
}
#[derive(Debug, Choice, ValueOrd)]
pub enum Time {
#[asn1(type = "UTCTime")]
UtcTime(UtcTime),
#[asn1(type = "GeneralizedTime")]
GeneralTime(GeneralizedTime),
}
#[cfg(feature = "defmt")]
impl defmt::Format for Time {
fn format(&self, fmt: defmt::Formatter) {
match self {
Time::UtcTime(utc_time) => {
defmt::write!(fmt, "UtcTime:{}", utc_time.to_unix_duration())
}
Time::GeneralTime(generalized_time) => {
defmt::write!(fmt, "GeneralTime:{}", generalized_time.to_unix_duration())
}
}
}
}
#[derive(Debug, Sequence, ValueOrd)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Validity {
pub not_before: Time,
pub not_after: Time,
}
#[derive(Debug, Sequence, ValueOrd)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SubjectPublicKeyInfoRef<'a> {
pub algorithm: AlgorithmIdentifier<'a>,
pub public_key: BitStringRef<'a>,
}
#[derive(Debug, Sequence, ValueOrd)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TbsCertificate<'a> {
#[asn1(context_specific = "0", default = "Default::default")]
pub version: Version,
pub serial_number: IntRef<'a>,
pub signature: AlgorithmIdentifier<'a>,
pub issuer: SequenceOf<SetOf<AttributeTypeAndValue<'a>, 1>, 7>,
pub validity: Validity,
pub subject: SequenceOf<SetOf<AttributeTypeAndValue<'a>, 1>, 7>,
pub subject_public_key_info: SubjectPublicKeyInfoRef<'a>,
#[asn1(context_specific = "1", tag_mode = "IMPLICIT", optional = "true")]
pub issuer_unique_id: Option<BitStringRef<'a>>,
#[asn1(context_specific = "2", tag_mode = "IMPLICIT", optional = "true")]
pub subject_unique_id: Option<BitStringRef<'a>>,
#[asn1(context_specific = "3", tag_mode = "EXPLICIT", optional = "true")]
pub extensions: Option<AnyRef<'a>>,
}

37
src/change_cipher_spec.rs Normal file
View File

@@ -0,0 +1,37 @@
use crate::ProtocolError;
use crate::buffer::CryptoBuffer;
use crate::parse_buffer::ParseBuffer;
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ChangeCipherSpec {}
#[allow(clippy::unnecessary_wraps)] // TODO
impl ChangeCipherSpec {
pub fn new() -> Self {
Self {}
}
pub fn read(_rx_buf: &mut [u8]) -> Result<Self, ProtocolError> {
// info!("change cipher spec of len={}", rx_buf.len());
// TODO: Decode data
Ok(Self {})
}
#[allow(dead_code)]
pub fn parse(_: &mut ParseBuffer) -> Result<Self, ProtocolError> {
Ok(Self {})
}
#[allow(dead_code, clippy::unused_self)]
pub(crate) fn encode(self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
buf.push(1).map_err(|_| ProtocolError::EncodeError)?;
Ok(())
}
}
impl Default for ChangeCipherSpec {
fn default() -> Self {
ChangeCipherSpec::new()
}
}

15
src/cipher.rs Normal file
View File

@@ -0,0 +1,15 @@
use crate::application_data::ApplicationData;
use crate::extensions::extension_data::supported_groups::NamedGroup;
use p256::ecdh::SharedSecret;
pub struct CryptoEngine {}
#[allow(clippy::unused_self, clippy::needless_pass_by_value)] // TODO
impl CryptoEngine {
pub fn new(_group: NamedGroup, _shared: SharedSecret) -> Self {
Self {}
}
#[allow(dead_code)]
pub fn decrypt(&self, _: &ApplicationData) {}
}

26
src/cipher_suites.rs Normal file
View File

@@ -0,0 +1,26 @@
use crate::parse_buffer::{ParseBuffer, ParseError};
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum CipherSuite {
TlsAes128GcmSha256 = 0x1301,
TlsAes256GcmSha384 = 0x1302,
TlsChacha20Poly1305Sha256 = 0x1303,
TlsAes128CcmSha256 = 0x1304,
TlsAes128Ccm8Sha256 = 0x1305,
TlsPskAes128GcmSha256 = 0x00A8,
}
impl CipherSuite {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u16()? {
v if v == Self::TlsAes128GcmSha256 as u16 => Ok(Self::TlsAes128GcmSha256),
v if v == Self::TlsAes256GcmSha384 as u16 => Ok(Self::TlsAes256GcmSha384),
v if v == Self::TlsChacha20Poly1305Sha256 as u16 => Ok(Self::TlsChacha20Poly1305Sha256),
v if v == Self::TlsAes128CcmSha256 as u16 => Ok(Self::TlsAes128CcmSha256),
v if v == Self::TlsAes128Ccm8Sha256 as u16 => Ok(Self::TlsAes128Ccm8Sha256),
v if v == Self::TlsPskAes128GcmSha256 as u16 => Ok(Self::TlsPskAes128GcmSha256),
_ => Err(ParseError::InvalidData),
}
}
}

View File

@@ -0,0 +1,24 @@
use crate::read_buffer::ReadBuffer;
#[derive(Default)]
pub struct DecryptedBufferInfo {
pub offset: usize,
pub len: usize,
pub consumed: usize,
}
impl DecryptedBufferInfo {
pub fn create_read_buffer<'b>(&'b mut self, buffer: &'b [u8]) -> ReadBuffer<'b> {
let offset = self.offset + self.consumed;
let end = self.offset + self.len;
ReadBuffer::new(&buffer[offset..end], &mut self.consumed)
}
pub fn len(&self) -> usize {
self.len - self.consumed
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

View File

@@ -0,0 +1,64 @@
use core::ops::Range;
use crate::{
ProtocolError, alert::AlertDescription, common::decrypted_buffer_info::DecryptedBufferInfo,
config::TlsCipherSuite, handshake::ServerHandshake, record::ServerRecord,
};
pub struct DecryptedReadHandler<'a> {
pub source_buffer: Range<*const u8>,
pub buffer_info: &'a mut DecryptedBufferInfo,
pub is_open: &'a mut bool,
}
impl DecryptedReadHandler<'_> {
pub fn handle<CipherSuite: TlsCipherSuite>(
&mut self,
record: ServerRecord<'_, CipherSuite>,
) -> Result<(), ProtocolError> {
match record {
ServerRecord::ApplicationData(data) => {
let slice = data.data.as_slice();
let slice_ptrs = slice.as_ptr_range();
debug_assert!(
self.source_buffer.contains(&slice_ptrs.start)
&& self.source_buffer.contains(&slice_ptrs.end)
);
let offset = unsafe {
// SAFETY: The assertion above ensures `slice` is a subslice of the read buffer.
// This, in turn, ensures we don't violate safety constraints of `offset_from`.
// TODO: We are only assuming here that the pointers are derived from the read
// buffer. While this is reasonable, and we don't do any pointer magic,
// it's not an invariant.
slice_ptrs.start.offset_from(self.source_buffer.start) as usize
};
self.buffer_info.offset = offset;
self.buffer_info.len = slice.len();
self.buffer_info.consumed = 0;
Ok(())
}
ServerRecord::Alert(alert) => {
if let AlertDescription::CloseNotify = alert.description {
*self.is_open = false;
Err(ProtocolError::ConnectionClosed)
} else {
Err(ProtocolError::InternalError)
}
}
ServerRecord::ChangeCipherSpec(_) => Err(ProtocolError::InternalError),
ServerRecord::Handshake(ServerHandshake::NewSessionTicket(_)) => {
// TODO: we should validate extensions and abort. We can do this automatically
// as long as the connection is unsplit, however, split connections must be aborted
// by the user.
Ok(())
}
ServerRecord::Handshake(_) => {
unimplemented!()
}
}
}
}

2
src/common/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod decrypted_buffer_info;
pub mod decrypted_read_handler;

421
src/config.rs Normal file
View File

@@ -0,0 +1,421 @@
use core::marker::PhantomData;
use crate::ProtocolError;
use crate::cipher_suites::CipherSuite;
use crate::extensions::extension_data::signature_algorithms::SignatureScheme;
use crate::extensions::extension_data::supported_groups::NamedGroup;
pub use crate::handshake::certificate::{CertificateEntryRef, CertificateRef};
pub use crate::handshake::certificate_verify::HandshakeVerifyRef;
use aes_gcm::{AeadInPlace, Aes128Gcm, Aes256Gcm, KeyInit};
use digest::core_api::BlockSizeUser;
use digest::{Digest, FixedOutput, OutputSizeUser, Reset};
use ecdsa::elliptic_curve::SecretKey;
use generic_array::ArrayLength;
use heapless::Vec;
use p256::ecdsa::SigningKey;
use rand_core::CryptoRngCore;
pub use sha2::{Sha256, Sha384};
use typenum::{Sum, U10, U12, U16, U32};
pub use crate::extensions::extension_data::max_fragment_length::MaxFragmentLength;
pub const TLS_RECORD_OVERHEAD: usize = 128;
// longest label is 12b -> buf <= 2 + 1 + 6 + longest + 1 + hash_out = hash_out + 22
type LongestLabel = U12;
type LabelOverhead = U10;
type LabelBuffer<CipherSuite> = Sum<
<<CipherSuite as TlsCipherSuite>::Hash as OutputSizeUser>::OutputSize,
Sum<LongestLabel, LabelOverhead>,
>;
/// Represents a TLS 1.3 cipher suite
pub trait TlsCipherSuite {
const CODE_POINT: u16;
type Cipher: KeyInit<KeySize = Self::KeyLen> + AeadInPlace<NonceSize = Self::IvLen>;
type KeyLen: ArrayLength<u8>;
type IvLen: ArrayLength<u8>;
type Hash: Digest + Reset + Clone + OutputSizeUser + BlockSizeUser + FixedOutput;
type LabelBufferSize: ArrayLength<u8>;
}
pub struct Aes128GcmSha256;
impl TlsCipherSuite for Aes128GcmSha256 {
const CODE_POINT: u16 = CipherSuite::TlsAes128GcmSha256 as u16;
type Cipher = Aes128Gcm;
type KeyLen = U16;
type IvLen = U12;
type Hash = Sha256;
type LabelBufferSize = LabelBuffer<Self>;
}
pub struct Aes256GcmSha384;
impl TlsCipherSuite for Aes256GcmSha384 {
const CODE_POINT: u16 = CipherSuite::TlsAes256GcmSha384 as u16;
type Cipher = Aes256Gcm;
type KeyLen = U32;
type IvLen = U12;
type Hash = Sha384;
type LabelBufferSize = LabelBuffer<Self>;
}
/// A TLS 1.3 verifier.
///
/// The verifier is responsible for verifying certificates and signatures. Since certificate verification is
/// an expensive process, this trait allows clients to choose how much verification should take place,
/// and also to skip the verification if the server is verified through other means (I.e. a pre-shared key).
pub trait Verifier<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
/// Host verification is enabled by passing a server hostname.
fn set_hostname_verification(&mut self, hostname: &str) -> Result<(), crate::ProtocolError>;
/// Verify a certificate.
///
/// The handshake transcript up to this point and the server certificate is provided
/// for the implementation to use. The verifier is responsible for resolving the CA
/// certificate internally.
fn verify_certificate(
&mut self,
transcript: &CipherSuite::Hash,
cert: CertificateRef,
) -> Result<(), ProtocolError>;
/// Verify the certificate signature.
///
/// The signature verification uses the transcript and certificate provided earlier to decode the provided signature.
fn verify_signature(&mut self, verify: HandshakeVerifyRef) -> Result<(), crate::ProtocolError>;
}
pub struct NoVerify;
impl<CipherSuite> Verifier<CipherSuite> for NoVerify
where
CipherSuite: TlsCipherSuite,
{
fn set_hostname_verification(&mut self, _hostname: &str) -> Result<(), crate::ProtocolError> {
Ok(())
}
fn verify_certificate(
&mut self,
_transcript: &CipherSuite::Hash,
_cert: CertificateRef,
) -> Result<(), ProtocolError> {
Ok(())
}
fn verify_signature(&mut self, _verify: HandshakeVerifyRef) -> Result<(), crate::ProtocolError> {
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[must_use = "ConnectConfig does nothing unless consumed"]
pub struct ConnectConfig<'a> {
pub(crate) server_name: Option<&'a str>,
pub(crate) alpn_protocols: Option<&'a [&'a [u8]]>,
pub(crate) psk: Option<(&'a [u8], Vec<&'a [u8], 4>)>,
pub(crate) signature_schemes: Vec<SignatureScheme, 25>,
pub(crate) named_groups: Vec<NamedGroup, 13>,
pub(crate) max_fragment_length: Option<MaxFragmentLength>,
}
pub trait TlsClock {
fn now() -> Option<u64>;
}
pub struct NoClock;
impl TlsClock for NoClock {
fn now() -> Option<u64> {
None
}
}
pub trait CryptoBackend {
type CipherSuite: TlsCipherSuite;
type Signature: AsRef<[u8]>;
fn rng(&mut self) -> impl CryptoRngCore;
fn verifier(&mut self) -> Result<&mut impl Verifier<Self::CipherSuite>, crate::ProtocolError> {
Err::<&mut NoVerify, _>(crate::ProtocolError::Unimplemented)
}
/// Provide a signing key for client certificate authentication.
///
/// The provider resolves the private key internally (e.g. from memory, flash, or a hardware
/// crypto module such as an HSM/TPM/secure element).
fn signer(
&mut self,
) -> Result<(impl signature::SignerMut<Self::Signature>, SignatureScheme), crate::ProtocolError>
{
Err::<(NoSign, _), crate::ProtocolError>(crate::ProtocolError::Unimplemented)
}
/// Resolve the client certificate for mutual TLS authentication.
///
/// Return `None` if no client certificate is available (an empty certificate message will
/// be sent to the server). The data type `D` can be borrowed (`&[u8]`) or owned
/// (e.g. `heapless::Vec<u8, N>`) — the certificate is only needed long enough to encode
/// into the TLS message.
fn client_cert(&mut self) -> Option<Certificate<impl AsRef<[u8]>>> {
None::<Certificate<&[u8]>>
}
}
impl<T: CryptoBackend> CryptoBackend for &mut T {
type CipherSuite = T::CipherSuite;
type Signature = T::Signature;
fn rng(&mut self) -> impl CryptoRngCore {
T::rng(self)
}
fn verifier(&mut self) -> Result<&mut impl Verifier<Self::CipherSuite>, crate::ProtocolError> {
T::verifier(self)
}
fn signer(
&mut self,
) -> Result<(impl signature::SignerMut<Self::Signature>, SignatureScheme), crate::ProtocolError>
{
T::signer(self)
}
fn client_cert(&mut self) -> Option<Certificate<impl AsRef<[u8]>>> {
T::client_cert(self)
}
}
pub struct NoSign;
impl<S> signature::Signer<S> for NoSign {
fn try_sign(&self, _msg: &[u8]) -> Result<S, signature::Error> {
unimplemented!()
}
}
pub struct SkipVerifyProvider<'a, CipherSuite, RNG> {
rng: RNG,
priv_key: Option<&'a [u8]>,
client_cert: Option<Certificate<&'a [u8]>>,
_marker: PhantomData<CipherSuite>,
}
impl<RNG: CryptoRngCore> SkipVerifyProvider<'_, (), RNG> {
pub fn new<CipherSuite: TlsCipherSuite>(
rng: RNG,
) -> SkipVerifyProvider<'static, CipherSuite, RNG> {
SkipVerifyProvider {
rng,
priv_key: None,
client_cert: None,
_marker: PhantomData,
}
}
}
impl<'a, CipherSuite: TlsCipherSuite, RNG: CryptoRngCore> SkipVerifyProvider<'a, CipherSuite, RNG> {
pub fn with_priv_key(mut self, priv_key: &'a [u8]) -> Self {
self.priv_key = Some(priv_key);
self
}
pub fn with_cert(mut self, cert: Certificate<&'a [u8]>) -> Self {
self.client_cert = Some(cert);
self
}
}
impl<CipherSuite: TlsCipherSuite, RNG: CryptoRngCore> CryptoBackend
for SkipVerifyProvider<'_, CipherSuite, RNG>
{
type CipherSuite = CipherSuite;
type Signature = p256::ecdsa::DerSignature;
fn rng(&mut self) -> impl CryptoRngCore {
&mut self.rng
}
fn signer(
&mut self,
) -> Result<(impl signature::SignerMut<Self::Signature>, SignatureScheme), crate::ProtocolError>
{
let key_der = self.priv_key.ok_or(ProtocolError::InvalidPrivateKey)?;
let secret_key =
SecretKey::from_sec1_der(key_der).map_err(|_| ProtocolError::InvalidPrivateKey)?;
Ok((
SigningKey::from(&secret_key),
SignatureScheme::EcdsaSecp256r1Sha256,
))
}
fn client_cert(&mut self) -> Option<Certificate<impl AsRef<[u8]>>> {
self.client_cert.clone()
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ConnectContext<'a, CP>
where
CP: CryptoBackend,
{
pub(crate) config: &'a ConnectConfig<'a>,
pub(crate) crypto_provider: CP,
}
impl<'a, CP> ConnectContext<'a, CP>
where
CP: CryptoBackend,
{
/// Create a new context with a given config and a crypto provider.
pub fn new(config: &'a ConnectConfig<'a>, crypto_provider: CP) -> Self {
Self {
config,
crypto_provider,
}
}
}
impl<'a> ConnectConfig<'a> {
pub fn new() -> Self {
let mut config = Self {
signature_schemes: Vec::new(),
named_groups: Vec::new(),
max_fragment_length: None,
psk: None,
server_name: None,
alpn_protocols: None,
};
if cfg!(feature = "alloc") {
config = config.enable_rsa_signatures();
}
unwrap!(
config
.signature_schemes
.push(SignatureScheme::EcdsaSecp256r1Sha256)
.ok()
);
unwrap!(
config
.signature_schemes
.push(SignatureScheme::EcdsaSecp384r1Sha384)
.ok()
);
unwrap!(config.signature_schemes.push(SignatureScheme::Ed25519).ok());
unwrap!(config.named_groups.push(NamedGroup::Secp256r1));
config
}
/// Enable RSA ciphers even if they might not be supported.
pub fn enable_rsa_signatures(mut self) -> Self {
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPkcs1Sha256)
.ok()
);
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPkcs1Sha384)
.ok()
);
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPkcs1Sha512)
.ok()
);
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPssRsaeSha256)
.ok()
);
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPssRsaeSha384)
.ok()
);
unwrap!(
self.signature_schemes
.push(SignatureScheme::RsaPssRsaeSha512)
.ok()
);
self
}
pub fn with_server_name(mut self, server_name: &'a str) -> Self {
self.server_name = Some(server_name);
self
}
/// Configure ALPN protocol names to send in the ClientHello.
///
/// The server will select one of the offered protocols and echo it back
/// in EncryptedExtensions. This is required for endpoints that multiplex
/// protocols on a single port (e.g. AWS IoT Core MQTT over port 443).
pub fn with_alpn(mut self, protocols: &'a [&'a [u8]]) -> Self {
self.alpn_protocols = Some(protocols);
self
}
/// Configures the maximum plaintext fragment size.
///
/// This option may help reduce memory size, as smaller fragment lengths require smaller
/// read/write buffers. Note that mote-tls does not currently use this option to fragment
/// writes. Note that the buffers need to include some overhead over the configured fragment
/// length.
///
/// From [RFC 6066, Section 4. Maximum Fragment Length Negotiation](https://www.rfc-editor.org/rfc/rfc6066#page-8):
///
/// > Without this extension, TLS specifies a fixed maximum plaintext
/// > fragment length of 2^14 bytes. It may be desirable for constrained
/// > clients to negotiate a smaller maximum fragment length due to memory
/// > limitations or bandwidth limitations.
///
/// > For example, if the negotiated length is 2^9=512, then, when using currently defined
/// > cipher suites ([...]) and null compression, the record-layer output can be at most
/// > 805 bytes: 5 bytes of headers, 512 bytes of application data, 256 bytes of padding,
/// > and 32 bytes of MAC.
pub fn with_max_fragment_length(mut self, max_fragment_length: MaxFragmentLength) -> Self {
self.max_fragment_length = Some(max_fragment_length);
self
}
/// Resets the max fragment length to 14 bits (16384).
pub fn reset_max_fragment_length(mut self) -> Self {
self.max_fragment_length = None;
self
}
pub fn with_psk(mut self, psk: &'a [u8], identities: &[&'a [u8]]) -> Self {
// TODO: Remove potential panic
self.psk = Some((psk, unwrap!(Vec::from_slice(identities).ok())));
self
}
}
impl Default for ConnectConfig<'_> {
fn default() -> Self {
ConnectConfig::new()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Certificate<D> {
X509(D),
RawPublicKey(D),
}

636
src/connection.rs Normal file
View File

@@ -0,0 +1,636 @@
use crate::config::{TlsCipherSuite, ConnectConfig};
use crate::handshake::{ClientHandshake, ServerHandshake};
use crate::key_schedule::{KeySchedule, ReadKeySchedule, WriteKeySchedule};
use crate::record::{ClientRecord, ServerRecord};
use crate::record_reader::RecordReader;
use crate::write_buffer::WriteBuffer;
use crate::{HandshakeVerify, CryptoBackend, ProtocolError, Verifier};
use crate::{
alert::{Alert, AlertDescription, AlertLevel},
handshake::{certificate::CertificateRef, certificate_request::CertificateRequest},
};
use core::fmt::Debug;
use digest::Digest;
use embedded_io::Error as _;
use embedded_io::{Read as BlockingRead, Write as BlockingWrite};
use embedded_io_async::{Read as AsyncRead, Write as AsyncWrite};
use crate::application_data::ApplicationData;
use crate::buffer::CryptoBuffer;
use digest::generic_array::typenum::Unsigned;
use p256::ecdh::EphemeralSecret;
use signature::SignerMut;
use crate::content_types::ContentType;
use crate::parse_buffer::ParseBuffer;
use aes_gcm::aead::{AeadCore, AeadInPlace, KeyInit};
pub(crate) fn decrypt_record<CipherSuite>(
key_schedule: &mut ReadKeySchedule<CipherSuite>,
record: ServerRecord<'_, CipherSuite>,
mut cb: impl FnMut(
&mut ReadKeySchedule<CipherSuite>,
ServerRecord<'_, CipherSuite>,
) -> Result<(), ProtocolError>,
) -> Result<(), ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
if let ServerRecord::ApplicationData(ApplicationData {
header,
data: mut app_data,
}) = record
{
let server_key = key_schedule.get_key()?;
let nonce = key_schedule.get_nonce()?;
let crypto = <CipherSuite::Cipher as KeyInit>::new(server_key);
crypto
.decrypt_in_place(&nonce, header.data(), &mut app_data)
.map_err(|_| ProtocolError::CryptoError)?;
let padding = app_data
.as_slice()
.iter()
.enumerate()
.rfind(|(_, b)| **b != 0);
if let Some((index, _)) = padding {
app_data.truncate(index + 1);
};
let content_type =
ContentType::of(*app_data.as_slice().last().unwrap()).ok_or(ProtocolError::InvalidRecord)?;
trace!("Decrypting: content type = {:?}", content_type);
// Remove the content type
app_data.truncate(app_data.len() - 1);
let mut buf = ParseBuffer::new(app_data.as_slice());
match content_type {
ContentType::Handshake => {
// Decode potentially coalesced handshake messages
while buf.remaining() > 0 {
let inner = ServerHandshake::read(&mut buf, key_schedule.transcript_hash())?;
cb(key_schedule, ServerRecord::Handshake(inner))?;
}
}
ContentType::ApplicationData => {
let inner = ApplicationData::new(app_data, header);
cb(key_schedule, ServerRecord::ApplicationData(inner))?;
}
ContentType::Alert => {
let alert = Alert::parse(&mut buf)?;
cb(key_schedule, ServerRecord::Alert(alert))?;
}
_ => return Err(ProtocolError::Unimplemented),
}
key_schedule.increment_counter();
} else {
trace!("Not decrypting: content_type = {:?}", record.content_type());
cb(key_schedule, record)?;
}
Ok(())
}
pub(crate) fn encrypt<CipherSuite>(
key_schedule: &WriteKeySchedule<CipherSuite>,
buf: &mut CryptoBuffer<'_>,
) -> Result<(), ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
let client_key = key_schedule.get_key()?;
let nonce = key_schedule.get_nonce()?;
// trace!("encrypt key {:02x?}", client_key);
// trace!("encrypt nonce {:02x?}", nonce);
// trace!("plaintext {} {:02x?}", buf.len(), buf.as_slice(),);
//let crypto = Aes128Gcm::new_varkey(&self.key_schedule.get_client_key()).unwrap();
let crypto = <CipherSuite::Cipher as KeyInit>::new(client_key);
let len = buf.len() + <CipherSuite::Cipher as AeadCore>::TagSize::to_usize();
if len > buf.capacity() {
return Err(ProtocolError::InsufficientSpace);
}
trace!("output size {}", len);
let len_bytes = (len as u16).to_be_bytes();
let additional_data = [
ContentType::ApplicationData as u8,
0x03,
0x03,
len_bytes[0],
len_bytes[1],
];
crypto
.encrypt_in_place(&nonce, &additional_data, buf)
.map_err(|_| ProtocolError::InvalidApplicationData)
}
pub struct Handshake<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
traffic_hash: Option<CipherSuite::Hash>,
secret: Option<EphemeralSecret>,
certificate_request: Option<CertificateRequest>,
}
impl<CipherSuite> Handshake<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub fn new() -> Handshake<CipherSuite> {
Handshake {
traffic_hash: None,
secret: None,
certificate_request: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum State {
ClientHello,
ServerHello,
ServerVerify,
ClientCert,
ClientCertVerify,
ClientFinished,
ApplicationData,
}
impl<'a> State {
#[allow(clippy::too_many_arguments)]
pub async fn process<'v, Transport, CP>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CP::CipherSuite>,
record_reader: &mut RecordReader<'_>,
tx_buf: &mut WriteBuffer<'_>,
key_schedule: &mut KeySchedule<CP::CipherSuite>,
config: &ConnectConfig<'a>,
crypto_provider: &mut CP,
) -> Result<State, ProtocolError>
where
Transport: AsyncRead + AsyncWrite + 'a,
CP: CryptoBackend,
{
match self {
State::ClientHello => {
let (state, tx) =
client_hello(key_schedule, config, crypto_provider, tx_buf, handshake)?;
respond(tx, transport, key_schedule).await?;
Ok(state)
}
State::ServerHello => {
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;
let result = process_server_hello(handshake, key_schedule, record);
handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ServerVerify => {
let record = record_reader
.read(transport, key_schedule.read_state())
.await?;
let result =
process_server_verify(handshake, key_schedule, crypto_provider, record);
handle_processing_error(result, transport, key_schedule, tx_buf).await
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, crypto_provider, tx_buf)?;
respond(tx, transport, key_schedule).await?;
Ok(state)
}
State::ClientCertVerify => {
let (result, tx) = client_cert_verify(key_schedule, crypto_provider, tx_buf)?;
respond(tx, transport, key_schedule).await?;
result
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
respond(tx, transport, key_schedule).await?;
client_finished_finalize(key_schedule, handshake)
}
State::ApplicationData => Ok(State::ApplicationData),
}
}
#[allow(clippy::too_many_arguments)]
pub fn process_blocking<'v, Transport, CP>(
self,
transport: &mut Transport,
handshake: &mut Handshake<CP::CipherSuite>,
record_reader: &mut RecordReader<'_>,
tx_buf: &mut WriteBuffer,
key_schedule: &mut KeySchedule<CP::CipherSuite>,
config: &ConnectConfig<'a>,
crypto_provider: &mut CP,
) -> Result<State, ProtocolError>
where
Transport: BlockingRead + BlockingWrite + 'a,
CP: CryptoBackend,
{
match self {
State::ClientHello => {
let (state, tx) =
client_hello(key_schedule, config, crypto_provider, tx_buf, handshake)?;
respond_blocking(tx, transport, key_schedule)?;
Ok(state)
}
State::ServerHello => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
let result = process_server_hello(handshake, key_schedule, record);
handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ServerVerify => {
let record = record_reader.read_blocking(transport, key_schedule.read_state())?;
let result =
process_server_verify(handshake, key_schedule, crypto_provider, record);
handle_processing_error_blocking(result, transport, key_schedule, tx_buf)
}
State::ClientCert => {
let (state, tx) = client_cert(handshake, key_schedule, crypto_provider, tx_buf)?;
respond_blocking(tx, transport, key_schedule)?;
Ok(state)
}
State::ClientCertVerify => {
let (result, tx) = client_cert_verify(key_schedule, crypto_provider, tx_buf)?;
respond_blocking(tx, transport, key_schedule)?;
result
}
State::ClientFinished => {
let tx = client_finished(key_schedule, tx_buf)?;
respond_blocking(tx, transport, key_schedule)?;
client_finished_finalize(key_schedule, handshake)
}
State::ApplicationData => Ok(State::ApplicationData),
}
}
}
fn handle_processing_error_blocking<CipherSuite>(
result: Result<State, ProtocolError>,
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer,
) -> Result<State, ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(ProtocolError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;
respond_blocking(tx, transport, key_schedule)?;
}
result
}
fn respond_blocking<CipherSuite>(
tx: &[u8],
transport: &mut impl BlockingWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
) -> Result<(), ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
transport
.write_all(tx)
.map_err(|e| ProtocolError::Io(e.kind()))?;
key_schedule.write_state().increment_counter();
transport.flush().map_err(|e| ProtocolError::Io(e.kind()))?;
Ok(())
}
async fn handle_processing_error<CipherSuite>(
result: Result<State, ProtocolError>,
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
tx_buf: &mut WriteBuffer<'_>,
) -> Result<State, ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
if let Err(ProtocolError::AbortHandshake(level, description)) = result {
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let tx = tx_buf.write_record(
&ClientRecord::Alert(Alert { level, description }, false),
write_key_schedule,
Some(read_key_schedule),
)?;
respond(tx, transport, key_schedule).await?;
}
result
}
async fn respond<CipherSuite>(
tx: &[u8],
transport: &mut impl AsyncWrite,
key_schedule: &mut KeySchedule<CipherSuite>,
) -> Result<(), ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
transport
.write_all(tx)
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
key_schedule.write_state().increment_counter();
transport
.flush()
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
Ok(())
}
fn client_hello<'r, CP>(
key_schedule: &mut KeySchedule<CP::CipherSuite>,
config: &ConnectConfig,
crypto_provider: &mut CP,
tx_buf: &'r mut WriteBuffer,
handshake: &mut Handshake<CP::CipherSuite>,
) -> Result<(State, &'r [u8]), ProtocolError>
where
CP: CryptoBackend,
{
key_schedule.initialize_early_secret(config.psk.as_ref().map(|p| p.0))?;
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
let client_hello = ClientRecord::client_hello(config, crypto_provider);
let slice = tx_buf.write_record(&client_hello, write_key_schedule, Some(read_key_schedule))?;
if let ClientRecord::Handshake(ClientHandshake::ClientHello(client_hello), _) = client_hello {
handshake.secret.replace(client_hello.secret);
Ok((State::ServerHello, slice))
} else {
Err(ProtocolError::EncodeError)
}
}
fn process_server_hello<CipherSuite>(
handshake: &mut Handshake<CipherSuite>,
key_schedule: &mut KeySchedule<CipherSuite>,
record: ServerRecord<'_, CipherSuite>,
) -> Result<State, ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
match record {
ServerRecord::Handshake(server_handshake) => match server_handshake {
ServerHandshake::ServerHello(server_hello) => {
trace!("********* ServerHello");
let secret = handshake.secret.take().ok_or(ProtocolError::InvalidHandshake)?;
let shared = server_hello
.calculate_shared_secret(&secret)
.ok_or(ProtocolError::InvalidKeyShare)?;
key_schedule.initialize_handshake_secret(shared.raw_secret_bytes())?;
Ok(State::ServerVerify)
}
_ => Err(ProtocolError::InvalidHandshake),
},
ServerRecord::Alert(alert) => {
Err(ProtocolError::HandshakeAborted(alert.level, alert.description))
}
_ => Err(ProtocolError::InvalidRecord),
}
}
fn process_server_verify<CP>(
handshake: &mut Handshake<CP::CipherSuite>,
key_schedule: &mut KeySchedule<CP::CipherSuite>,
crypto_provider: &mut CP,
record: ServerRecord<'_, CP::CipherSuite>,
) -> Result<State, ProtocolError>
where
CP: CryptoBackend,
{
let mut state = State::ServerVerify;
decrypt_record(key_schedule.read_state(), record, |key_schedule, record| {
match record {
ServerRecord::Handshake(server_handshake) => {
match server_handshake {
ServerHandshake::EncryptedExtensions(_) => {}
ServerHandshake::Certificate(certificate) => {
let transcript = key_schedule.transcript_hash();
if let Ok(verifier) = crypto_provider.verifier() {
verifier.verify_certificate(transcript, certificate)?;
debug!("Certificate verified!");
} else {
debug!("Certificate verification skipped due to no verifier!");
}
}
ServerHandshake::HandshakeVerify(verify) => {
if let Ok(verifier) = crypto_provider.verifier() {
verifier.verify_signature(verify)?;
debug!("Signature verified!");
} else {
debug!("Signature verification skipped due to no verifier!");
}
}
ServerHandshake::CertificateRequest(request) => {
handshake.certificate_request.replace(request.try_into()?);
}
ServerHandshake::Finished(finished) => {
if !key_schedule.verify_server_finished(&finished)? {
warn!("Server signature verification failed");
return Err(ProtocolError::InvalidSignature);
}
// trace!("server verified {}", verified);
state = if handshake.certificate_request.is_some() {
State::ClientCert
} else {
handshake
.traffic_hash
.replace(key_schedule.transcript_hash().clone());
State::ClientFinished
};
}
_ => return Err(ProtocolError::InvalidHandshake),
}
}
ServerRecord::ChangeCipherSpec(_) => {}
_ => return Err(ProtocolError::InvalidRecord),
}
Ok(())
})?;
Ok(state)
}
fn client_cert<'r, CP>(
handshake: &mut Handshake<CP::CipherSuite>,
key_schedule: &mut KeySchedule<CP::CipherSuite>,
crypto_provider: &mut CP,
buffer: &'r mut WriteBuffer,
) -> Result<(State, &'r [u8]), ProtocolError>
where
CP: CryptoBackend,
{
handshake
.traffic_hash
.replace(key_schedule.transcript_hash().clone());
let request_context = &handshake
.certificate_request
.as_ref()
.ok_or(ProtocolError::InvalidHandshake)?
.request_context;
// Declare cert before certificate so owned data outlives the CertificateRef that borrows it
let cert = crypto_provider.client_cert();
let mut certificate = CertificateRef::with_context(request_context);
let next_state = if let Some(ref cert) = cert {
certificate.add(cert.into())?;
State::ClientCertVerify
} else {
State::ClientFinished
};
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
buffer
.write_record(
&ClientRecord::Handshake(ClientHandshake::ClientCert(certificate), true),
write_key_schedule,
Some(read_key_schedule),
)
.map(|slice| (next_state, slice))
}
fn client_cert_verify<'r, CP>(
key_schedule: &mut KeySchedule<CP::CipherSuite>,
crypto_provider: &mut CP,
buffer: &'r mut WriteBuffer,
) -> Result<(Result<State, ProtocolError>, &'r [u8]), ProtocolError>
where
CP: CryptoBackend,
{
let (result, record) = match crypto_provider.signer() {
Ok((mut signing_key, signature_scheme)) => {
let ctx_str = b"TLS 1.3, client CertificateVerify\x00";
// 64 (pad) + 34 (ctx) + 48 (SHA-384) = 146 bytes required
let mut msg: heapless::Vec<u8, 146> = heapless::Vec::new();
msg.resize(64, 0x20).map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(ctx_str)
.map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(&key_schedule.transcript_hash().clone().finalize())
.map_err(|_| ProtocolError::EncodeError)?;
let signature = signing_key.sign(&msg);
trace!(
"Signature: {:?} ({})",
signature.as_ref(),
signature.as_ref().len()
);
let certificate_verify = HandshakeVerify {
signature_scheme,
signature: heapless::Vec::from_slice(signature.as_ref()).unwrap(),
};
(
Ok(State::ClientFinished),
ClientRecord::Handshake(
ClientHandshake::ClientCertVerify(certificate_verify),
true,
),
)
}
Err(e) => {
error!("Failed to obtain signing key: {:?}", e);
(
Err(e),
ClientRecord::Alert(
Alert::new(AlertLevel::Warning, AlertDescription::CloseNotify),
true,
),
)
}
};
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
buffer
.write_record(&record, write_key_schedule, Some(read_key_schedule))
.map(|slice| (result, slice))
}
fn client_finished<'r, CipherSuite>(
key_schedule: &mut KeySchedule<CipherSuite>,
buffer: &'r mut WriteBuffer,
) -> Result<&'r [u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
let client_finished = key_schedule
.create_client_finished()
.map_err(|_| ProtocolError::InvalidHandshake)?;
let (write_key_schedule, read_key_schedule) = key_schedule.as_split();
buffer.write_record(
&ClientRecord::Handshake(ClientHandshake::Finished(client_finished), true),
write_key_schedule,
Some(read_key_schedule),
)
}
fn client_finished_finalize<CipherSuite>(
key_schedule: &mut KeySchedule<CipherSuite>,
handshake: &mut Handshake<CipherSuite>,
) -> Result<State, ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
key_schedule.replace_transcript_hash(
handshake
.traffic_hash
.take()
.ok_or(ProtocolError::InvalidHandshake)?,
);
key_schedule.initialize_master_secret()?;
Ok(State::ApplicationData)
}

22
src/content_types.rs Normal file
View File

@@ -0,0 +1,22 @@
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ContentType {
Invalid = 0,
ChangeCipherSpec = 20,
Alert = 21,
Handshake = 22,
ApplicationData = 23,
}
impl ContentType {
pub fn of(num: u8) -> Option<Self> {
match num {
0 => Some(Self::Invalid),
20 => Some(Self::ChangeCipherSpec),
21 => Some(Self::Alert),
22 => Some(Self::Handshake),
23 => Some(Self::ApplicationData),
_ => None,
}
}
}

View File

@@ -0,0 +1,56 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
/// ALPN protocol name list per RFC 7301, Section 3.1.
///
/// Wire format:
/// ```text
/// opaque ProtocolName<1..2^8-1>;
///
/// struct {
/// ProtocolName protocol_name_list<2..2^16-1>
/// } ProtocolNameList;
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct AlpnProtocolNameList<'a> {
pub protocols: &'a [&'a [u8]],
}
impl<'a> AlpnProtocolNameList<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
// We parse but don't store the individual protocol names in a heapless
// container — just validate the wire format. The slice reference is kept
// for the lifetime of the parse buffer, but since we can't reconstruct
// `&[&[u8]]` from a flat buffer without allocation, we store an empty
// slice. Callers that need the parsed protocols (server-side) would need
// a different approach; for our client-side use we only need encode().
let list_len = buf.read_u16()? as usize;
let mut list_buf = buf.slice(list_len)?;
while !list_buf.is_empty() {
let name_len = list_buf.read_u8()? as usize;
if name_len == 0 {
return Err(ParseError::InvalidData);
}
let _name = list_buf.slice(name_len)?;
}
Ok(Self { protocols: &[] })
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
// Outer u16 length prefix for the ProtocolNameList
buf.with_u16_length(|buf| {
for protocol in self.protocols {
buf.push(protocol.len() as u8)
.map_err(|_| ProtocolError::EncodeError)?;
buf.extend_from_slice(protocol)?;
}
Ok(())
})
}
}

View File

@@ -0,0 +1,135 @@
use heapless::Vec;
use crate::buffer::CryptoBuffer;
use crate::extensions::extension_data::supported_groups::NamedGroup;
use crate::ProtocolError;
use crate::parse_buffer::{ParseBuffer, ParseError};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KeyShareServerHello<'a>(pub KeyShareEntry<'a>);
impl<'a> KeyShareServerHello<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
Ok(KeyShareServerHello(KeyShareEntry::parse(buf)?))
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
self.0.encode(buf)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KeyShareClientHello<'a, const N: usize> {
pub client_shares: Vec<KeyShareEntry<'a>, N>,
}
impl<'a, const N: usize> KeyShareClientHello<'a, N> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
let len = buf.read_u16()? as usize;
Ok(KeyShareClientHello {
client_shares: buf.read_list(len, KeyShareEntry::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for client_share in &self.client_shares {
client_share.encode(buf)?;
}
Ok(())
})
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KeyShareHelloRetryRequest {
pub selected_group: NamedGroup,
}
#[allow(dead_code)]
impl KeyShareHelloRetryRequest {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
Ok(Self {
selected_group: NamedGroup::parse(buf)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
self.selected_group.encode(buf)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct KeyShareEntry<'a> {
pub(crate) group: NamedGroup,
pub(crate) opaque: &'a [u8],
}
impl<'a> KeyShareEntry<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
let group = NamedGroup::parse(buf)?;
let opaque_len = buf.read_u16()?;
let opaque = buf.slice(opaque_len as usize)?;
Ok(Self {
group,
opaque: opaque.as_slice(),
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
self.group.encode(buf)?;
buf.with_u16_length(|buf| buf.extend_from_slice(self.opaque))
.map_err(|_| ProtocolError::EncodeError)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Once;
static INIT: Once = Once::new();
fn setup() {
INIT.call_once(|| {
env_logger::init();
});
}
#[test]
fn test_parse_empty() {
setup();
let buffer = [
0x00, 0x17, // Secp256r1
0x00, 0x00, // key_exchange length = 0 bytes
];
let result = KeyShareEntry::parse(&mut ParseBuffer::new(&buffer)).unwrap();
assert_eq!(NamedGroup::Secp256r1, result.group);
assert_eq!(0, result.opaque.len());
}
#[test]
fn test_parse() {
setup();
let buffer = [
0x00, 0x17, // Secp256r1
0x00, 0x02, // key_exchange length = 2 bytes
0xAA, 0xBB,
];
let result = KeyShareEntry::parse(&mut ParseBuffer::new(&buffer)).unwrap();
assert_eq!(NamedGroup::Secp256r1, result.group);
assert_eq!(2, result.opaque.len());
assert_eq!([0xAA, 0xBB], result.opaque);
}
}

View File

@@ -0,0 +1,44 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
/// Maximum plaintext fragment length
///
/// RFC 6066, Section 4. Maximum Fragment Length Negotiation
/// Without this extension, TLS specifies a fixed maximum plaintext
/// fragment length of 2^14 bytes. It may be desirable for constrained
/// clients to negotiate a smaller maximum fragment length due to memory
/// limitations or bandwidth limitations.
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum MaxFragmentLength {
/// 512 bytes
Bits9 = 1,
/// 1024 bytes
Bits10 = 2,
/// 2048 bytes
Bits11 = 3,
/// 4096 bytes
Bits12 = 4,
}
impl MaxFragmentLength {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u8()? {
1 => Ok(Self::Bits9),
2 => Ok(Self::Bits10),
3 => Ok(Self::Bits11),
4 => Ok(Self::Bits12),
other => {
warn!("Read unknown MaxFragmentLength: {}", other);
Err(ParseError::InvalidData)
}
}
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push(*self as u8).map_err(|_| ProtocolError::EncodeError)
}
}

View File

@@ -0,0 +1,11 @@
pub mod alpn;
pub mod key_share;
pub mod max_fragment_length;
pub mod pre_shared_key;
pub mod psk_key_exchange_modes;
pub mod server_name;
pub mod signature_algorithms;
pub mod signature_algorithms_cert;
pub mod supported_groups;
pub mod supported_versions;
pub mod unimplemented;

View File

@@ -0,0 +1,63 @@
use crate::buffer::CryptoBuffer;
use crate::ProtocolError;
use crate::parse_buffer::{ParseBuffer, ParseError};
use heapless::Vec;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PreSharedKeyClientHello<'a, const N: usize> {
pub identities: Vec<&'a [u8], N>,
pub hash_size: usize,
}
impl<const N: usize> PreSharedKeyClientHello<'_, N> {
pub fn parse(_buf: &mut ParseBuffer) -> Result<Self, ParseError> {
unimplemented!()
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for identity in &self.identities {
buf.with_u16_length(|buf| buf.extend_from_slice(identity))
.map_err(|_| ProtocolError::EncodeError)?;
// NOTE: No support for ticket age, set to 0 as recommended by RFC
buf.push_u32(0).map_err(|_| ProtocolError::EncodeError)?;
}
Ok(())
})
.map_err(|_| ProtocolError::EncodeError)?;
// NOTE: We encode binders later after computing the transcript.
let binders_len = (1 + self.hash_size) * self.identities.len();
buf.push_u16(binders_len as u16)
.map_err(|_| ProtocolError::EncodeError)?;
for _ in 0..binders_len {
buf.push(0).map_err(|_| ProtocolError::EncodeError)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PreSharedKeyServerHello {
pub selected_identity: u16,
}
impl PreSharedKeyServerHello {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
Ok(Self {
selected_identity: buf.read_u16()?,
})
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push_u16(self.selected_identity)
.map_err(|_| ProtocolError::EncodeError)
}
}

View File

@@ -0,0 +1,53 @@
use crate::buffer::CryptoBuffer;
use crate::ProtocolError;
use crate::parse_buffer::{ParseBuffer, ParseError};
use heapless::Vec;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum PskKeyExchangeMode {
PskKe = 0,
PskDheKe = 1,
}
impl PskKeyExchangeMode {
fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u8()? {
0 => Ok(Self::PskKe),
1 => Ok(Self::PskDheKe),
other => {
warn!("Read unknown PskKeyExchangeMode: {}", other);
Err(ParseError::InvalidData)
}
}
}
fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push(self as u8).map_err(|_| ProtocolError::EncodeError)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct PskKeyExchangeModes<const N: usize> {
pub modes: Vec<PskKeyExchangeMode, N>,
}
impl<const N: usize> PskKeyExchangeModes<N> {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
let data_length = buf.read_u8()? as usize;
Ok(Self {
modes: buf.read_list::<_, N>(data_length, PskKeyExchangeMode::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u8_length(|buf| {
for mode in &self.modes {
mode.encode(buf)?;
}
Ok(())
})
}
}

View File

@@ -0,0 +1,133 @@
use heapless::Vec;
use crate::{
ProtocolError,
buffer::CryptoBuffer,
extensions::ExtensionType,
parse_buffer::{ParseBuffer, ParseError},
};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum NameType {
HostName = 0,
}
impl NameType {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u8()? {
0 => Ok(Self::HostName),
other => {
warn!("Read unknown NameType: {}", other);
Err(ParseError::InvalidData)
}
}
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push(self as u8).map_err(|_| ProtocolError::EncodeError)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ServerName<'a> {
pub name_type: NameType,
pub name: &'a str,
}
impl<'a> ServerName<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<ServerName<'a>, ParseError> {
let name_type = NameType::parse(buf)?;
let name_len = buf.read_u16()?;
let name = buf.slice(name_len as usize)?.as_slice();
// RFC 6066, Section 3. Server Name Indication
// The hostname is represented as a byte
// string using ASCII encoding without a trailing dot.
if name.is_ascii() {
Ok(ServerName {
name_type,
name: core::str::from_utf8(name).map_err(|_| ParseError::InvalidData)?,
})
} else {
Err(ParseError::InvalidData)
}
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
self.name_type.encode(buf)?;
buf.with_u16_length(|buf| buf.extend_from_slice(self.name.as_bytes()))
.map_err(|_| ProtocolError::EncodeError)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ServerNameList<'a, const N: usize> {
pub names: Vec<ServerName<'a>, N>,
}
impl<'a> ServerNameList<'a, 1> {
pub fn single(server_name: &'a str) -> Self {
let mut names = Vec::<_, 1>::new();
names
.push(ServerName {
name_type: NameType::HostName,
name: server_name,
})
.unwrap();
ServerNameList { names }
}
}
impl<'a, const N: usize> ServerNameList<'a, N> {
#[allow(dead_code)]
pub const EXTENSION_TYPE: ExtensionType = ExtensionType::ServerName;
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<ServerNameList<'a, N>, ParseError> {
let data_length = buf.read_u16()? as usize;
Ok(Self {
names: buf.read_list::<_, N>(data_length, ServerName::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for name in &self.names {
name.encode(buf)?;
}
Ok(())
})
}
}
// RFC 6066, Section 3. Server Name Indication
// A server that receives a client hello containing the "server_name"
// extension [..]. In this event, the server
// SHALL include an extension of type "server_name" in the (extended)
// server hello. The "extension_data" field of this extension SHALL be
// empty.
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ServerNameResponse;
impl ServerNameResponse {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
if buf.is_empty() {
Ok(Self)
} else {
Err(ParseError::InvalidData)
}
}
#[allow(clippy::unused_self, clippy::unnecessary_wraps)]
pub fn encode(&self, _buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
Ok(())
}
}

View File

@@ -0,0 +1,164 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
use heapless::Vec;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum SignatureScheme {
/* RSASSA-PKCS1-v1_5 algorithms */
RsaPkcs1Sha256,
RsaPkcs1Sha384,
RsaPkcs1Sha512,
/* ECDSA algorithms */
EcdsaSecp256r1Sha256,
EcdsaSecp384r1Sha384,
EcdsaSecp521r1Sha512,
/* RSASSA-PSS algorithms with public key OID rsaEncryption */
RsaPssRsaeSha256,
RsaPssRsaeSha384,
RsaPssRsaeSha512,
/* EdDSA algorithms */
Ed25519,
Ed448,
/* RSASSA-PSS algorithms with public key OID RSASSA-PSS */
RsaPssPssSha256,
RsaPssPssSha384,
RsaPssPssSha512,
Sha224Ecdsa,
Sha224Rsa,
Sha224Dsa,
/* Legacy algorithms */
RsaPkcs1Sha1,
EcdsaSha1,
/* Brainpool */
Sha256BrainpoolP256r1,
Sha384BrainpoolP384r1,
Sha512BrainpoolP512r1,
/* ML-DSA */
MlDsa44,
MlDsa65,
MlDsa87,
/* Reserved Code Points */
//private_use(0xFE00..0xFFFF),
//(0xFFFF)
}
impl SignatureScheme {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u16()? {
0x0401 => Ok(Self::RsaPkcs1Sha256),
0x0501 => Ok(Self::RsaPkcs1Sha384),
0x0601 => Ok(Self::RsaPkcs1Sha512),
0x0403 => Ok(Self::EcdsaSecp256r1Sha256),
0x0503 => Ok(Self::EcdsaSecp384r1Sha384),
0x0603 => Ok(Self::EcdsaSecp521r1Sha512),
0x0804 => Ok(Self::RsaPssRsaeSha256),
0x0805 => Ok(Self::RsaPssRsaeSha384),
0x0806 => Ok(Self::RsaPssRsaeSha512),
0x0807 => Ok(Self::Ed25519),
0x0808 => Ok(Self::Ed448),
0x0809 => Ok(Self::RsaPssPssSha256),
0x080a => Ok(Self::RsaPssPssSha384),
0x080b => Ok(Self::RsaPssPssSha512),
0x0303 => Ok(Self::Sha224Ecdsa),
0x0301 => Ok(Self::Sha224Rsa),
0x0302 => Ok(Self::Sha224Dsa),
0x0201 => Ok(Self::RsaPkcs1Sha1),
0x0203 => Ok(Self::EcdsaSha1),
0x081A => Ok(Self::Sha256BrainpoolP256r1),
0x081B => Ok(Self::Sha384BrainpoolP384r1),
0x081C => Ok(Self::Sha512BrainpoolP512r1),
0x0904 => Ok(Self::MlDsa44),
0x0905 => Ok(Self::MlDsa65),
0x0906 => Ok(Self::MlDsa87),
_ => Err(ParseError::InvalidData),
}
}
#[must_use]
pub fn as_u16(self) -> u16 {
match self {
Self::RsaPkcs1Sha256 => 0x0401,
Self::RsaPkcs1Sha384 => 0x0501,
Self::RsaPkcs1Sha512 => 0x0601,
Self::EcdsaSecp256r1Sha256 => 0x0403,
Self::EcdsaSecp384r1Sha384 => 0x0503,
Self::EcdsaSecp521r1Sha512 => 0x0603,
Self::RsaPssRsaeSha256 => 0x0804,
Self::RsaPssRsaeSha384 => 0x0805,
Self::RsaPssRsaeSha512 => 0x0806,
Self::Ed25519 => 0x0807,
Self::Ed448 => 0x0808,
Self::RsaPssPssSha256 => 0x0809,
Self::RsaPssPssSha384 => 0x080a,
Self::RsaPssPssSha512 => 0x080b,
Self::Sha224Ecdsa => 0x0303,
Self::Sha224Rsa => 0x0301,
Self::Sha224Dsa => 0x0302,
Self::RsaPkcs1Sha1 => 0x0201,
Self::EcdsaSha1 => 0x0203,
Self::Sha256BrainpoolP256r1 => 0x081A,
Self::Sha384BrainpoolP384r1 => 0x081B,
Self::Sha512BrainpoolP512r1 => 0x081C,
Self::MlDsa44 => 0x0904,
Self::MlDsa65 => 0x0905,
Self::MlDsa87 => 0x0906,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SignatureAlgorithms<const N: usize> {
pub supported_signature_algorithms: Vec<SignatureScheme, N>,
}
impl<const N: usize> SignatureAlgorithms<N> {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
let data_length = buf.read_u16()? as usize;
Ok(Self {
supported_signature_algorithms: buf
.read_list::<_, N>(data_length, SignatureScheme::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for &a in &self.supported_signature_algorithms {
buf.push_u16(a.as_u16())
.map_err(|_| ProtocolError::EncodeError)?;
}
Ok(())
})
}
}

View File

@@ -0,0 +1,34 @@
use crate::buffer::CryptoBuffer;
use crate::extensions::extension_data::signature_algorithms::SignatureScheme;
use crate::ProtocolError;
use crate::parse_buffer::{ParseBuffer, ParseError};
use heapless::Vec;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SignatureAlgorithmsCert<const N: usize> {
pub supported_signature_algorithms: Vec<SignatureScheme, N>,
}
impl<const N: usize> SignatureAlgorithmsCert<N> {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
let data_length = buf.read_u16()? as usize;
Ok(Self {
supported_signature_algorithms: buf
.read_list::<_, N>(data_length, SignatureScheme::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for &a in &self.supported_signature_algorithms {
buf.push_u16(a.as_u16())
.map_err(|_| ProtocolError::EncodeError)?;
}
Ok(())
})
}
}

View File

@@ -0,0 +1,104 @@
use heapless::Vec;
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
#[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum NamedGroup {
/* Elliptic Curve Groups (ECDHE) */
Secp256r1,
Secp384r1,
Secp521r1,
X25519,
X448,
/* Finite Field Groups (DHE) */
Ffdhe2048,
Ffdhe3072,
Ffdhe4096,
Ffdhe6144,
Ffdhe8192,
/* Post-quantum hybrid groups */
X25519MLKEM768,
SecP256r1MLKEM768,
SecP384r1MLKEM1024,
}
impl NamedGroup {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u16()? {
0x0017 => Ok(Self::Secp256r1),
0x0018 => Ok(Self::Secp384r1),
0x0019 => Ok(Self::Secp521r1),
0x001D => Ok(Self::X25519),
0x001E => Ok(Self::X448),
0x0100 => Ok(Self::Ffdhe2048),
0x0101 => Ok(Self::Ffdhe3072),
0x0102 => Ok(Self::Ffdhe4096),
0x0103 => Ok(Self::Ffdhe6144),
0x0104 => Ok(Self::Ffdhe8192),
0x11EB => Ok(Self::SecP256r1MLKEM768),
0x11EC => Ok(Self::X25519MLKEM768),
0x11ED => Ok(Self::SecP384r1MLKEM1024),
_ => Err(ParseError::InvalidData),
}
}
pub fn as_u16(self) -> u16 {
match self {
Self::Secp256r1 => 0x0017,
Self::Secp384r1 => 0x0018,
Self::Secp521r1 => 0x0019,
Self::X25519 => 0x001D,
Self::X448 => 0x001E,
Self::Ffdhe2048 => 0x0100,
Self::Ffdhe3072 => 0x0101,
Self::Ffdhe4096 => 0x0102,
Self::Ffdhe6144 => 0x0103,
Self::Ffdhe8192 => 0x0104,
Self::SecP256r1MLKEM768 => 0x11EB,
Self::X25519MLKEM768 => 0x11EC,
Self::SecP384r1MLKEM1024 => 0x11ED,
}
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push_u16(self.as_u16())
.map_err(|_| ProtocolError::EncodeError)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SupportedGroups<const N: usize> {
pub supported_groups: Vec<NamedGroup, N>,
}
impl<const N: usize> SupportedGroups<N> {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
let data_length = buf.read_u16()? as usize;
Ok(Self {
supported_groups: buf.read_list::<_, N>(data_length, NamedGroup::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u16_length(|buf| {
for g in &self.supported_groups {
g.encode(buf)?;
}
Ok(())
})
}
}

View File

@@ -0,0 +1,65 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
use heapless::Vec;
#[derive(Clone, Copy, PartialEq, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ProtocolVersion(u16);
impl ProtocolVersion {
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push_u16(self.0).map_err(|_| ProtocolError::EncodeError)
}
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
buf.read_u16().map(Self)
}
}
pub const TLS13: ProtocolVersion = ProtocolVersion(0x0304);
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SupportedVersionsClientHello<const N: usize> {
pub versions: Vec<ProtocolVersion, N>,
}
impl<const N: usize> SupportedVersionsClientHello<N> {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
let data_length = buf.read_u8()? as usize;
Ok(Self {
versions: buf.read_list::<_, N>(data_length, ProtocolVersion::parse)?,
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.with_u8_length(|buf| {
for v in &self.versions {
v.encode(buf)?;
}
Ok(())
})
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct SupportedVersionsServerHello {
pub selected_version: ProtocolVersion,
}
impl SupportedVersionsServerHello {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
Ok(Self {
selected_version: ProtocolVersion::parse(buf)?,
})
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
self.selected_version.encode(buf)
}
}

View File

@@ -0,0 +1,24 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Unimplemented<'a> {
pub data: &'a [u8],
}
impl<'a> Unimplemented<'a> {
#[allow(clippy::unnecessary_wraps)]
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ParseError> {
Ok(Self {
data: buf.as_slice(),
})
}
pub fn encode(&self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.extend_from_slice(self.data)
}
}

View File

@@ -0,0 +1,102 @@
macro_rules! extension_group {
(pub enum $name:ident$(<$lt:lifetime>)? {
$($extension:ident($extension_data:ty)),+
}) => {
#[derive(Debug, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[allow(dead_code)] // extension_data may not be used
pub enum $name$(<$lt>)? {
$($extension($extension_data)),+
}
#[allow(dead_code)] // not all methods are used
impl$(<$lt>)? $name$(<$lt>)? {
pub fn extension_type(&self) -> crate::extensions::ExtensionType {
match self {
$(Self::$extension(_) => crate::extensions::ExtensionType::$extension),+
}
}
pub fn encode(&self, buf: &mut crate::buffer::CryptoBuffer) -> Result<(), crate::ProtocolError> {
self.extension_type().encode(buf)?;
buf.with_u16_length(|buf| match self {
$(Self::$extension(ext_data) => ext_data.encode(buf)),+
})
}
pub fn parse(buf: &mut crate::parse_buffer::ParseBuffer$(<$lt>)?) -> Result<Self, crate::ProtocolError> {
// Consume extension data even if we don't recognize the extension
let extension_type = crate::extensions::ExtensionType::parse(buf);
let data_len = buf.read_u16().map_err(|_| crate::ProtocolError::DecodeError)? as usize;
let mut ext_data = buf.slice(data_len).map_err(|_| crate::ProtocolError::DecodeError)?;
let ext_type = extension_type.map_err(|err| {
warn!("Failed to read extension type: {:?}", err);
match err {
crate::parse_buffer::ParseError::InvalidData => crate::ProtocolError::UnknownExtensionType,
_ => crate::ProtocolError::DecodeError,
}
})?;
debug!("Read extension type {:?}", ext_type);
trace!("Extension data length: {}", data_len);
match ext_type {
$(crate::extensions::ExtensionType::$extension => Ok(Self::$extension(<$extension_data>::parse(&mut ext_data).map_err(|err| {
warn!("Failed to parse extension data: {:?}", err);
crate::ProtocolError::DecodeError
})?)),)+
#[allow(unreachable_patterns)]
other => {
warn!("Read unexpected ExtensionType: {:?}", other);
// Section 4.2. Extensions
// If an implementation receives an extension
// which it recognizes and which is not specified for the message in
// which it appears, it MUST abort the handshake with an
// "illegal_parameter" alert.
Err(crate::ProtocolError::AbortHandshake(
crate::alert::AlertLevel::Fatal,
crate::alert::AlertDescription::IllegalParameter,
))
}
}
}
pub fn parse_vector<const N: usize>(
buf: &mut crate::parse_buffer::ParseBuffer$(<$lt>)?,
) -> Result<heapless::Vec<Self, N>, crate::ProtocolError> {
let extensions_len = buf
.read_u16()
.map_err(|_| crate::ProtocolError::InvalidExtensionsLength)?;
let mut ext_buf = buf.slice(extensions_len as usize)?;
let mut extensions = heapless::Vec::new();
while !ext_buf.is_empty() {
trace!("Extension buffer: {}", ext_buf.remaining());
match Self::parse(&mut ext_buf) {
Ok(extension) => {
extensions
.push(extension)
.map_err(|_| crate::ProtocolError::DecodeError)?;
}
Err(crate::ProtocolError::UnknownExtensionType) => {
// ignore unrecognized extension type
}
Err(err) => return Err(err),
}
}
trace!("Read {} extensions", extensions.len());
Ok(extensions)
}
}
};
}
// This re-export makes it possible to omit #[macro_export]
// https://stackoverflow.com/a/67140319
pub(crate) use extension_group;

106
src/extensions/messages.rs Normal file
View File

@@ -0,0 +1,106 @@
use crate::extensions::{
extension_data::{
alpn::AlpnProtocolNameList,
key_share::{KeyShareClientHello, KeyShareServerHello},
max_fragment_length::MaxFragmentLength,
pre_shared_key::{PreSharedKeyClientHello, PreSharedKeyServerHello},
psk_key_exchange_modes::PskKeyExchangeModes,
server_name::{ServerNameList, ServerNameResponse},
signature_algorithms::SignatureAlgorithms,
signature_algorithms_cert::SignatureAlgorithmsCert,
supported_groups::SupportedGroups,
supported_versions::{SupportedVersionsClientHello, SupportedVersionsServerHello},
unimplemented::Unimplemented,
},
extension_group_macro::extension_group,
};
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CH
extension_group! {
pub enum ClientHelloExtension<'a> {
ServerName(ServerNameList<'a, 1>),
SupportedVersions(SupportedVersionsClientHello<1>),
SignatureAlgorithms(SignatureAlgorithms<25>),
SupportedGroups(SupportedGroups<13>),
KeyShare(KeyShareClientHello<'a, 1>),
PreSharedKey(PreSharedKeyClientHello<'a, 4>),
PskKeyExchangeModes(PskKeyExchangeModes<4>),
SignatureAlgorithmsCert(SignatureAlgorithmsCert<25>),
MaxFragmentLength(MaxFragmentLength),
StatusRequest(Unimplemented<'a>),
UseSrtp(Unimplemented<'a>),
Heartbeat(Unimplemented<'a>),
ApplicationLayerProtocolNegotiation(AlpnProtocolNameList<'a>),
SignedCertificateTimestamp(Unimplemented<'a>),
ClientCertificateType(Unimplemented<'a>),
ServerCertificateType(Unimplemented<'a>),
Padding(Unimplemented<'a>),
EarlyData(Unimplemented<'a>),
Cookie(Unimplemented<'a>),
CertificateAuthorities(Unimplemented<'a>),
OidFilters(Unimplemented<'a>),
PostHandshakeAuth(Unimplemented<'a>)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with SH
extension_group! {
pub enum ServerHelloExtension<'a> {
KeyShare(KeyShareServerHello<'a>),
PreSharedKey(PreSharedKeyServerHello),
Cookie(Unimplemented<'a>), // temporary so we don't trip up on HelloRetryRequests
SupportedVersions(SupportedVersionsServerHello)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with EE
extension_group! {
pub enum EncryptedExtensionsExtension<'a> {
ServerName(ServerNameResponse),
MaxFragmentLength(MaxFragmentLength),
SupportedGroups(SupportedGroups<13>),
UseSrtp(Unimplemented<'a>),
Heartbeat(Unimplemented<'a>),
ApplicationLayerProtocolNegotiation(AlpnProtocolNameList<'a>),
ClientCertificateType(Unimplemented<'a>),
ServerCertificateType(Unimplemented<'a>),
EarlyData(Unimplemented<'a>)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CR
extension_group! {
pub enum CertificateRequestExtension<'a> {
StatusRequest(Unimplemented<'a>),
SignatureAlgorithms(SignatureAlgorithms<25>),
SignedCertificateTimestamp(Unimplemented<'a>),
CertificateAuthorities(Unimplemented<'a>),
OidFilters(Unimplemented<'a>),
SignatureAlgorithmsCert(Unimplemented<'a>),
CompressCertificate(Unimplemented<'a>)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with CT
extension_group! {
pub enum CertificateExtension<'a> {
StatusRequest(Unimplemented<'a>),
SignedCertificateTimestamp(Unimplemented<'a>)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with NST
extension_group! {
pub enum NewSessionTicketExtension<'a> {
EarlyData(Unimplemented<'a>)
}
}
// Source: https://www.rfc-editor.org/rfc/rfc8446#section-4.2 table, rows marked with HRR
extension_group! {
pub enum HelloRetryRequestExtension<'a> {
KeyShare(Unimplemented<'a>),
Cookie(Unimplemented<'a>),
SupportedVersions(Unimplemented<'a>)
}
}

80
src/extensions/mod.rs Normal file
View File

@@ -0,0 +1,80 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
parse_buffer::{ParseBuffer, ParseError},
};
mod extension_group_macro;
pub mod extension_data;
pub mod messages;
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ExtensionType {
ServerName = 0,
MaxFragmentLength = 1,
StatusRequest = 5,
SupportedGroups = 10,
SignatureAlgorithms = 13,
UseSrtp = 14,
Heartbeat = 15,
ApplicationLayerProtocolNegotiation = 16,
SignedCertificateTimestamp = 18,
ClientCertificateType = 19,
ServerCertificateType = 20,
Padding = 21,
CompressCertificate = 27,
PreSharedKey = 41,
EarlyData = 42,
SupportedVersions = 43,
Cookie = 44,
PskKeyExchangeModes = 45,
CertificateAuthorities = 47,
OidFilters = 48,
PostHandshakeAuth = 49,
SignatureAlgorithmsCert = 50,
KeyShare = 51,
}
impl ExtensionType {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u16()? {
v if v == Self::ServerName as u16 => Ok(Self::ServerName),
v if v == Self::MaxFragmentLength as u16 => Ok(Self::MaxFragmentLength),
v if v == Self::StatusRequest as u16 => Ok(Self::StatusRequest),
v if v == Self::SupportedGroups as u16 => Ok(Self::SupportedGroups),
v if v == Self::SignatureAlgorithms as u16 => Ok(Self::SignatureAlgorithms),
v if v == Self::UseSrtp as u16 => Ok(Self::UseSrtp),
v if v == Self::Heartbeat as u16 => Ok(Self::Heartbeat),
v if v == Self::ApplicationLayerProtocolNegotiation as u16 => {
Ok(Self::ApplicationLayerProtocolNegotiation)
}
v if v == Self::SignedCertificateTimestamp as u16 => {
Ok(Self::SignedCertificateTimestamp)
}
v if v == Self::ClientCertificateType as u16 => Ok(Self::ClientCertificateType),
v if v == Self::ServerCertificateType as u16 => Ok(Self::ServerCertificateType),
v if v == Self::Padding as u16 => Ok(Self::Padding),
v if v == Self::CompressCertificate as u16 => Ok(Self::CompressCertificate),
v if v == Self::PreSharedKey as u16 => Ok(Self::PreSharedKey),
v if v == Self::EarlyData as u16 => Ok(Self::EarlyData),
v if v == Self::SupportedVersions as u16 => Ok(Self::SupportedVersions),
v if v == Self::Cookie as u16 => Ok(Self::Cookie),
v if v == Self::PskKeyExchangeModes as u16 => Ok(Self::PskKeyExchangeModes),
v if v == Self::CertificateAuthorities as u16 => Ok(Self::CertificateAuthorities),
v if v == Self::OidFilters as u16 => Ok(Self::OidFilters),
v if v == Self::PostHandshakeAuth as u16 => Ok(Self::PostHandshakeAuth),
v if v == Self::SignatureAlgorithmsCert as u16 => Ok(Self::SignatureAlgorithmsCert),
v if v == Self::KeyShare as u16 => Ok(Self::KeyShare),
other => {
warn!("Read unknown ExtensionType: {}", other);
Err(ParseError::InvalidData)
}
}
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push_u16(self as u16).map_err(|_| ProtocolError::EncodeError)
}
}

223
src/fmt.rs Normal file
View File

@@ -0,0 +1,223 @@
#![macro_use]
#![allow(unused_macros)]
macro_rules! assert {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::assert!($($x)*);
#[cfg(feature = "defmt")]
::defmt::assert!($($x)*);
}
};
}
macro_rules! assert_eq {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::assert_eq!($($x)*);
#[cfg(feature = "defmt")]
::defmt::assert_eq!($($x)*);
}
};
}
macro_rules! assert_ne {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::assert_ne!($($x)*);
#[cfg(feature = "defmt")]
::defmt::assert_ne!($($x)*);
}
};
}
macro_rules! debug_assert {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::debug_assert!($($x)*);
#[cfg(feature = "defmt")]
::defmt::debug_assert!($($x)*);
}
};
}
macro_rules! debug_assert_eq {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::debug_assert_eq!($($x)*);
#[cfg(feature = "defmt")]
::defmt::debug_assert_eq!($($x)*);
}
};
}
macro_rules! debug_assert_ne {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::debug_assert_ne!($($x)*);
#[cfg(feature = "defmt")]
::defmt::debug_assert_ne!($($x)*);
}
};
}
macro_rules! todo {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::todo!($($x)*);
#[cfg(feature = "defmt")]
::defmt::todo!($($x)*);
}
};
}
macro_rules! unreachable {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::unreachable!($($x)*);
#[cfg(feature = "defmt")]
::defmt::unreachable!($($x)*);
}
};
}
macro_rules! panic {
($($x:tt)*) => {
{
#[cfg(not(feature = "defmt"))]
::core::panic!($($x)*);
#[cfg(feature = "defmt")]
::defmt::panic!($($x)*);
}
};
}
macro_rules! trace {
($s:literal $(, $x:expr)* $(,)?) => {
{
#[cfg(feature = "log")]
::log::trace!($s $(, $x)*);
#[cfg(feature = "defmt")]
::defmt::trace!($s $(, $x)*);
#[cfg(not(any(feature = "log", feature="defmt")))]
let _ = ($( & $x ),*);
}
};
}
macro_rules! debug {
($s:literal $(, $x:expr)* $(,)?) => {
{
#[cfg(feature = "log")]
::log::debug!($s $(, $x)*);
#[cfg(feature = "defmt")]
::defmt::debug!($s $(, $x)*);
#[cfg(not(any(feature = "log", feature="defmt")))]
let _ = ($( & $x ),*);
}
};
}
macro_rules! info {
($s:literal $(, $x:expr)* $(,)?) => {
{
#[cfg(feature = "log")]
::log::info!($s $(, $x)*);
#[cfg(feature = "defmt")]
::defmt::info!($s $(, $x)*);
#[cfg(not(any(feature = "log", feature="defmt")))]
let _ = ($( & $x ),*);
}
};
}
macro_rules! warn {
($s:literal $(, $x:expr)* $(,)?) => {
{
#[cfg(feature = "log")]
::log::warn!($s $(, $x)*);
#[cfg(feature = "defmt")]
::defmt::warn!($s $(, $x)*);
#[cfg(not(any(feature = "log", feature="defmt")))]
let _ = ($( & $x ),*);
}
};
}
macro_rules! error {
($s:literal $(, $x:expr)* $(,)?) => {
{
#[cfg(feature = "log")]
::log::error!($s $(, $x)*);
#[cfg(feature = "defmt")]
::defmt::error!($s $(, $x)*);
#[cfg(not(any(feature = "log", feature="defmt")))]
let _ = ($( & $x ),*);
}
};
}
#[cfg(feature = "defmt")]
macro_rules! unwrap {
($($x:tt)*) => {
::defmt::unwrap!($($x)*)
};
}
#[cfg(not(feature = "defmt"))]
macro_rules! unwrap {
($arg:expr) => {
match $crate::fmt::Try::into_result($arg) {
::core::result::Result::Ok(t) => t,
::core::result::Result::Err(e) => {
::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e);
}
}
};
($arg:expr, $($msg:expr),+ $(,)? ) => {
match $crate::fmt::Try::into_result($arg) {
::core::result::Result::Ok(t) => t,
::core::result::Result::Err(e) => {
::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e);
}
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct NoneError;
pub trait Try {
type Ok;
type Error;
fn into_result(self) -> Result<Self::Ok, Self::Error>;
}
impl<T> Try for Option<T> {
type Ok = T;
type Error = NoneError;
#[inline]
fn into_result(self) -> Result<T, NoneError> {
self.ok_or(NoneError)
}
}
impl<T, E> Try for Result<T, E> {
type Ok = T;
type Error = E;
#[inline]
fn into_result(self) -> Self {
self
}
}

39
src/handshake/binder.rs Normal file
View File

@@ -0,0 +1,39 @@
use crate::ProtocolError;
use crate::buffer::CryptoBuffer;
use core::fmt::{Debug, Formatter};
//use digest::generic_array::{ArrayLength, GenericArray};
use generic_array::{ArrayLength, GenericArray};
// use heapless::Vec;
pub struct PskBinder<N: ArrayLength<u8>> {
pub verify: GenericArray<u8, N>,
}
#[cfg(feature = "defmt")]
impl<N: ArrayLength<u8>> defmt::Format for PskBinder<N> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "verify length:{}", &self.verify.len());
}
}
impl<N: ArrayLength<u8>> Debug for PskBinder<N> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("PskBinder").finish()
}
}
impl<N: ArrayLength<u8>> PskBinder<N> {
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
let len = self.verify.len() as u8;
//buf.extend_from_slice(&[len[1], len[2], len[3]]);
buf.push(len).map_err(|_| ProtocolError::EncodeError)?;
buf.extend_from_slice(&self.verify[..self.verify.len()])
.map_err(|_| ProtocolError::EncodeError)?;
Ok(())
}
#[allow(dead_code)]
pub fn len() -> usize {
N::to_usize()
}
}

View File

@@ -0,0 +1,174 @@
use crate::ProtocolError;
use crate::buffer::CryptoBuffer;
use crate::extensions::messages::CertificateExtension;
use crate::parse_buffer::ParseBuffer;
use heapless::Vec;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct CertificateRef<'a> {
raw_entries: &'a [u8],
request_context: &'a [u8],
pub entries: Vec<CertificateEntryRef<'a>, 16>,
}
impl<'a> CertificateRef<'a> {
pub fn with_context(request_context: &'a [u8]) -> Self {
Self {
raw_entries: &[],
request_context,
entries: Vec::new(),
}
}
pub fn add(&mut self, entry: CertificateEntryRef<'a>) -> Result<(), ProtocolError> {
self.entries.push(entry).map_err(|_| {
error!("CertificateRef: InsufficientSpace");
ProtocolError::InsufficientSpace
})
}
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ProtocolError> {
let request_context_len = buf.read_u8().map_err(|_| ProtocolError::InvalidCertificate)?;
let request_context = buf
.slice(request_context_len as usize)
.map_err(|_| ProtocolError::InvalidCertificate)?;
let entries_len = buf.read_u24().map_err(|_| ProtocolError::InvalidCertificate)?;
let mut raw_entries = buf
.slice(entries_len as usize)
.map_err(|_| ProtocolError::InvalidCertificate)?;
let entries = CertificateEntryRef::parse_vector(&mut raw_entries)?;
Ok(Self {
raw_entries: raw_entries.as_slice(),
request_context: request_context.as_slice(),
entries,
})
}
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
buf.with_u8_length(|buf| buf.extend_from_slice(self.request_context))?;
buf.with_u24_length(|buf| {
for entry in &self.entries {
entry.encode(buf)?;
}
Ok(())
})?;
Ok(())
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum CertificateEntryRef<'a> {
X509(&'a [u8]),
RawPublicKey(&'a [u8]),
}
impl<'a> CertificateEntryRef<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ProtocolError> {
let entry_len = buf
.read_u24()
.map_err(|_| ProtocolError::InvalidCertificateEntry)?;
let cert = buf
.slice(entry_len as usize)
.map_err(|_| ProtocolError::InvalidCertificateEntry)?;
let entry = CertificateEntryRef::X509(cert.as_slice());
// Validate extensions
CertificateExtension::parse_vector::<2>(buf)?;
Ok(entry)
}
pub fn parse_vector<const N: usize>(
buf: &mut ParseBuffer<'a>,
) -> Result<Vec<Self, N>, ProtocolError> {
let mut result = Vec::new();
while !buf.is_empty() {
result
.push(Self::parse(buf)?)
.map_err(|_| ProtocolError::DecodeError)?;
}
Ok(result)
}
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
match *self {
CertificateEntryRef::RawPublicKey(_key) => {
todo!("ASN1_subjectPublicKeyInfo encoding?");
// buf.with_u24_length(|buf| buf.extend_from_slice(key))?;
}
CertificateEntryRef::X509(cert) => {
buf.with_u24_length(|buf| buf.extend_from_slice(cert))?;
}
}
// Zero extensions for now
buf.push_u16(0)?;
Ok(())
}
}
impl<'a, D: AsRef<[u8]>> From<&'a crate::config::Certificate<D>> for CertificateEntryRef<'a> {
fn from(cert: &'a crate::config::Certificate<D>) -> Self {
match cert {
crate::config::Certificate::X509(data) => CertificateEntryRef::X509(data.as_ref()),
crate::config::Certificate::RawPublicKey(data) => {
CertificateEntryRef::RawPublicKey(data.as_ref())
}
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Certificate<const N: usize> {
request_context: Vec<u8, 256>,
entries_data: Vec<u8, N>,
}
impl<const N: usize> Certificate<N> {
pub fn request_context(&self) -> &[u8] {
&self.request_context[..]
}
}
impl<'a, const N: usize> TryFrom<CertificateRef<'a>> for Certificate<N> {
type Error = ProtocolError;
fn try_from(cert: CertificateRef<'a>) -> Result<Self, Self::Error> {
let mut request_context = Vec::new();
request_context
.extend_from_slice(cert.request_context)
.map_err(|_| ProtocolError::OutOfMemory)?;
let mut entries_data = Vec::new();
entries_data
.extend_from_slice(cert.raw_entries)
.map_err(|_| ProtocolError::OutOfMemory)?;
Ok(Self {
request_context,
entries_data,
})
}
}
impl<'a, const N: usize> TryFrom<&'a Certificate<N>> for CertificateRef<'a> {
type Error = ProtocolError;
fn try_from(cert: &'a Certificate<N>) -> Result<Self, Self::Error> {
let request_context = cert.request_context();
let entries =
CertificateEntryRef::parse_vector(&mut ParseBuffer::from(&cert.entries_data[..]))?;
Ok(Self {
raw_entries: &cert.entries_data[..],
request_context,
entries,
})
}
}

View File

@@ -0,0 +1,50 @@
use crate::extensions::messages::CertificateRequestExtension;
use crate::parse_buffer::ParseBuffer;
use crate::{ProtocolError, unused};
use heapless::Vec;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct CertificateRequestRef<'a> {
pub(crate) request_context: &'a [u8],
}
impl<'a> CertificateRequestRef<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<CertificateRequestRef<'a>, ProtocolError> {
let request_context_len = buf
.read_u8()
.map_err(|_| ProtocolError::InvalidCertificateRequest)?;
let request_context = buf
.slice(request_context_len as usize)
.map_err(|_| ProtocolError::InvalidCertificateRequest)?;
// Validate extensions
let extensions = CertificateRequestExtension::parse_vector::<6>(buf)?;
unused(extensions);
Ok(Self {
request_context: request_context.as_slice(),
})
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct CertificateRequest {
pub(crate) request_context: Vec<u8, 256>,
}
impl<'a> TryFrom<CertificateRequestRef<'a>> for CertificateRequest {
type Error = ProtocolError;
fn try_from(cert: CertificateRequestRef<'a>) -> Result<Self, Self::Error> {
let mut request_context = Vec::new();
request_context
.extend_from_slice(cert.request_context)
.map_err(|_| {
error!("CertificateRequest: InsufficientSpace");
ProtocolError::InsufficientSpace
})?;
Ok(Self { request_context })
}
}

View File

@@ -0,0 +1,56 @@
use crate::ProtocolError;
use crate::extensions::extension_data::signature_algorithms::SignatureScheme;
use crate::parse_buffer::ParseBuffer;
use super::CryptoBuffer;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct HandshakeVerifyRef<'a> {
pub signature_scheme: SignatureScheme,
pub signature: &'a [u8],
}
impl<'a> HandshakeVerifyRef<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<HandshakeVerifyRef<'a>, ProtocolError> {
let signature_scheme =
SignatureScheme::parse(buf).map_err(|_| ProtocolError::InvalidSignatureScheme)?;
let len = buf.read_u16().map_err(|_| ProtocolError::InvalidSignature)?;
let signature = buf
.slice(len as usize)
.map_err(|_| ProtocolError::InvalidSignature)?;
Ok(Self {
signature_scheme,
signature: signature.as_slice(),
})
}
}
// Calculations for max. signature sizes:
// ecdsaSHA256 -> 6 bytes (ASN.1 structure) + 32-33 bytes (r) + 32-33 bytes (s) = 70..72 bytes
// ecdsaSHA384 -> 6 bytes (ASN.1 structure) + 48-49 bytes (r) + 48-49 bytes (s) = 102..104 bytes
// Ed25519 -> 6 bytes (ASN.1 structure) + 32-33 bytes (r) + 32-33 bytes (s) = 70..72 bytes
// RSA2048 -> 256 bytes
// RSA3072 -> 384 bytee
// RSA4096 -> 512 bytes
#[cfg(feature = "rsa")]
const SIGNATURE_SIZE: usize = 512;
#[cfg(not(feature = "rsa"))]
const SIGNATURE_SIZE: usize = 104;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct HandshakeVerify {
pub(crate) signature_scheme: SignatureScheme,
pub(crate) signature: heapless::Vec<u8, SIGNATURE_SIZE>,
}
impl HandshakeVerify {
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
buf.push_u16(self.signature_scheme.as_u16())?;
buf.with_u16_length(|buf| buf.extend_from_slice(self.signature.as_slice()))?;
Ok(())
}
}

View File

@@ -0,0 +1,188 @@
use core::marker::PhantomData;
use digest::{Digest, OutputSizeUser};
use heapless::Vec;
use p256::EncodedPoint;
use p256::ecdh::EphemeralSecret;
use p256::elliptic_curve::rand_core::RngCore;
use typenum::Unsigned;
use crate::ProtocolError;
use crate::config::{TlsCipherSuite, ConnectConfig};
use crate::extensions::extension_data::alpn::AlpnProtocolNameList;
use crate::extensions::extension_data::key_share::{KeyShareClientHello, KeyShareEntry};
use crate::extensions::extension_data::pre_shared_key::PreSharedKeyClientHello;
use crate::extensions::extension_data::psk_key_exchange_modes::{
PskKeyExchangeMode, PskKeyExchangeModes,
};
use crate::extensions::extension_data::server_name::ServerNameList;
use crate::extensions::extension_data::signature_algorithms::SignatureAlgorithms;
use crate::extensions::extension_data::supported_groups::{NamedGroup, SupportedGroups};
use crate::extensions::extension_data::supported_versions::{SupportedVersionsClientHello, TLS13};
use crate::extensions::messages::ClientHelloExtension;
use crate::handshake::{LEGACY_VERSION, Random};
use crate::key_schedule::{HashOutputSize, WriteKeySchedule};
use crate::{CryptoBackend, buffer::CryptoBuffer};
pub struct ClientHello<'config, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub(crate) config: &'config ConnectConfig<'config>,
random: Random,
cipher_suite: PhantomData<CipherSuite>,
pub(crate) secret: EphemeralSecret,
}
impl<'config, CipherSuite> ClientHello<'config, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub fn new<CP>(config: &'config ConnectConfig<'config>, mut provider: CP) -> Self
where
CP: CryptoBackend,
{
let mut random = [0; 32];
provider.rng().fill_bytes(&mut random);
Self {
config,
random,
cipher_suite: PhantomData,
secret: EphemeralSecret::random(&mut provider.rng()),
}
}
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
let public_key = EncodedPoint::from(&self.secret.public_key());
let public_key = public_key.as_ref();
buf.push_u16(LEGACY_VERSION)
.map_err(|_| ProtocolError::EncodeError)?;
buf.extend_from_slice(&self.random)
.map_err(|_| ProtocolError::EncodeError)?;
// session id (empty)
buf.push(0).map_err(|_| ProtocolError::EncodeError)?;
// cipher suites (2+)
//buf.extend_from_slice(&((self.config.cipher_suites.len() * 2) as u16).to_be_bytes());
//for c in self.config.cipher_suites.iter() {
//buf.extend_from_slice(&(*c as u16).to_be_bytes());
//}
buf.push_u16(2).map_err(|_| ProtocolError::EncodeError)?;
buf.push_u16(CipherSuite::CODE_POINT)
.map_err(|_| ProtocolError::EncodeError)?;
// compression methods, 1 byte of 0
buf.push(1).map_err(|_| ProtocolError::EncodeError)?;
buf.push(0).map_err(|_| ProtocolError::EncodeError)?;
// extensions (1+)
buf.with_u16_length(|buf| {
// Section 4.2.1. Supported Versions
// Implementations of this specification MUST send this extension in the
// ClientHello containing all versions of TLS which they are prepared to
// negotiate
ClientHelloExtension::SupportedVersions(SupportedVersionsClientHello {
versions: Vec::from_slice(&[TLS13]).unwrap(),
})
.encode(buf)?;
ClientHelloExtension::SignatureAlgorithms(SignatureAlgorithms {
supported_signature_algorithms: self.config.signature_schemes.clone(),
})
.encode(buf)?;
if let Some(max_fragment_length) = self.config.max_fragment_length {
ClientHelloExtension::MaxFragmentLength(max_fragment_length).encode(buf)?;
}
ClientHelloExtension::SupportedGroups(SupportedGroups {
supported_groups: self.config.named_groups.clone(),
})
.encode(buf)?;
ClientHelloExtension::PskKeyExchangeModes(PskKeyExchangeModes {
modes: Vec::from_slice(&[PskKeyExchangeMode::PskDheKe]).unwrap(),
})
.encode(buf)?;
ClientHelloExtension::KeyShare(KeyShareClientHello {
client_shares: Vec::from_slice(&[KeyShareEntry {
group: NamedGroup::Secp256r1,
opaque: public_key,
}])
.unwrap(),
})
.encode(buf)?;
if let Some(server_name) = self.config.server_name {
ClientHelloExtension::ServerName(ServerNameList::single(server_name))
.encode(buf)?;
}
if let Some(alpn_protocols) = self.config.alpn_protocols {
ClientHelloExtension::ApplicationLayerProtocolNegotiation(AlpnProtocolNameList {
protocols: alpn_protocols,
})
.encode(buf)?;
}
// Section 4.2
// When multiple extensions of different types are present, the
// extensions MAY appear in any order, with the exception of
// "pre_shared_key" which MUST be the last extension in
// the ClientHello.
if let Some((_, identities)) = &self.config.psk {
ClientHelloExtension::PreSharedKey(PreSharedKeyClientHello {
identities: identities.clone(),
hash_size: <CipherSuite::Hash as OutputSizeUser>::output_size(),
})
.encode(buf)?;
}
Ok(())
})?;
Ok(())
}
pub fn finalize(
&self,
enc_buf: &mut [u8],
transcript: &mut CipherSuite::Hash,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<(), ProtocolError> {
// Special case for PSK which needs to:
//
// 1. Add the client hello without the binders to the transcript
// 2. Create the binders for each identity using the transcript
// 3. Add the rest of the client hello.
//
// This causes a few issues since lengths must be correctly inside the payload,
// but won't actually be added to the record buffer until the end.
if let Some((_, identities)) = &self.config.psk {
let binders_len = identities.len() * (1 + HashOutputSize::<CipherSuite>::to_usize());
let binders_pos = enc_buf.len() - binders_len;
// NOTE: Exclude the binders_len itself from the digest
transcript.update(&enc_buf[0..binders_pos - 2]);
// Append after the client hello data. Sizes have already been set.
let mut buf = CryptoBuffer::wrap(&mut enc_buf[binders_pos..]);
// Create a binder and encode for each identity
for _id in identities {
let binder = write_key_schedule.create_psk_binder(transcript)?;
binder.encode(&mut buf)?;
}
transcript.update(&enc_buf[binders_pos - 2..]);
} else {
transcript.update(enc_buf);
}
Ok(())
}
}

View File

@@ -0,0 +1,19 @@
use core::marker::PhantomData;
use crate::extensions::messages::EncryptedExtensionsExtension;
use crate::ProtocolError;
use crate::parse_buffer::ParseBuffer;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct EncryptedExtensions<'a> {
_todo: PhantomData<&'a ()>,
}
impl<'a> EncryptedExtensions<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<EncryptedExtensions<'a>, ProtocolError> {
EncryptedExtensionsExtension::parse_vector::<16>(buf)?;
Ok(EncryptedExtensions { _todo: PhantomData })
}
}

52
src/handshake/finished.rs Normal file
View File

@@ -0,0 +1,52 @@
use crate::ProtocolError;
use crate::buffer::CryptoBuffer;
use crate::parse_buffer::ParseBuffer;
use core::fmt::{Debug, Formatter};
//use digest::generic_array::{ArrayLength, GenericArray};
use generic_array::{ArrayLength, GenericArray};
// use heapless::Vec;
pub struct Finished<N: ArrayLength<u8>> {
pub verify: GenericArray<u8, N>,
pub hash: Option<GenericArray<u8, N>>,
}
#[cfg(feature = "defmt")]
impl<N: ArrayLength<u8>> defmt::Format for Finished<N> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "verify length:{}", &self.verify.len());
}
}
impl<N: ArrayLength<u8>> Debug for Finished<N> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Finished")
.field("verify", &self.hash)
.finish()
}
}
impl<N: ArrayLength<u8>> Finished<N> {
pub fn parse(buf: &mut ParseBuffer, _len: u32) -> Result<Self, ProtocolError> {
// info!("finished len: {}", len);
let mut verify = GenericArray::default();
buf.fill(&mut verify)?;
//let hash = GenericArray::from_slice()
//let hash: Result<Vec<u8, _>, ()> = buf
//.slice(len as usize)
//.map_err(|_| ProtocolError::InvalidHandshake)?
//.into();
// info!("hash {:?}", verify);
//let hash = hash.map_err(|_| ProtocolError::InvalidHandshake)?;
// info!("hash ng {:?}", verify);
Ok(Self { verify, hash: None })
}
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
//let len = self.verify.len().to_be_bytes();
//buf.extend_from_slice(&[len[1], len[2], len[3]]);
buf.extend_from_slice(&self.verify[..self.verify.len()])
.map_err(|_| ProtocolError::EncodeError)?;
Ok(())
}
}

241
src/handshake/mod.rs Normal file
View File

@@ -0,0 +1,241 @@
//use p256::elliptic_curve::AffinePoint;
use crate::ProtocolError;
use crate::config::TlsCipherSuite;
use crate::handshake::certificate::CertificateRef;
use crate::handshake::certificate_request::CertificateRequestRef;
use crate::handshake::certificate_verify::{HandshakeVerify, HandshakeVerifyRef};
use crate::handshake::client_hello::ClientHello;
use crate::handshake::encrypted_extensions::EncryptedExtensions;
use crate::handshake::finished::Finished;
use crate::handshake::new_session_ticket::NewSessionTicket;
use crate::handshake::server_hello::ServerHello;
use crate::key_schedule::HashOutputSize;
use crate::parse_buffer::{ParseBuffer, ParseError};
use crate::{buffer::CryptoBuffer, key_schedule::WriteKeySchedule};
use core::fmt::{Debug, Formatter};
use sha2::Digest;
pub mod binder;
pub mod certificate;
pub mod certificate_request;
pub mod certificate_verify;
pub mod client_hello;
pub mod encrypted_extensions;
pub mod finished;
pub mod new_session_ticket;
pub mod server_hello;
const LEGACY_VERSION: u16 = 0x0303;
type Random = [u8; 32];
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum HandshakeType {
ClientHello = 1,
ServerHello = 2,
NewSessionTicket = 4,
EndOfEarlyData = 5,
EncryptedExtensions = 8,
Certificate = 11,
CertificateRequest = 13,
HandshakeVerify = 15,
Finished = 20,
KeyUpdate = 24,
MessageHash = 254,
}
impl HandshakeType {
pub fn parse(buf: &mut ParseBuffer) -> Result<Self, ParseError> {
match buf.read_u8()? {
1 => Ok(HandshakeType::ClientHello),
2 => Ok(HandshakeType::ServerHello),
4 => Ok(HandshakeType::NewSessionTicket),
5 => Ok(HandshakeType::EndOfEarlyData),
8 => Ok(HandshakeType::EncryptedExtensions),
11 => Ok(HandshakeType::Certificate),
13 => Ok(HandshakeType::CertificateRequest),
15 => Ok(HandshakeType::HandshakeVerify),
20 => Ok(HandshakeType::Finished),
24 => Ok(HandshakeType::KeyUpdate),
254 => Ok(HandshakeType::MessageHash),
_ => Err(ParseError::InvalidData),
}
}
}
#[allow(clippy::large_enum_variant)]
pub enum ClientHandshake<'config, 'a, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
ClientCert(CertificateRef<'a>),
ClientCertVerify(HandshakeVerify),
ClientHello(ClientHello<'config, CipherSuite>),
Finished(Finished<HashOutputSize<CipherSuite>>),
}
impl<CipherSuite> ClientHandshake<'_, '_, CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn handshake_type(&self) -> HandshakeType {
match self {
ClientHandshake::ClientHello(_) => HandshakeType::ClientHello,
ClientHandshake::Finished(_) => HandshakeType::Finished,
ClientHandshake::ClientCert(_) => HandshakeType::Certificate,
ClientHandshake::ClientCertVerify(_) => HandshakeType::HandshakeVerify,
}
}
fn encode_inner(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
match self {
ClientHandshake::ClientHello(inner) => inner.encode(buf),
ClientHandshake::Finished(inner) => inner.encode(buf),
ClientHandshake::ClientCert(inner) => inner.encode(buf),
ClientHandshake::ClientCertVerify(inner) => inner.encode(buf),
}
}
pub(crate) fn encode(&self, buf: &mut CryptoBuffer<'_>) -> Result<(), ProtocolError> {
buf.push(self.handshake_type() as u8)
.map_err(|_| ProtocolError::EncodeError)?;
buf.with_u24_length(|buf| self.encode_inner(buf))
}
pub fn finalize(
&self,
buf: &mut CryptoBuffer,
transcript: &mut CipherSuite::Hash,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<(), ProtocolError> {
let enc_buf = buf.as_mut_slice();
if let ClientHandshake::ClientHello(hello) = self {
hello.finalize(enc_buf, transcript, write_key_schedule)
} else {
transcript.update(enc_buf);
Ok(())
}
}
pub fn finalize_encrypted(buf: &mut CryptoBuffer, transcript: &mut CipherSuite::Hash) {
let enc_buf = buf.as_slice();
let end = enc_buf.len();
transcript.update(&enc_buf[0..end]);
}
}
#[allow(clippy::large_enum_variant)]
pub enum ServerHandshake<'a, CipherSuite: TlsCipherSuite> {
ServerHello(ServerHello<'a>),
EncryptedExtensions(EncryptedExtensions<'a>),
NewSessionTicket(NewSessionTicket<'a>),
Certificate(CertificateRef<'a>),
CertificateRequest(CertificateRequestRef<'a>),
HandshakeVerify(HandshakeVerifyRef<'a>),
Finished(Finished<HashOutputSize<CipherSuite>>),
}
impl<CipherSuite: TlsCipherSuite> ServerHandshake<'_, CipherSuite> {
#[allow(dead_code)]
pub fn handshake_type(&self) -> HandshakeType {
match self {
ServerHandshake::ServerHello(_) => HandshakeType::ServerHello,
ServerHandshake::EncryptedExtensions(_) => HandshakeType::EncryptedExtensions,
ServerHandshake::NewSessionTicket(_) => HandshakeType::NewSessionTicket,
ServerHandshake::Certificate(_) => HandshakeType::Certificate,
ServerHandshake::CertificateRequest(_) => HandshakeType::CertificateRequest,
ServerHandshake::HandshakeVerify(_) => HandshakeType::HandshakeVerify,
ServerHandshake::Finished(_) => HandshakeType::Finished,
}
}
}
impl<CipherSuite: TlsCipherSuite> Debug for ServerHandshake<'_, CipherSuite> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
match self {
ServerHandshake::ServerHello(inner) => Debug::fmt(inner, f),
ServerHandshake::EncryptedExtensions(inner) => Debug::fmt(inner, f),
ServerHandshake::Certificate(inner) => Debug::fmt(inner, f),
ServerHandshake::CertificateRequest(inner) => Debug::fmt(inner, f),
ServerHandshake::HandshakeVerify(inner) => Debug::fmt(inner, f),
ServerHandshake::Finished(inner) => Debug::fmt(inner, f),
ServerHandshake::NewSessionTicket(inner) => Debug::fmt(inner, f),
}
}
}
#[cfg(feature = "defmt")]
impl<'a, CipherSuite: TlsCipherSuite> defmt::Format for ServerHandshake<'a, CipherSuite> {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
ServerHandshake::ServerHello(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::EncryptedExtensions(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::Certificate(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::CertificateRequest(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::HandshakeVerify(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::Finished(inner) => defmt::write!(f, "{}", inner),
ServerHandshake::NewSessionTicket(inner) => defmt::write!(f, "{}", inner),
}
}
}
impl<'a, CipherSuite: TlsCipherSuite> ServerHandshake<'a, CipherSuite> {
pub fn read(
buf: &mut ParseBuffer<'a>,
digest: &mut CipherSuite::Hash,
) -> Result<Self, ProtocolError> {
let handshake_start = buf.offset();
let mut handshake = Self::parse(buf)?;
let handshake_end = buf.offset();
if let ServerHandshake::Finished(finished) = &mut handshake {
finished.hash.replace(digest.clone().finalize());
}
digest.update(&buf.as_slice()[handshake_start..handshake_end]);
Ok(handshake)
}
fn parse(buf: &mut ParseBuffer<'a>) -> Result<Self, ProtocolError> {
let handshake_type = HandshakeType::parse(buf).map_err(|_| ProtocolError::InvalidHandshake)?;
trace!("handshake = {:?}", handshake_type);
let content_len = buf.read_u24().map_err(|_| ProtocolError::InvalidHandshake)?;
let handshake = match handshake_type {
//HandshakeType::ClientHello => {}
HandshakeType::ServerHello => ServerHandshake::ServerHello(ServerHello::parse(buf)?),
HandshakeType::NewSessionTicket => {
ServerHandshake::NewSessionTicket(NewSessionTicket::parse(buf)?)
}
//HandshakeType::EndOfEarlyData => {}
HandshakeType::EncryptedExtensions => {
ServerHandshake::EncryptedExtensions(EncryptedExtensions::parse(buf)?)
}
HandshakeType::Certificate => ServerHandshake::Certificate(CertificateRef::parse(buf)?),
HandshakeType::CertificateRequest => {
ServerHandshake::CertificateRequest(CertificateRequestRef::parse(buf)?)
}
HandshakeType::HandshakeVerify => {
ServerHandshake::HandshakeVerify(HandshakeVerifyRef::parse(buf)?)
}
HandshakeType::Finished => {
ServerHandshake::Finished(Finished::parse(buf, content_len)?)
}
//HandshakeType::KeyUpdate => {}
//HandshakeType::MessageHash => {}
t => {
warn!("Unimplemented handshake type: {:?}", t);
return Err(ProtocolError::Unimplemented);
}
};
Ok(handshake)
}
}

View File

@@ -0,0 +1,33 @@
use core::marker::PhantomData;
use crate::extensions::messages::NewSessionTicketExtension;
use crate::parse_buffer::ParseBuffer;
use crate::{ProtocolError, unused};
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct NewSessionTicket<'a> {
_todo: PhantomData<&'a ()>,
}
impl<'a> NewSessionTicket<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<NewSessionTicket<'a>, ProtocolError> {
let lifetime = buf.read_u32()?;
let age_add = buf.read_u32()?;
let nonce_length = buf.read_u8()?;
let nonce = buf
.slice(nonce_length as usize)
.map_err(|_| ProtocolError::InvalidNonceLength)?;
let ticket_length = buf.read_u16()?;
let ticket = buf
.slice(ticket_length as usize)
.map_err(|_| ProtocolError::InvalidTicketLength)?;
let extensions = NewSessionTicketExtension::parse_vector::<1>(buf)?;
unused((lifetime, age_add, nonce, ticket, extensions));
Ok(Self { _todo: PhantomData })
}
}

View File

@@ -0,0 +1,83 @@
use heapless::Vec;
use crate::cipher_suites::CipherSuite;
use crate::cipher::CryptoEngine;
use crate::extensions::extension_data::key_share::KeyShareEntry;
use crate::extensions::messages::ServerHelloExtension;
use crate::parse_buffer::ParseBuffer;
use crate::{ProtocolError, unused};
use p256::PublicKey;
use p256::ecdh::{EphemeralSecret, SharedSecret};
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ServerHello<'a> {
extensions: Vec<ServerHelloExtension<'a>, 4>,
}
impl<'a> ServerHello<'a> {
pub fn parse(buf: &mut ParseBuffer<'a>) -> Result<ServerHello<'a>, ProtocolError> {
//let mut buf = ParseBuffer::new(&buf[0..content_length]);
//let mut buf = ParseBuffer::new(&buf);
let _version = buf.read_u16().map_err(|_| ProtocolError::InvalidHandshake)?;
let mut random = [0; 32];
buf.fill(&mut random)?;
let session_id_length = buf
.read_u8()
.map_err(|_| ProtocolError::InvalidSessionIdLength)?;
//info!("sh 1");
let session_id = buf
.slice(session_id_length as usize)
.map_err(|_| ProtocolError::InvalidSessionIdLength)?;
//info!("sh 2");
let cipher_suite = CipherSuite::parse(buf).map_err(|_| ProtocolError::InvalidCipherSuite)?;
////info!("sh 3");
// skip compression method, it's 0.
buf.read_u8()?;
let extensions = ServerHelloExtension::parse_vector(buf)?;
// debug!("server random {:x}", random);
// debug!("server session-id {:x}", session_id.as_slice());
debug!("server cipher_suite {:?}", cipher_suite);
debug!("server extensions {:?}", extensions);
unused(session_id);
Ok(Self { extensions })
}
pub fn key_share(&self) -> Option<&KeyShareEntry<'_>> {
self.extensions.iter().find_map(|e| {
if let ServerHelloExtension::KeyShare(entry) = e {
Some(&entry.0)
} else {
None
}
})
}
pub fn calculate_shared_secret(&self, secret: &EphemeralSecret) -> Option<SharedSecret> {
let server_key_share = self.key_share()?;
let server_public_key = PublicKey::from_sec1_bytes(server_key_share.opaque).ok()?;
Some(secret.diffie_hellman(&server_public_key))
}
#[allow(dead_code)]
pub fn initialize_crypto_engine(&self, secret: &EphemeralSecret) -> Option<CryptoEngine> {
let server_key_share = self.key_share()?;
let group = server_key_share.group;
let server_public_key = PublicKey::from_sec1_bytes(server_key_share.opaque).ok()?;
let shared = secret.diffie_hellman(&server_public_key);
Some(CryptoEngine::new(group, shared))
}
}

499
src/key_schedule.rs Normal file
View File

@@ -0,0 +1,499 @@
use crate::handshake::binder::PskBinder;
use crate::handshake::finished::Finished;
use crate::{ProtocolError, config::TlsCipherSuite};
use digest::OutputSizeUser;
use digest::generic_array::ArrayLength;
use hmac::{Mac, SimpleHmac};
use sha2::Digest;
use sha2::digest::generic_array::{GenericArray, typenum::Unsigned};
pub type HashOutputSize<CipherSuite> =
<<CipherSuite as TlsCipherSuite>::Hash as OutputSizeUser>::OutputSize;
pub type LabelBufferSize<CipherSuite> = <CipherSuite as TlsCipherSuite>::LabelBufferSize;
pub type IvArray<CipherSuite> = GenericArray<u8, <CipherSuite as TlsCipherSuite>::IvLen>;
pub type KeyArray<CipherSuite> = GenericArray<u8, <CipherSuite as TlsCipherSuite>::KeyLen>;
pub type HashArray<CipherSuite> = GenericArray<u8, HashOutputSize<CipherSuite>>;
type Hkdf<CipherSuite> = hkdf::Hkdf<
<CipherSuite as TlsCipherSuite>::Hash,
SimpleHmac<<CipherSuite as TlsCipherSuite>::Hash>,
>;
enum Secret<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
Uninitialized,
Initialized(Hkdf<CipherSuite>),
}
impl<CipherSuite> Secret<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn replace(&mut self, secret: Hkdf<CipherSuite>) {
*self = Self::Initialized(secret);
}
fn as_ref(&self) -> Result<&Hkdf<CipherSuite>, ProtocolError> {
match self {
Secret::Initialized(secret) => Ok(secret),
Secret::Uninitialized => Err(ProtocolError::InternalError),
}
}
fn make_expanded_hkdf_label<N: ArrayLength<u8>>(
&self,
label: &[u8],
context_type: ContextType<CipherSuite>,
) -> Result<GenericArray<u8, N>, ProtocolError> {
//info!("make label {:?} {}", label, len);
let mut hkdf_label = heapless_typenum::Vec::<u8, LabelBufferSize<CipherSuite>>::new();
hkdf_label
.extend_from_slice(&N::to_u16().to_be_bytes())
.map_err(|()| ProtocolError::InternalError)?;
let label_len = 6 + label.len() as u8;
hkdf_label
.extend_from_slice(&label_len.to_be_bytes())
.map_err(|()| ProtocolError::InternalError)?;
hkdf_label
.extend_from_slice(b"tls13 ")
.map_err(|()| ProtocolError::InternalError)?;
hkdf_label
.extend_from_slice(label)
.map_err(|()| ProtocolError::InternalError)?;
match context_type {
ContextType::None => {
hkdf_label.push(0).map_err(|_| ProtocolError::InternalError)?;
}
ContextType::Hash(context) => {
hkdf_label
.extend_from_slice(&(context.len() as u8).to_be_bytes())
.map_err(|()| ProtocolError::InternalError)?;
hkdf_label
.extend_from_slice(&context)
.map_err(|()| ProtocolError::InternalError)?;
}
}
let mut okm = GenericArray::default();
//info!("label {:x?}", label);
self.as_ref()?
.expand(&hkdf_label, &mut okm)
.map_err(|_| ProtocolError::CryptoError)?;
//info!("expand {:x?}", okm);
Ok(okm)
}
}
pub struct SharedState<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
secret: HashArray<CipherSuite>,
hkdf: Secret<CipherSuite>,
}
impl<CipherSuite> SharedState<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn new() -> Self {
Self {
secret: GenericArray::default(),
hkdf: Secret::Uninitialized,
}
}
fn initialize(&mut self, ikm: &[u8]) {
let (secret, hkdf) = Hkdf::<CipherSuite>::extract(Some(self.secret.as_ref()), ikm);
self.hkdf.replace(hkdf);
self.secret = secret;
}
fn derive_secret(
&mut self,
label: &[u8],
context_type: ContextType<CipherSuite>,
) -> Result<HashArray<CipherSuite>, ProtocolError> {
self.hkdf
.make_expanded_hkdf_label::<HashOutputSize<CipherSuite>>(label, context_type)
}
fn derived(&mut self) -> Result<(), ProtocolError> {
self.secret = self.derive_secret(b"derived", ContextType::empty_hash())?;
Ok(())
}
}
pub(crate) struct KeyScheduleState<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
traffic_secret: Secret<CipherSuite>,
counter: u64,
key: KeyArray<CipherSuite>,
iv: IvArray<CipherSuite>,
}
impl<CipherSuite> KeyScheduleState<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn new() -> Self {
Self {
traffic_secret: Secret::Uninitialized,
counter: 0,
key: KeyArray::<CipherSuite>::default(),
iv: IvArray::<CipherSuite>::default(),
}
}
#[inline]
pub fn get_key(&self) -> Result<&KeyArray<CipherSuite>, ProtocolError> {
Ok(&self.key)
}
#[inline]
pub fn get_iv(&self) -> Result<&IvArray<CipherSuite>, ProtocolError> {
Ok(&self.iv)
}
pub fn get_nonce(&self) -> Result<IvArray<CipherSuite>, ProtocolError> {
let iv = self.get_iv()?;
Ok(KeySchedule::<CipherSuite>::get_nonce(self.counter, iv))
}
fn calculate_traffic_secret(
&mut self,
label: &[u8],
shared: &mut SharedState<CipherSuite>,
transcript_hash: &CipherSuite::Hash,
) -> Result<(), ProtocolError> {
let secret = shared.derive_secret(label, ContextType::transcript_hash(transcript_hash))?;
let traffic_secret =
Hkdf::<CipherSuite>::from_prk(&secret).map_err(|_| ProtocolError::InternalError)?;
self.traffic_secret.replace(traffic_secret);
self.key = self
.traffic_secret
.make_expanded_hkdf_label(b"key", ContextType::None)?;
self.iv = self
.traffic_secret
.make_expanded_hkdf_label(b"iv", ContextType::None)?;
self.counter = 0;
Ok(())
}
pub fn increment_counter(&mut self) {
self.counter = unwrap!(self.counter.checked_add(1));
}
}
enum ContextType<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
None,
Hash(HashArray<CipherSuite>),
}
impl<CipherSuite> ContextType<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn transcript_hash(hash: &CipherSuite::Hash) -> Self {
Self::Hash(hash.clone().finalize())
}
fn empty_hash() -> Self {
Self::Hash(
<CipherSuite::Hash as Digest>::new()
.chain_update([])
.finalize(),
)
}
}
pub struct KeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
shared: SharedState<CipherSuite>,
client_state: WriteKeySchedule<CipherSuite>,
server_state: ReadKeySchedule<CipherSuite>,
}
impl<CipherSuite> KeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub fn new() -> Self {
Self {
shared: SharedState::new(),
client_state: WriteKeySchedule {
state: KeyScheduleState::new(),
binder_key: Secret::Uninitialized,
},
server_state: ReadKeySchedule {
state: KeyScheduleState::new(),
transcript_hash: <CipherSuite::Hash as Digest>::new(),
},
}
}
pub(crate) fn transcript_hash(&mut self) -> &mut CipherSuite::Hash {
&mut self.server_state.transcript_hash
}
pub(crate) fn replace_transcript_hash(&mut self, hash: CipherSuite::Hash) {
self.server_state.transcript_hash = hash;
}
pub fn as_split(
&mut self,
) -> (
&mut WriteKeySchedule<CipherSuite>,
&mut ReadKeySchedule<CipherSuite>,
) {
(&mut self.client_state, &mut self.server_state)
}
pub(crate) fn write_state(&mut self) -> &mut WriteKeySchedule<CipherSuite> {
&mut self.client_state
}
pub(crate) fn read_state(&mut self) -> &mut ReadKeySchedule<CipherSuite> {
&mut self.server_state
}
pub fn create_client_finished(
&self,
) -> Result<Finished<HashOutputSize<CipherSuite>>, ProtocolError> {
let key = self
.client_state
.state
.traffic_secret
.make_expanded_hkdf_label::<HashOutputSize<CipherSuite>>(
b"finished",
ContextType::None,
)?;
let mut hmac = SimpleHmac::<CipherSuite::Hash>::new_from_slice(&key)
.map_err(|_| ProtocolError::CryptoError)?;
Mac::update(
&mut hmac,
&self.server_state.transcript_hash.clone().finalize(),
);
let verify = hmac.finalize().into_bytes();
Ok(Finished { verify, hash: None })
}
fn get_nonce(counter: u64, iv: &IvArray<CipherSuite>) -> IvArray<CipherSuite> {
//info!("counter = {} {:x?}", counter, &counter.to_be_bytes(),);
let counter = Self::pad::<CipherSuite::IvLen>(&counter.to_be_bytes());
//info!("counter = {:x?}", counter);
// info!("iv = {:x?}", iv);
let mut nonce = GenericArray::default();
for (index, (l, r)) in iv[0..CipherSuite::IvLen::to_usize()]
.iter()
.zip(counter.iter())
.enumerate()
{
nonce[index] = l ^ r;
}
//debug!("nonce {:x?}", nonce);
nonce
}
fn pad<N: ArrayLength<u8>>(input: &[u8]) -> GenericArray<u8, N> {
// info!("padding input = {:x?}", input);
let mut padded = GenericArray::default();
for (index, byte) in input.iter().rev().enumerate() {
/*info!(
"{} pad {}={:x?}",
index,
((N::to_usize() - index) - 1),
*byte
);*/
padded[(N::to_usize() - index) - 1] = *byte;
}
padded
}
fn zero() -> HashArray<CipherSuite> {
GenericArray::default()
}
// Initializes the early secrets with a callback for any PSK binders
pub fn initialize_early_secret(&mut self, psk: Option<&[u8]>) -> Result<(), ProtocolError> {
self.shared.initialize(
#[allow(clippy::or_fun_call)]
psk.unwrap_or(Self::zero().as_slice()),
);
let binder_key = self
.shared
.derive_secret(b"ext binder", ContextType::empty_hash())?;
self.client_state.binder_key.replace(
Hkdf::<CipherSuite>::from_prk(&binder_key).map_err(|_| ProtocolError::InternalError)?,
);
self.shared.derived()
}
pub fn initialize_handshake_secret(&mut self, ikm: &[u8]) -> Result<(), ProtocolError> {
self.shared.initialize(ikm);
self.calculate_traffic_secrets(b"c hs traffic", b"s hs traffic")?;
self.shared.derived()
}
pub fn initialize_master_secret(&mut self) -> Result<(), ProtocolError> {
self.shared.initialize(Self::zero().as_slice());
//let context = self.transcript_hash.as_ref().unwrap().clone().finalize();
//info!("Derive keys, hash: {:x?}", context);
self.calculate_traffic_secrets(b"c ap traffic", b"s ap traffic")?;
self.shared.derived()
}
fn calculate_traffic_secrets(
&mut self,
client_label: &[u8],
server_label: &[u8],
) -> Result<(), ProtocolError> {
self.client_state.state.calculate_traffic_secret(
client_label,
&mut self.shared,
&self.server_state.transcript_hash,
)?;
self.server_state.state.calculate_traffic_secret(
server_label,
&mut self.shared,
&self.server_state.transcript_hash,
)?;
Ok(())
}
}
impl<CipherSuite> Default for KeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
fn default() -> Self {
KeySchedule::new()
}
}
pub struct WriteKeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
state: KeyScheduleState<CipherSuite>,
binder_key: Secret<CipherSuite>,
}
impl<CipherSuite> WriteKeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub(crate) fn increment_counter(&mut self) {
self.state.increment_counter();
}
pub(crate) fn get_key(&self) -> Result<&KeyArray<CipherSuite>, ProtocolError> {
self.state.get_key()
}
pub(crate) fn get_nonce(&self) -> Result<IvArray<CipherSuite>, ProtocolError> {
self.state.get_nonce()
}
pub fn create_psk_binder(
&self,
transcript_hash: &CipherSuite::Hash,
) -> Result<PskBinder<HashOutputSize<CipherSuite>>, ProtocolError> {
let key = self
.binder_key
.make_expanded_hkdf_label::<HashOutputSize<CipherSuite>>(
b"finished",
ContextType::None,
)?;
let mut hmac = SimpleHmac::<CipherSuite::Hash>::new_from_slice(&key)
.map_err(|_| ProtocolError::CryptoError)?;
Mac::update(&mut hmac, &transcript_hash.clone().finalize());
let verify = hmac.finalize().into_bytes();
Ok(PskBinder { verify })
}
}
pub struct ReadKeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
state: KeyScheduleState<CipherSuite>,
transcript_hash: CipherSuite::Hash,
}
impl<CipherSuite> ReadKeySchedule<CipherSuite>
where
CipherSuite: TlsCipherSuite,
{
pub(crate) fn increment_counter(&mut self) {
self.state.increment_counter();
}
pub(crate) fn transcript_hash(&mut self) -> &mut CipherSuite::Hash {
&mut self.transcript_hash
}
pub(crate) fn get_key(&self) -> Result<&KeyArray<CipherSuite>, ProtocolError> {
self.state.get_key()
}
pub(crate) fn get_nonce(&self) -> Result<IvArray<CipherSuite>, ProtocolError> {
self.state.get_nonce()
}
pub fn verify_server_finished(
&self,
finished: &Finished<HashOutputSize<CipherSuite>>,
) -> Result<bool, ProtocolError> {
//info!("verify server finished: {:x?}", finished.verify);
//self.client_traffic_secret.as_ref().unwrap().expand()
//info!("size ===> {}", D::OutputSize::to_u16());
let key = self
.state
.traffic_secret
.make_expanded_hkdf_label::<HashOutputSize<CipherSuite>>(
b"finished",
ContextType::None,
)?;
// info!("hmac sign key {:x?}", key);
let mut hmac = SimpleHmac::<CipherSuite::Hash>::new_from_slice(&key)
.map_err(|_| ProtocolError::InternalError)?;
Mac::update(
&mut hmac,
finished.hash.as_ref().ok_or_else(|| {
warn!("No hash in Finished");
ProtocolError::InternalError
})?,
);
//let code = hmac.clone().finalize().into_bytes();
Ok(hmac.verify(&finished.verify).is_ok())
//info!("verified {:?}", verified);
//unimplemented!()
}
}

169
src/lib.rs Normal file
View File

@@ -0,0 +1,169 @@
#![cfg_attr(not(any(test, feature = "std")), no_std)]
#![warn(clippy::pedantic)]
#![allow(
clippy::module_name_repetitions,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::missing_errors_doc // TODO
)]
/*!
# Example
```
use mote_tls::*;
use embedded_io_adapters::tokio_1::FromTokio;
use rand::rngs::OsRng;
use tokio::net::TcpStream;
#[tokio::main]
async fn main() {
let stream = TcpStream::connect("google.com:443")
.await
.expect("error creating TCP connection");
println!("TCP connection opened");
let mut read_record_buffer = [0; 16384];
let mut write_record_buffer = [0; 16384];
let config = ConnectConfig::new().with_server_name("google.com").enable_rsa_signatures();
let mut tls = SecureStream::new(
FromTokio::new(stream),
&mut read_record_buffer,
&mut write_record_buffer,
);
// Allows disabling cert verification, in case you are using PSK and don't need it, or are just testing.
// otherwise, use mote_tls::cert_verify::CertVerifier, which only works on std for now.
tls.open(ConnectContext::new(
&config,
SkipVerifyProvider::new::<Aes128GcmSha256>(OsRng),
))
.await
.expect("error establishing TLS connection");
println!("TLS session opened");
}
```
*/
// This mod MUST go first, so that the others see its macros.
pub(crate) mod fmt;
use parse_buffer::ParseError;
pub mod alert;
mod application_data;
pub mod blocking;
mod buffer;
mod change_cipher_spec;
mod cipher_suites;
mod common;
mod config;
mod connection;
mod content_types;
mod cipher;
mod extensions;
pub mod send_policy;
mod handshake;
mod key_schedule;
mod parse_buffer;
pub mod read_buffer;
mod record;
mod record_reader;
mod write_buffer;
pub use config::SkipVerifyProvider;
pub use extensions::extension_data::signature_algorithms::SignatureScheme;
pub use handshake::certificate_verify::HandshakeVerify;
pub use rand_core::{CryptoRng, CryptoRngCore};
#[cfg(feature = "webpki")]
pub mod cert_verify;
#[cfg(feature = "native-pki")]
mod certificate;
#[cfg(feature = "native-pki")]
pub mod native_pki;
mod asynch;
pub use asynch::*;
pub use send_policy::*;
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ProtocolError {
ConnectionClosed,
Unimplemented,
MissingHandshake,
HandshakeAborted(alert::AlertLevel, alert::AlertDescription),
AbortHandshake(alert::AlertLevel, alert::AlertDescription),
IoError,
InternalError,
InvalidRecord,
UnknownContentType,
InvalidNonceLength,
InvalidTicketLength,
UnknownExtensionType,
InsufficientSpace,
InvalidHandshake,
InvalidCipherSuite,
InvalidSignatureScheme,
InvalidSignature,
InvalidExtensionsLength,
InvalidSessionIdLength,
InvalidSupportedVersions,
InvalidApplicationData,
InvalidKeyShare,
InvalidCertificate,
InvalidCertificateEntry,
InvalidCertificateRequest,
InvalidPrivateKey,
UnableToInitializeCryptoEngine,
ParseError(ParseError),
OutOfMemory,
CryptoError,
EncodeError,
DecodeError,
Io(embedded_io::ErrorKind),
}
impl embedded_io::Error for ProtocolError {
fn kind(&self) -> embedded_io::ErrorKind {
if let Self::Io(k) = self {
*k
} else {
error!("TLS error: {:?}", self);
embedded_io::ErrorKind::Other
}
}
}
impl core::fmt::Display for ProtocolError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self:?}")
}
}
impl core::error::Error for ProtocolError {}
#[cfg(feature = "std")]
mod stdlib {
use crate::config::TlsClock;
use std::time::SystemTime;
impl TlsClock for SystemTime {
fn now() -> Option<u64> {
Some(
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs(),
)
}
}
}
/// An internal function to mark an unused value.
///
/// All calls to this should be removed before 1.x.
fn unused<T>(_: T) {}

468
src/native_pki.rs Normal file
View File

@@ -0,0 +1,468 @@
use crate::ProtocolError;
use crate::config::{Certificate, TlsCipherSuite, TlsClock, Verifier};
#[cfg(feature = "p384")]
use crate::certificate::ECDSA_SHA384;
#[cfg(feature = "ed25519")]
use crate::certificate::ED25519;
use crate::certificate::{DecodedCertificate, ECDSA_SHA256, Time};
#[cfg(feature = "rsa")]
use crate::certificate::{RSA_PKCS1_SHA256, RSA_PKCS1_SHA384, RSA_PKCS1_SHA512};
use crate::extensions::extension_data::signature_algorithms::SignatureScheme;
use crate::handshake::{
certificate::{
Certificate as OwnedCertificate, CertificateEntryRef, CertificateRef as ServerCertificate,
},
certificate_verify::HandshakeVerifyRef,
};
use crate::parse_buffer::ParseError;
use const_oid::ObjectIdentifier;
use core::marker::PhantomData;
use der::Decode;
use digest::Digest;
use heapless::{String, Vec};
const HOSTNAME_MAXLEN: usize = 64;
const COMMON_NAME_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("2.5.4.3");
pub struct CertificateChain<'a> {
prev: Option<&'a CertificateEntryRef<'a>>,
chain: &'a ServerCertificate<'a>,
idx: isize,
}
impl<'a> CertificateChain<'a> {
pub fn new(ca: &'a CertificateEntryRef, chain: &'a ServerCertificate<'a>) -> Self {
Self {
prev: Some(ca),
chain,
idx: chain.entries.len() as isize - 1,
}
}
}
impl<'a> Iterator for CertificateChain<'a> {
type Item = (&'a CertificateEntryRef<'a>, &'a CertificateEntryRef<'a>);
fn next(&mut self) -> Option<Self::Item> {
if self.idx < 0 {
return None;
}
let cur = &self.chain.entries[self.idx as usize];
let out = (self.prev.unwrap(), cur);
self.prev = Some(cur);
self.idx -= 1;
Some(out)
}
}
pub struct CertVerifier<'a, CipherSuite, Clock, const CERT_SIZE: usize>
where
Clock: TlsClock,
CipherSuite: TlsCipherSuite,
{
ca: Certificate<&'a [u8]>,
host: Option<heapless::String<64>>,
certificate_transcript: Option<CipherSuite::Hash>,
certificate: Option<OwnedCertificate<CERT_SIZE>>,
_clock: PhantomData<Clock>,
}
impl<'a, CipherSuite, Clock, const CERT_SIZE: usize> CertVerifier<'a, CipherSuite, Clock, CERT_SIZE>
where
Clock: TlsClock,
CipherSuite: TlsCipherSuite,
{
#[must_use]
pub fn new(ca: Certificate<&'a [u8]>) -> Self {
Self {
ca,
host: None,
certificate_transcript: None,
certificate: None,
_clock: PhantomData,
}
}
}
impl<CipherSuite, Clock, const CERT_SIZE: usize> Verifier<CipherSuite>
for CertVerifier<'_, CipherSuite, Clock, CERT_SIZE>
where
CipherSuite: TlsCipherSuite,
Clock: TlsClock,
{
fn set_hostname_verification(&mut self, hostname: &str) -> Result<(), ProtocolError> {
self.host.replace(
heapless::String::try_from(hostname).map_err(|_| ProtocolError::InsufficientSpace)?,
);
Ok(())
}
fn verify_certificate(
&mut self,
transcript: &CipherSuite::Hash,
cert: ServerCertificate,
) -> Result<(), ProtocolError> {
let mut cn = None;
for (p, q) in CertificateChain::new(&(&self.ca).into(), &cert) {
cn = verify_certificate(p, q, Clock::now())?;
}
if self.host.ne(&cn) {
error!(
"Hostname ({:?}) does not match CommonName ({:?})",
self.host, cn
);
return Err(ProtocolError::InvalidCertificate);
}
self.certificate.replace(cert.try_into()?);
self.certificate_transcript.replace(transcript.clone());
Ok(())
}
fn verify_signature(&mut self, verify: HandshakeVerifyRef) -> Result<(), ProtocolError> {
let handshake_hash = unwrap!(self.certificate_transcript.take());
let ctx_str = b"TLS 1.3, server HandshakeVerify\x00";
let mut msg: Vec<u8, 146> = Vec::new();
msg.resize(64, 0x20).map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(ctx_str)
.map_err(|_| ProtocolError::EncodeError)?;
msg.extend_from_slice(&handshake_hash.finalize())
.map_err(|_| ProtocolError::EncodeError)?;
let certificate = unwrap!(self.certificate.as_ref()).try_into()?;
verify_signature(&msg[..], &certificate, &verify)?;
Ok(())
}
}
fn verify_signature(
message: &[u8],
certificate: &ServerCertificate,
verify: &HandshakeVerifyRef,
) -> Result<(), ProtocolError> {
let verified;
let certificate =
if let Some(CertificateEntryRef::X509(certificate)) = certificate.entries.first() {
certificate
} else {
return Err(ProtocolError::DecodeError);
};
let certificate =
DecodedCertificate::from_der(certificate).map_err(|_| ProtocolError::DecodeError)?;
let public_key = certificate
.tbs_certificate
.subject_public_key_info
.public_key
.as_bytes()
.ok_or(ProtocolError::DecodeError)?;
match verify.signature_scheme {
SignatureScheme::EcdsaSecp256r1Sha256 => {
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let verifying_key =
VerifyingKey::from_sec1_bytes(public_key).map_err(|_| ProtocolError::DecodeError)?;
let signature =
Signature::from_der(&verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
#[cfg(feature = "p384")]
SignatureScheme::EcdsaSecp384r1Sha384 => {
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let verifying_key =
VerifyingKey::from_sec1_bytes(public_key).map_err(|_| ProtocolError::DecodeError)?;
let signature =
Signature::from_der(&verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
#[cfg(feature = "ed25519")]
SignatureScheme::Ed25519 => {
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
let verifying_key: VerifyingKey =
VerifyingKey::from_bytes(public_key.try_into().unwrap())
.map_err(|_| ProtocolError::DecodeError)?;
let signature =
Signature::try_from(verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
#[cfg(feature = "rsa")]
SignatureScheme::RsaPssRsaeSha256 => {
use rsa::{
RsaPublicKey,
pkcs1::DecodeRsaPublicKey,
pss::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha256;
let der_pubkey = RsaPublicKey::from_pkcs1_der(public_key).unwrap();
let verifying_key = VerifyingKey::<Sha256>::from(der_pubkey);
let signature =
Signature::try_from(verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
#[cfg(feature = "rsa")]
SignatureScheme::RsaPssRsaeSha384 => {
use rsa::{
RsaPublicKey,
pkcs1::DecodeRsaPublicKey,
pss::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha384;
let der_pubkey =
RsaPublicKey::from_pkcs1_der(public_key).map_err(|_| ProtocolError::DecodeError)?;
let verifying_key = VerifyingKey::<Sha384>::from(der_pubkey);
let signature =
Signature::try_from(verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
#[cfg(feature = "rsa")]
SignatureScheme::RsaPssRsaeSha512 => {
use rsa::{
RsaPublicKey,
pkcs1::DecodeRsaPublicKey,
pss::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha512;
let der_pubkey =
RsaPublicKey::from_pkcs1_der(public_key).map_err(|_| ProtocolError::DecodeError)?;
let verifying_key = VerifyingKey::<Sha512>::from(der_pubkey);
let signature =
Signature::try_from(verify.signature).map_err(|_| ProtocolError::DecodeError)?;
verified = verifying_key.verify(message, &signature).is_ok();
}
_ => {
error!(
"InvalidSignatureScheme: {:?} Are you missing a feature?",
verify.signature_scheme
);
return Err(ProtocolError::InvalidSignatureScheme);
}
}
if !verified {
return Err(ProtocolError::InvalidSignature);
}
Ok(())
}
fn get_certificate_tlv_bytes<'a>(input: &[u8]) -> der::Result<&[u8]> {
use der::{Decode, Reader, SliceReader};
let mut reader = SliceReader::new(input)?;
let top_header = der::Header::decode(&mut reader)?;
top_header.tag().assert_eq(der::Tag::Sequence)?;
let header = der::Header::peek(&mut reader)?;
header.tag().assert_eq(der::Tag::Sequence)?;
// Should we read the remaining two fields and call reader.finish() just be certain here?
reader.tlv_bytes()
}
fn get_cert_time(time: Time) -> u64 {
match time {
Time::UtcTime(utc_time) => utc_time.to_unix_duration().as_secs(),
Time::GeneralTime(generalized_time) => generalized_time.to_unix_duration().as_secs(),
}
}
fn verify_certificate(
verifier: &CertificateEntryRef,
certificate: &CertificateEntryRef,
now: Option<u64>,
) -> Result<Option<heapless::String<HOSTNAME_MAXLEN>>, ProtocolError> {
let mut verified = false;
let mut common_name = None;
let ca_certificate = if let CertificateEntryRef::X509(verifier) = verifier {
DecodedCertificate::from_der(verifier).map_err(|_| ProtocolError::DecodeError)?
} else {
return Err(ProtocolError::DecodeError);
};
if let CertificateEntryRef::X509(certificate) = certificate {
let parsed_certificate =
DecodedCertificate::from_der(certificate).map_err(|_| ProtocolError::DecodeError)?;
let ca_public_key = ca_certificate
.tbs_certificate
.subject_public_key_info
.public_key
.as_bytes()
.ok_or(ProtocolError::DecodeError)?;
for elems in parsed_certificate.tbs_certificate.subject.iter() {
let attrs = elems
.get(0)
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?;
if attrs.oid == COMMON_NAME_OID {
let mut v: Vec<u8, HOSTNAME_MAXLEN> = Vec::new();
v.extend_from_slice(attrs.value.value())
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
common_name = String::from_utf8(v).ok();
debug!("CommonName: {:?}", common_name);
}
}
if let Some(now) = now {
if get_cert_time(parsed_certificate.tbs_certificate.validity.not_before) > now
|| get_cert_time(parsed_certificate.tbs_certificate.validity.not_after) < now
{
return Err(ProtocolError::InvalidCertificate);
}
debug!("Epoch is {} and certificate is valid!", now)
}
let certificate_data =
get_certificate_tlv_bytes(certificate).map_err(|_| ProtocolError::DecodeError)?;
match parsed_certificate.signature_algorithm {
ECDSA_SHA256 => {
use p256::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let verifying_key = VerifyingKey::from_sec1_bytes(ca_public_key)
.map_err(|_| ProtocolError::DecodeError)?;
let signature = Signature::from_der(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
verified = verifying_key.verify(&certificate_data, &signature).is_ok();
}
#[cfg(feature = "p384")]
ECDSA_SHA384 => {
use p384::ecdsa::{Signature, VerifyingKey, signature::Verifier};
let verifying_key = VerifyingKey::from_sec1_bytes(ca_public_key)
.map_err(|_| ProtocolError::DecodeError)?;
let signature = Signature::from_der(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
verified = verifying_key.verify(&certificate_data, &signature).is_ok();
}
#[cfg(feature = "ed25519")]
ED25519 => {
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
let verifying_key: VerifyingKey =
VerifyingKey::from_bytes(ca_public_key.try_into().unwrap())
.map_err(|_| ProtocolError::DecodeError)?;
let signature = Signature::try_from(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
verified = verifying_key.verify(certificate_data, &signature).is_ok();
}
#[cfg(feature = "rsa")]
a if a == RSA_PKCS1_SHA256 => {
use rsa::{
pkcs1::DecodeRsaPublicKey,
pkcs1v15::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha256;
let verifying_key =
VerifyingKey::<Sha256>::from_pkcs1_der(ca_public_key).map_err(|e| {
error!("VerifyingKey: {}", e);
ProtocolError::DecodeError
})?;
let signature = Signature::try_from(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|e| {
error!("Signature: {}", e);
ProtocolError::ParseError(ParseError::InvalidData)
})?;
verified = verifying_key.verify(certificate_data, &signature).is_ok();
}
#[cfg(feature = "rsa")]
a if a == RSA_PKCS1_SHA384 => {
use rsa::{
pkcs1::DecodeRsaPublicKey,
pkcs1v15::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha384;
let verifying_key = VerifyingKey::<Sha384>::from_pkcs1_der(ca_public_key)
.map_err(|_| ProtocolError::DecodeError)?;
let signature = Signature::try_from(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
verified = verifying_key.verify(certificate_data, &signature).is_ok();
}
#[cfg(feature = "rsa")]
a if a == RSA_PKCS1_SHA512 => {
use rsa::{
pkcs1::DecodeRsaPublicKey,
pkcs1v15::{Signature, VerifyingKey},
signature::Verifier,
};
use sha2::Sha512;
let verifying_key = VerifyingKey::<Sha512>::from_pkcs1_der(ca_public_key)
.map_err(|_| ProtocolError::DecodeError)?;
let signature = Signature::try_from(
parsed_certificate
.signature
.as_bytes()
.ok_or(ProtocolError::ParseError(ParseError::InvalidData))?,
)
.map_err(|_| ProtocolError::ParseError(ParseError::InvalidData))?;
verified = verifying_key.verify(certificate_data, &signature).is_ok();
}
_ => {
error!(
"Unsupported signature alg: {:?}",
parsed_certificate.signature_algorithm
);
return Err(ProtocolError::InvalidSignatureScheme);
}
}
}
if !verified {
return Err(ProtocolError::InvalidCertificate);
}
Ok(common_name)
}

173
src/parse_buffer.rs Normal file
View File

@@ -0,0 +1,173 @@
use crate::ProtocolError;
use heapless::{CapacityError, Vec};
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ParseError {
InsufficientBytes,
InsufficientSpace,
InvalidData,
}
pub struct ParseBuffer<'b> {
pos: usize,
buffer: &'b [u8],
}
impl<'b> From<&'b [u8]> for ParseBuffer<'b> {
fn from(val: &'b [u8]) -> Self {
ParseBuffer::new(val)
}
}
impl<'b, const N: usize> From<ParseBuffer<'b>> for Result<Vec<u8, N>, CapacityError> {
fn from(val: ParseBuffer<'b>) -> Self {
Vec::from_slice(&val.buffer[val.pos..])
}
}
impl<'b> ParseBuffer<'b> {
pub fn new(buffer: &'b [u8]) -> Self {
Self { pos: 0, buffer }
}
pub fn is_empty(&self) -> bool {
self.pos == self.buffer.len()
}
pub fn remaining(&self) -> usize {
self.buffer.len() - self.pos
}
pub fn offset(&self) -> usize {
self.pos
}
pub fn as_slice(&self) -> &'b [u8] {
self.buffer
}
pub fn slice(&mut self, len: usize) -> Result<ParseBuffer<'b>, ParseError> {
if self.pos + len <= self.buffer.len() {
let slice = ParseBuffer::new(&self.buffer[self.pos..self.pos + len]);
self.pos += len;
Ok(slice)
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn read_u8(&mut self) -> Result<u8, ParseError> {
if self.pos < self.buffer.len() {
let value = self.buffer[self.pos];
self.pos += 1;
Ok(value)
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn read_u16(&mut self) -> Result<u16, ParseError> {
//info!("pos={} len={}", self.pos, self.buffer.len());
if self.pos + 2 <= self.buffer.len() {
let value = u16::from_be_bytes([self.buffer[self.pos], self.buffer[self.pos + 1]]);
self.pos += 2;
Ok(value)
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn read_u24(&mut self) -> Result<u32, ParseError> {
if self.pos + 3 <= self.buffer.len() {
let value = u32::from_be_bytes([
0,
self.buffer[self.pos],
self.buffer[self.pos + 1],
self.buffer[self.pos + 2],
]);
self.pos += 3;
Ok(value)
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn read_u32(&mut self) -> Result<u32, ParseError> {
if self.pos + 4 <= self.buffer.len() {
let value = u32::from_be_bytes([
self.buffer[self.pos],
self.buffer[self.pos + 1],
self.buffer[self.pos + 2],
self.buffer[self.pos + 3],
]);
self.pos += 4;
Ok(value)
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn fill(&mut self, dest: &mut [u8]) -> Result<(), ParseError> {
if self.pos + dest.len() <= self.buffer.len() {
dest.copy_from_slice(&self.buffer[self.pos..self.pos + dest.len()]);
self.pos += dest.len();
// info!("Copied {} bytes", dest.len());
Ok(())
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn copy<const N: usize>(
&mut self,
dest: &mut Vec<u8, N>,
num_bytes: usize,
) -> Result<(), ParseError> {
let space = dest.capacity() - dest.len();
if space < num_bytes {
error!(
"Insufficient space in destination buffer. Space: {} required: {}",
space, num_bytes
);
Err(ParseError::InsufficientSpace)
} else if self.pos + num_bytes <= self.buffer.len() {
dest.extend_from_slice(&self.buffer[self.pos..self.pos + num_bytes])
.map_err(|_| {
error!(
"Failed to extend destination buffer. Space: {} required: {}",
space, num_bytes
);
ParseError::InsufficientSpace
})?;
self.pos += num_bytes;
Ok(())
} else {
Err(ParseError::InsufficientBytes)
}
}
pub fn read_list<T, const N: usize>(
&mut self,
data_length: usize,
read: impl Fn(&mut ParseBuffer<'b>) -> Result<T, ParseError>,
) -> Result<Vec<T, N>, ParseError> {
let mut result = Vec::new();
let mut data = self.slice(data_length)?;
while !data.is_empty() {
result.push(read(&mut data)?).map_err(|_| {
error!("Failed to store parse result");
ParseError::InsufficientSpace
})?;
}
Ok(result)
}
}
impl From<ParseError> for ProtocolError {
fn from(e: ParseError) -> Self {
ProtocolError::ParseError(e)
}
}

180
src/read_buffer.rs Normal file
View File

@@ -0,0 +1,180 @@
/// A reference to consume bytes from the internal buffer.
#[must_use]
pub struct ReadBuffer<'a> {
data: &'a [u8],
consumed: usize,
used: bool,
decrypted_consumed: &'a mut usize,
}
impl<'a> ReadBuffer<'a> {
#[inline]
pub(crate) fn new(buffer: &'a [u8], decrypted_consumed: &'a mut usize) -> Self {
Self {
data: buffer,
consumed: 0,
used: false,
decrypted_consumed,
}
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.data.len() - self.consumed
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Consumes and returns a slice of at most `count` bytes.
#[inline]
pub fn peek(&mut self, count: usize) -> &'a [u8] {
let count = self.len().min(count);
let start = self.consumed;
// We mark the buffer used to prevent dropping unconsumed bytes.
self.used = true;
&self.data[start..start + count]
}
/// Consumes and returns a slice of at most `count` bytes.
#[inline]
pub fn peek_all(&mut self) -> &'a [u8] {
self.peek(self.len())
}
/// Consumes and returns a slice of at most `count` bytes.
#[inline]
pub fn pop(&mut self, count: usize) -> &'a [u8] {
let count = self.len().min(count);
let start = self.consumed;
self.consumed += count;
self.used = true;
&self.data[start..start + count]
}
/// Consumes and returns the internal buffer.
#[inline]
pub fn pop_all(&mut self) -> &'a [u8] {
self.pop(self.len())
}
/// Drops the reference and restores internal buffer.
#[inline]
pub fn revert(self) {
core::mem::forget(self);
}
/// Tries to fills the buffer by consuming and copying bytes into it.
#[inline]
pub fn pop_into(&mut self, buf: &mut [u8]) -> usize {
let to_copy = self.pop(buf.len());
buf[..to_copy.len()].copy_from_slice(to_copy);
to_copy.len()
}
}
impl Drop for ReadBuffer<'_> {
#[inline]
fn drop(&mut self) {
*self.decrypted_consumed += if self.used {
self.consumed
} else {
// Consume all if dropped unused
self.data.len()
};
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn dropping_unused_buffer_consumes_all() {
let mut consumed = 1000;
let buffer = [0, 1, 2, 3];
_ = ReadBuffer::new(&buffer, &mut consumed);
assert_eq!(consumed, 1004);
}
#[test]
fn pop_moves_internal_cursor() {
let mut consumed = 0;
let mut buffer = ReadBuffer::new(&[0, 1, 2, 3], &mut consumed);
assert_eq!(buffer.pop(1), &[0]);
assert_eq!(buffer.pop(1), &[1]);
assert_eq!(buffer.pop(1), &[2]);
}
#[test]
fn dropping_consumes_as_many_bytes_as_used() {
let mut consumed = 0;
let mut buffer = ReadBuffer::new(&[0, 1, 2, 3], &mut consumed);
assert_eq!(buffer.pop(1), &[0]);
assert_eq!(buffer.pop(1), &[1]);
assert_eq!(buffer.pop(1), &[2]);
core::mem::drop(buffer);
assert_eq!(consumed, 3);
}
#[test]
fn pop_returns_fewer_bytes_if_requested_more_than_what_it_has() {
let mut consumed = 0;
let mut buffer = ReadBuffer::new(&[0, 1, 2, 3], &mut consumed);
assert_eq!(buffer.pop(1), &[0]);
assert_eq!(buffer.pop(1), &[1]);
assert_eq!(buffer.pop(4), &[2, 3]);
assert_eq!(buffer.pop(1), &[]);
core::mem::drop(buffer);
assert_eq!(consumed, 4);
}
#[test]
fn peek_does_not_consume() {
let mut consumed = 0;
let mut buffer = ReadBuffer::new(&[0, 1, 2, 3], &mut consumed);
assert_eq!(buffer.peek(1), &[0]);
assert_eq!(buffer.peek(1), &[0]);
core::mem::drop(buffer);
assert_eq!(consumed, 0);
}
#[test]
fn revert_undoes_pop() {
let mut consumed = 0;
let mut buffer = ReadBuffer::new(&[0, 1, 2, 3], &mut consumed);
assert_eq!(buffer.pop(4), &[0, 1, 2, 3]);
buffer.revert();
assert_eq!(consumed, 0);
}
}

224
src/record.rs Normal file
View File

@@ -0,0 +1,224 @@
use crate::ProtocolError;
use crate::application_data::ApplicationData;
use crate::change_cipher_spec::ChangeCipherSpec;
use crate::config::{TlsCipherSuite, ConnectConfig};
use crate::content_types::ContentType;
use crate::handshake::client_hello::ClientHello;
use crate::handshake::{ClientHandshake, ServerHandshake};
use crate::key_schedule::WriteKeySchedule;
use crate::{CryptoBackend, buffer::CryptoBuffer};
use crate::{
alert::{Alert, AlertDescription, AlertLevel},
parse_buffer::ParseBuffer,
};
use core::fmt::Debug;
pub type Encrypted = bool;
#[allow(clippy::large_enum_variant)]
pub enum ClientRecord<'config, 'a, CipherSuite>
where
// N: ArrayLength<u8>,
CipherSuite: TlsCipherSuite,
{
Handshake(ClientHandshake<'config, 'a, CipherSuite>, Encrypted),
Alert(Alert, Encrypted),
}
#[derive(Clone, Copy, PartialEq, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ClientRecordHeader {
Handshake(Encrypted),
Alert(Encrypted),
ApplicationData,
}
impl ClientRecordHeader {
pub fn is_encrypted(self) -> bool {
match self {
ClientRecordHeader::Handshake(encrypted) | ClientRecordHeader::Alert(encrypted) => {
encrypted
}
ClientRecordHeader::ApplicationData => true,
}
}
pub fn header_content_type(self) -> ContentType {
match self {
Self::Handshake(false) => ContentType::Handshake,
Self::Alert(false) => ContentType::ChangeCipherSpec,
Self::Handshake(true) | Self::Alert(true) | Self::ApplicationData => {
ContentType::ApplicationData
}
}
}
pub fn trailer_content_type(self) -> ContentType {
match self {
Self::Handshake(_) => ContentType::Handshake,
Self::Alert(_) => ContentType::Alert,
Self::ApplicationData => ContentType::ApplicationData,
}
}
pub fn version(self) -> [u8; 2] {
match self {
Self::Handshake(true) | Self::Alert(true) | Self::ApplicationData => [0x03, 0x03],
Self::Handshake(false) | Self::Alert(false) => [0x03, 0x01],
}
}
pub fn encode(self, buf: &mut CryptoBuffer) -> Result<(), ProtocolError> {
buf.push(self.header_content_type() as u8)
.map_err(|_| ProtocolError::EncodeError)?;
buf.extend_from_slice(&self.version())
.map_err(|_| ProtocolError::EncodeError)?;
Ok(())
}
}
impl<'config, CipherSuite> ClientRecord<'config, '_, CipherSuite>
where
//N: ArrayLength<u8>,
CipherSuite: TlsCipherSuite,
{
pub fn header(&self) -> ClientRecordHeader {
match self {
ClientRecord::Handshake(_, encrypted) => ClientRecordHeader::Handshake(*encrypted),
ClientRecord::Alert(_, encrypted) => ClientRecordHeader::Alert(*encrypted),
}
}
pub fn client_hello<CP>(
config: &'config ConnectConfig<'config>,
provider: &mut CP,
) -> Self
where
CP: CryptoBackend,
{
ClientRecord::Handshake(
ClientHandshake::ClientHello(ClientHello::new(config, provider)),
false,
)
}
pub fn close_notify(opened: bool) -> Self {
ClientRecord::Alert(
Alert::new(AlertLevel::Warning, AlertDescription::CloseNotify),
opened,
)
}
pub(crate) fn encode_payload(&self, buf: &mut CryptoBuffer) -> Result<usize, ProtocolError> {
let record_length_marker = buf.len();
match self {
ClientRecord::Handshake(handshake, _) => handshake.encode(buf)?,
ClientRecord::Alert(alert, _) => alert.encode(buf)?,
};
Ok(buf.len() - record_length_marker)
}
pub fn finish_record(
&self,
buf: &mut CryptoBuffer,
transcript: &mut CipherSuite::Hash,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<(), ProtocolError> {
match self {
ClientRecord::Handshake(handshake, false) => {
handshake.finalize(buf, transcript, write_key_schedule)
}
ClientRecord::Handshake(_, true) => {
ClientHandshake::<CipherSuite>::finalize_encrypted(buf, transcript);
Ok(())
}
_ => Ok(()),
}
}
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[allow(clippy::large_enum_variant)]
pub enum ServerRecord<'a, CipherSuite: TlsCipherSuite> {
Handshake(ServerHandshake<'a, CipherSuite>),
ChangeCipherSpec(ChangeCipherSpec),
Alert(Alert),
ApplicationData(ApplicationData<'a>),
}
pub struct RecordHeader {
header: [u8; 5],
}
impl RecordHeader {
pub const LEN: usize = 5;
pub fn content_type(&self) -> ContentType {
// Content type already validated in read
unwrap!(ContentType::of(self.header[0]))
}
pub fn content_length(&self) -> usize {
// Content length already validated in read
u16::from_be_bytes([self.header[3], self.header[4]]) as usize
}
pub fn data(&self) -> &[u8; 5] {
&self.header
}
pub fn decode(header: [u8; 5]) -> Result<RecordHeader, ProtocolError> {
match ContentType::of(header[0]) {
None => Err(ProtocolError::InvalidRecord),
Some(_) => Ok(RecordHeader { header }),
}
}
}
impl<'a, CipherSuite: TlsCipherSuite> ServerRecord<'a, CipherSuite> {
pub fn content_type(&self) -> ContentType {
match self {
ServerRecord::Handshake(_) => ContentType::Handshake,
ServerRecord::ChangeCipherSpec(_) => ContentType::ChangeCipherSpec,
ServerRecord::Alert(_) => ContentType::Alert,
ServerRecord::ApplicationData(_) => ContentType::ApplicationData,
}
}
pub fn decode(
header: RecordHeader,
data: &'a mut [u8],
digest: &mut CipherSuite::Hash,
) -> Result<ServerRecord<'a, CipherSuite>, ProtocolError> {
assert_eq!(header.content_length(), data.len());
match header.content_type() {
ContentType::Invalid => Err(ProtocolError::Unimplemented),
ContentType::ChangeCipherSpec => Ok(ServerRecord::ChangeCipherSpec(
ChangeCipherSpec::read(data)?,
)),
ContentType::Alert => {
let mut parse = ParseBuffer::new(data);
let alert = Alert::parse(&mut parse)?;
Ok(ServerRecord::Alert(alert))
}
ContentType::Handshake => {
let mut parse = ParseBuffer::new(data);
Ok(ServerRecord::Handshake(ServerHandshake::read(
&mut parse, digest,
)?))
}
ContentType::ApplicationData => {
let buf = CryptoBuffer::wrap_with_pos(data, data.len());
Ok(ServerRecord::ApplicationData(ApplicationData::new(
buf, header,
)))
}
}
}
//pub fn parse<D: Digest>(buf: &[u8]) -> Result<Self, ProtocolError> {}
}

479
src/record_reader.rs Normal file
View File

@@ -0,0 +1,479 @@
use crate::key_schedule::ReadKeySchedule;
use embedded_io::{Error as _, Read as BlockingRead};
use embedded_io_async::Read as AsyncRead;
use crate::{
ProtocolError,
config::TlsCipherSuite,
record::{RecordHeader, ServerRecord},
};
pub struct RecordReader<'a> {
pub(crate) buf: &'a mut [u8],
/// The number of decoded bytes in the buffer
decoded: usize,
/// The number of read but not yet decoded bytes in the buffer
pending: usize,
}
pub struct RecordReaderBorrowMut<'a> {
pub(crate) buf: &'a mut [u8],
/// The number of decoded bytes in the buffer
decoded: &'a mut usize,
/// The number of read but not yet decoded bytes in the buffer
pending: &'a mut usize,
}
impl<'a> RecordReader<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
if buf.len() < 16640 {
warn!("Read buffer is smaller than 16640 bytes, which may cause problems!");
}
Self {
buf,
decoded: 0,
pending: 0,
}
}
pub fn reborrow_mut(&mut self) -> RecordReaderBorrowMut<'_> {
RecordReaderBorrowMut {
buf: self.buf,
decoded: &mut self.decoded,
pending: &mut self.pending,
}
}
pub async fn read<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl AsyncRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
read(
self.buf,
&mut self.decoded,
&mut self.pending,
transport,
key_schedule,
)
.await
}
pub fn read_blocking<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl BlockingRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
read_blocking(
self.buf,
&mut self.decoded,
&mut self.pending,
transport,
key_schedule,
)
}
}
impl RecordReaderBorrowMut<'_> {
pub async fn read<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl AsyncRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
read(
self.buf,
self.decoded,
self.pending,
transport,
key_schedule,
)
.await
}
pub fn read_blocking<'m, CipherSuite: TlsCipherSuite>(
&'m mut self,
transport: &mut impl BlockingRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
read_blocking(
self.buf,
self.decoded,
self.pending,
transport,
key_schedule,
)
}
}
pub async fn read<'m, CipherSuite: TlsCipherSuite>(
buf: &'m mut [u8],
decoded: &mut usize,
pending: &mut usize,
transport: &mut impl AsyncRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
let header: RecordHeader = next_record_header(transport).await?;
advance(buf, decoded, pending, transport, header.content_length()).await?;
consume(
buf,
decoded,
pending,
header,
key_schedule.transcript_hash(),
)
}
pub fn read_blocking<'m, CipherSuite: TlsCipherSuite>(
buf: &'m mut [u8],
decoded: &mut usize,
pending: &mut usize,
transport: &mut impl BlockingRead,
key_schedule: &mut ReadKeySchedule<CipherSuite>,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
let header: RecordHeader = next_record_header_blocking(transport)?;
advance_blocking(buf, decoded, pending, transport, header.content_length())?;
consume(
buf,
decoded,
pending,
header,
key_schedule.transcript_hash(),
)
}
async fn next_record_header(transport: &mut impl AsyncRead) -> Result<RecordHeader, ProtocolError> {
let mut buf: [u8; RecordHeader::LEN] = [0; RecordHeader::LEN];
let mut total_read: usize = 0;
while total_read != RecordHeader::LEN {
let read: usize = transport
.read(&mut buf[total_read..])
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
if read == 0 {
return Err(ProtocolError::IoError);
}
total_read += read;
}
RecordHeader::decode(buf)
}
fn next_record_header_blocking(
transport: &mut impl BlockingRead,
) -> Result<RecordHeader, ProtocolError> {
let mut buf: [u8; RecordHeader::LEN] = [0; RecordHeader::LEN];
let mut total_read: usize = 0;
while total_read != RecordHeader::LEN {
let read: usize = transport
.read(&mut buf[total_read..])
.map_err(|e| ProtocolError::Io(e.kind()))?;
if read == 0 {
return Err(ProtocolError::IoError);
}
total_read += read;
}
RecordHeader::decode(buf)
}
async fn advance(
buf: &mut [u8],
decoded: &mut usize,
pending: &mut usize,
transport: &mut impl AsyncRead,
amount: usize,
) -> Result<(), ProtocolError> {
ensure_contiguous(buf, decoded, pending, amount)?;
let mut remain: usize = amount;
while *pending < amount {
let read = transport
.read(&mut buf[*decoded + *pending..][..remain])
.await
.map_err(|e| ProtocolError::Io(e.kind()))?;
if read == 0 {
return Err(ProtocolError::IoError);
}
remain -= read;
*pending += read;
}
Ok(())
}
fn advance_blocking(
buf: &mut [u8],
decoded: &mut usize,
pending: &mut usize,
transport: &mut impl BlockingRead,
amount: usize,
) -> Result<(), ProtocolError> {
ensure_contiguous(buf, decoded, pending, amount)?;
let mut remain: usize = amount;
while *pending < amount {
let read = transport
.read(&mut buf[*decoded + *pending..][..remain])
.map_err(|e| ProtocolError::Io(e.kind()))?;
if read == 0 {
return Err(ProtocolError::IoError);
}
remain -= read;
*pending += read;
}
Ok(())
}
fn consume<'m, CipherSuite: TlsCipherSuite>(
buf: &'m mut [u8],
decoded: &mut usize,
pending: &mut usize,
header: RecordHeader,
digest: &mut CipherSuite::Hash,
) -> Result<ServerRecord<'m, CipherSuite>, ProtocolError> {
let content_len = header.content_length();
let slice = &mut buf[*decoded..][..content_len];
*decoded += content_len;
*pending -= content_len;
ServerRecord::decode(header, slice, digest)
}
fn ensure_contiguous(
buf: &mut [u8],
decoded: &mut usize,
pending: &mut usize,
len: usize,
) -> Result<(), ProtocolError> {
if *decoded + len > buf.len() {
if len > buf.len() {
error!(
"Record too large for buffer. Size: {} Buffer size: {}",
len,
buf.len()
);
return Err(ProtocolError::InsufficientSpace);
}
buf.copy_within(*decoded..*decoded + *pending, 0);
*decoded = 0;
}
Ok(())
}
#[cfg(test)]
mod tests {
use core::convert::Infallible;
use super::*;
use crate::{Aes128GcmSha256, content_types::ContentType, key_schedule::KeySchedule};
struct ChunkRead<'a>(&'a [u8], usize);
impl embedded_io::ErrorType for ChunkRead<'_> {
type Error = Infallible;
}
impl BlockingRead for ChunkRead<'_> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
let len = usize::min(self.1, buf.len());
let len = usize::min(len, self.0.len());
buf[..len].copy_from_slice(&self.0[..len]);
self.0 = &self.0[len..];
Ok(len)
}
}
#[test]
fn can_read_blocking() {
can_read_blocking_case(1);
can_read_blocking_case(2);
can_read_blocking_case(3);
can_read_blocking_case(4);
can_read_blocking_case(5);
can_read_blocking_case(6);
can_read_blocking_case(7);
can_read_blocking_case(8);
can_read_blocking_case(9);
can_read_blocking_case(10);
can_read_blocking_case(11);
can_read_blocking_case(12);
can_read_blocking_case(13);
can_read_blocking_case(14);
can_read_blocking_case(15);
can_read_blocking_case(16);
}
fn can_read_blocking_case(chunk_size: usize) {
let mut transport = ChunkRead(
&[
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x04,
// Data
0xde,
0xad,
0xbe,
0xef,
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x02,
// Data
0xaa,
0xbb,
],
chunk_size,
);
let mut buf = [0; 32];
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert_eq!([0xde, 0xad, 0xbe, 0xef], data.data.as_slice());
} else {
panic!("Wrong server record");
}
assert_eq!(4, reader.decoded);
assert_eq!(0, reader.pending);
}
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert_eq!([0xaa, 0xbb], data.data.as_slice());
} else {
panic!("Wrong server record");
}
assert_eq!(6, reader.decoded);
assert_eq!(0, reader.pending);
}
}
#[test]
fn can_read_blocking_must_rotate_buffer() {
let mut transport = [
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x04,
// Data
0xde,
0xad,
0xbe,
0xef,
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x02,
// Data
0xaa,
0xbb,
]
.as_slice();
let mut buf = [0; 4]; // cannot contain both data portions
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert_eq!([0xde, 0xad, 0xbe, 0xef], data.data.as_slice());
} else {
panic!("Wrong server record");
}
assert_eq!(4, reader.decoded);
assert_eq!(0, reader.pending);
}
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert_eq!([0xaa, 0xbb], data.data.as_slice());
} else {
panic!("Wrong server record");
}
assert_eq!(2, reader.decoded);
assert_eq!(0, reader.pending);
}
}
#[test]
fn can_read_empty_record() {
let mut transport = [
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x00,
// Header
ContentType::ApplicationData as u8,
0x03,
0x03,
0x00,
0x00,
]
.as_slice();
let mut buf = [0; 32];
let mut reader = RecordReader::new(&mut buf);
let mut key_schedule = KeySchedule::<Aes128GcmSha256>::new();
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert!(data.data.is_empty());
} else {
panic!("Wrong server record");
}
assert_eq!(0, reader.decoded);
assert_eq!(0, reader.pending);
}
{
if let ServerRecord::ApplicationData(data) = reader
.read_blocking(&mut transport, key_schedule.read_state())
.unwrap()
{
assert!(data.data.is_empty());
} else {
panic!("Wrong server record");
}
assert_eq!(0, reader.decoded);
assert_eq!(0, reader.pending);
}
}
}

37
src/send_policy.rs Normal file
View File

@@ -0,0 +1,37 @@
//! Flush policy for TLS sockets.
//!
//! Two strategies are provided:
//! - `Relaxed`: close the TLS encryption buffer and hand the data to the transport
//! delegate without forcing a transport-level flush.
//! - `Strict`: in addition to handing the data to the transport delegate, also
//! request a flush of the transport. For TCP transports this typically means
//! waiting for an ACK (e.g. on embassy TCP sockets) before considering the
//! data fully flushed.
/// Policy controlling how TLS layer flushes encrypted data to the transport.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FlushPolicy {
/// Close the TLS encryption buffer and pass bytes to the transport delegate.
/// Do not force a transport-level flush or wait for an ACK.
Relaxed,
/// In addition to passing bytes to the transport delegate, request a
/// transport-level flush and wait for confirmation (ACK) before returning.
Strict,
}
impl FlushPolicy {
/// Returns true when the transport delegate should be explicitly flushed.
///
/// Relaxed -> false, Strict -> true.
pub fn flush_transport(&self) -> bool {
matches!(self, Self::Strict)
}
}
impl Default for FlushPolicy {
/// Default to `Strict` for compatibility with mote-tls 0.17.0.
fn default() -> Self {
FlushPolicy::Strict
}
}

287
src/write_buffer.rs Normal file
View File

@@ -0,0 +1,287 @@
use crate::{
ProtocolError,
buffer::CryptoBuffer,
config::{TLS_RECORD_OVERHEAD, TlsCipherSuite},
connection::encrypt,
key_schedule::{ReadKeySchedule, WriteKeySchedule},
record::{ClientRecord, ClientRecordHeader},
};
pub struct WriteBuffer<'a> {
buffer: &'a mut [u8],
pos: usize,
current_header: Option<ClientRecordHeader>,
}
pub(crate) struct WriteBufferBorrow<'a> {
buffer: &'a [u8],
pos: &'a usize,
current_header: &'a Option<ClientRecordHeader>,
}
pub(crate) struct WriteBufferBorrowMut<'a> {
buffer: &'a mut [u8],
pos: &'a mut usize,
current_header: &'a mut Option<ClientRecordHeader>,
}
impl<'a> WriteBuffer<'a> {
pub fn new(buffer: &'a mut [u8]) -> Self {
debug_assert!(
buffer.len() > TLS_RECORD_OVERHEAD,
"The write buffer must be sufficiently large to include the tls record overhead"
);
Self {
buffer,
pos: 0,
current_header: None,
}
}
pub(crate) fn reborrow_mut(&mut self) -> WriteBufferBorrowMut<'_> {
WriteBufferBorrowMut {
buffer: self.buffer,
pos: &mut self.pos,
current_header: &mut self.current_header,
}
}
pub(crate) fn reborrow(&self) -> WriteBufferBorrow<'_> {
WriteBufferBorrow {
buffer: self.buffer,
pos: &self.pos,
current_header: &self.current_header,
}
}
pub fn is_full(&self) -> bool {
self.reborrow().is_full()
}
pub fn append(&mut self, buf: &[u8]) -> usize {
self.reborrow_mut().append(buf)
}
pub fn is_empty(&self) -> bool {
self.reborrow().is_empty()
}
pub fn contains(&self, header: ClientRecordHeader) -> bool {
self.reborrow().contains(header)
}
pub(crate) fn start_record(&mut self, header: ClientRecordHeader) -> Result<(), ProtocolError> {
self.reborrow_mut().start_record(header)
}
pub(crate) fn close_record<CipherSuite>(
&mut self,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<&[u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
close_record(
self.buffer,
&mut self.pos,
&mut self.current_header,
write_key_schedule,
)
}
pub fn write_record<CipherSuite>(
&mut self,
record: &ClientRecord<CipherSuite>,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
read_key_schedule: Option<&mut ReadKeySchedule<CipherSuite>>,
) -> Result<&[u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
write_record(
self.buffer,
&mut self.pos,
&mut self.current_header,
record,
write_key_schedule,
read_key_schedule,
)
}
}
impl WriteBufferBorrow<'_> {
fn max_block_size(&self) -> usize {
self.buffer.len() - TLS_RECORD_OVERHEAD
}
pub fn is_full(&self) -> bool {
*self.pos == self.max_block_size()
}
pub fn len(&self) -> usize {
*self.pos
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn space(&self) -> usize {
self.max_block_size() - *self.pos
}
pub fn contains(&self, header: ClientRecordHeader) -> bool {
self.current_header.as_ref() == Some(&header)
}
}
impl WriteBufferBorrowMut<'_> {
fn reborrow(&self) -> WriteBufferBorrow<'_> {
WriteBufferBorrow {
buffer: self.buffer,
pos: self.pos,
current_header: self.current_header,
}
}
pub fn is_full(&self) -> bool {
self.reborrow().is_full()
}
pub fn is_empty(&self) -> bool {
self.reborrow().is_empty()
}
pub fn contains(&self, header: ClientRecordHeader) -> bool {
self.reborrow().contains(header)
}
pub fn append(&mut self, buf: &[u8]) -> usize {
let buffered = usize::min(buf.len(), self.reborrow().space());
if buffered > 0 {
self.buffer[*self.pos..*self.pos + buffered].copy_from_slice(&buf[..buffered]);
*self.pos += buffered;
}
buffered
}
pub(crate) fn start_record(&mut self, header: ClientRecordHeader) -> Result<(), ProtocolError> {
start_record(self.buffer, self.pos, self.current_header, header)
}
pub fn close_record<CipherSuite>(
&mut self,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<&[u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
close_record(
self.buffer,
self.pos,
self.current_header,
write_key_schedule,
)
}
}
fn start_record(
buffer: &mut [u8],
pos: &mut usize,
current_header: &mut Option<ClientRecordHeader>,
header: ClientRecordHeader,
) -> Result<(), ProtocolError> {
debug_assert!(current_header.is_none());
debug!("start_record({:?})", header);
*current_header = Some(header);
with_buffer(buffer, pos, |mut buf| {
header.encode(&mut buf)?;
buf.push_u16(0)?;
Ok(buf.rewind())
})
}
fn with_buffer(
buffer: &mut [u8],
pos: &mut usize,
op: impl FnOnce(CryptoBuffer) -> Result<CryptoBuffer, ProtocolError>,
) -> Result<(), ProtocolError> {
let buf = CryptoBuffer::wrap_with_pos(buffer, *pos);
match op(buf) {
Ok(buf) => {
*pos = buf.len();
Ok(())
}
Err(err) => Err(err),
}
}
fn close_record<'a, CipherSuite>(
buffer: &'a mut [u8],
pos: &mut usize,
current_header: &mut Option<ClientRecordHeader>,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
) -> Result<&'a [u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
const HEADER_SIZE: usize = 5;
let header = current_header.take().unwrap();
with_buffer(buffer, pos, |mut buf| {
if !header.is_encrypted() {
return Ok(buf);
}
buf.push(header.trailer_content_type() as u8)
.map_err(|_| ProtocolError::EncodeError)?;
let mut buf = buf.offset(HEADER_SIZE);
encrypt(write_key_schedule, &mut buf)?;
Ok(buf.rewind())
})?;
let [upper, lower] = ((*pos - HEADER_SIZE) as u16).to_be_bytes();
buffer[3] = upper;
buffer[4] = lower;
let slice = &buffer[..*pos];
*pos = 0;
*current_header = None;
Ok(slice)
}
fn write_record<'a, CipherSuite>(
buffer: &'a mut [u8],
pos: &mut usize,
current_header: &mut Option<ClientRecordHeader>,
record: &ClientRecord<CipherSuite>,
write_key_schedule: &mut WriteKeySchedule<CipherSuite>,
read_key_schedule: Option<&mut ReadKeySchedule<CipherSuite>>,
) -> Result<&'a [u8], ProtocolError>
where
CipherSuite: TlsCipherSuite,
{
if current_header.is_some() {
return Err(ProtocolError::InternalError);
}
start_record(buffer, pos, current_header, record.header())?;
with_buffer(buffer, pos, |buf| {
let mut buf = buf.forward();
record.encode_payload(&mut buf)?;
let transcript = read_key_schedule
.ok_or(ProtocolError::InternalError)?
.transcript_hash();
record.finish_record(&mut buf, transcript, write_key_schedule)?;
Ok(buf.rewind())
})?;
close_record(buffer, pos, current_header, write_key_schedule)
}