Initial commit
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
373
tests/common/mod.rs
Normal file
373
tests/common/mod.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use mio::net::{TcpListener, TcpStream};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::io;
|
||||
use std::io::{BufReader, Read, Write};
|
||||
use std::net;
|
||||
|
||||
// Token for our listening socket.
|
||||
pub const LISTENER: mio::Token = mio::Token(0);
|
||||
|
||||
// Which mode the server operates in.
|
||||
#[derive(Clone)]
|
||||
pub enum ServerMode {
|
||||
/// Write back received bytes
|
||||
Echo,
|
||||
}
|
||||
|
||||
/// This binds together a TCP listening socket, some outstanding
|
||||
/// connections, and a TLS server configuration.
|
||||
pub struct EchoServer {
|
||||
server: TcpListener,
|
||||
connections: HashMap<mio::Token, Connection>,
|
||||
next_id: usize,
|
||||
tls_config: Arc<rustls::ServerConfig>,
|
||||
mode: ServerMode,
|
||||
}
|
||||
|
||||
impl EchoServer {
|
||||
pub fn new(
|
||||
server: TcpListener,
|
||||
mode: ServerMode,
|
||||
cfg: Arc<rustls::ServerConfig>,
|
||||
) -> EchoServer {
|
||||
EchoServer {
|
||||
server,
|
||||
connections: HashMap::new(),
|
||||
next_id: 2,
|
||||
tls_config: cfg,
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn accept(&mut self, registry: &mio::Registry) -> Result<(), io::Error> {
|
||||
loop {
|
||||
match self.server.accept() {
|
||||
Ok((socket, addr)) => {
|
||||
log::debug!("Accepting new connection from {:?}", addr);
|
||||
|
||||
let tls_session =
|
||||
rustls::ServerConnection::new(self.tls_config.clone()).unwrap();
|
||||
let mode = self.mode.clone();
|
||||
|
||||
let token = mio::Token(self.next_id);
|
||||
self.next_id += 1;
|
||||
|
||||
let mut connection = Connection::new(socket, token, mode, tls_session);
|
||||
connection.register(registry);
|
||||
self.connections.insert(token, connection);
|
||||
}
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Ok(()),
|
||||
Err(err) => {
|
||||
println!(
|
||||
"encountered error while accepting connection; err={:?}",
|
||||
err
|
||||
);
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conn_event(&mut self, registry: &mio::Registry, event: &mio::event::Event) {
|
||||
let token = event.token();
|
||||
|
||||
if self.connections.contains_key(&token) {
|
||||
self.connections
|
||||
.get_mut(&token)
|
||||
.unwrap()
|
||||
.ready(registry, event);
|
||||
|
||||
if self.connections[&token].is_closed() {
|
||||
self.connections.remove(&token);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// This is a connection which has been accepted by the server,
|
||||
/// and is currently being served.
|
||||
///
|
||||
/// It has a TCP-level stream, a TLS-level session, and some
|
||||
/// other state/metadata.
|
||||
struct Connection {
|
||||
socket: TcpStream,
|
||||
token: mio::Token,
|
||||
closing: bool,
|
||||
closed: bool,
|
||||
mode: ServerMode,
|
||||
tls_session: rustls::ServerConnection,
|
||||
back: Option<TcpStream>,
|
||||
}
|
||||
|
||||
/// Open a plaintext TCP-level connection for forwarded connections.
|
||||
fn open_back(_mode: &ServerMode) -> Option<TcpStream> {
|
||||
None
|
||||
}
|
||||
|
||||
/// This used to be conveniently exposed by mio: map EWOULDBLOCK
|
||||
/// errors to something less-errory.
|
||||
fn try_read(r: io::Result<usize>) -> io::Result<Option<usize>> {
|
||||
match r {
|
||||
Ok(len) => Ok(Some(len)),
|
||||
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
fn new(
|
||||
socket: TcpStream,
|
||||
token: mio::Token,
|
||||
mode: ServerMode,
|
||||
tls_session: rustls::ServerConnection,
|
||||
) -> Connection {
|
||||
let back = open_back(&mode);
|
||||
Connection {
|
||||
socket,
|
||||
token,
|
||||
closing: false,
|
||||
closed: false,
|
||||
mode,
|
||||
tls_session,
|
||||
back,
|
||||
}
|
||||
}
|
||||
|
||||
/// We're a connection, and we have something to do.
|
||||
fn ready(&mut self, registry: &mio::Registry, ev: &mio::event::Event) {
|
||||
if ev.is_readable() {
|
||||
self.do_tls_read();
|
||||
self.try_plain_read();
|
||||
self.try_back_read();
|
||||
}
|
||||
|
||||
if ev.is_writable() {
|
||||
self.do_tls_write_and_handle_error();
|
||||
}
|
||||
|
||||
if self.closing {
|
||||
let _ = self.socket.shutdown(net::Shutdown::Both);
|
||||
self.close_back();
|
||||
self.closed = true;
|
||||
self.deregister(registry);
|
||||
} else {
|
||||
self.reregister(registry);
|
||||
}
|
||||
}
|
||||
|
||||
fn close_back(&mut self) {
|
||||
if self.back.is_some() {
|
||||
let back = self.back.as_mut().unwrap();
|
||||
back.shutdown(net::Shutdown::Both).unwrap();
|
||||
}
|
||||
self.back = None;
|
||||
}
|
||||
|
||||
fn do_tls_read(&mut self) {
|
||||
let rc = self.tls_session.read_tls(&mut self.socket);
|
||||
if rc.is_err() {
|
||||
let err = rc.unwrap_err();
|
||||
if let io::ErrorKind::WouldBlock = err.kind() {
|
||||
return;
|
||||
}
|
||||
log::warn!("read error {:?}", err);
|
||||
self.closing = true;
|
||||
return;
|
||||
}
|
||||
if rc.unwrap() == 0 {
|
||||
log::debug!("eof");
|
||||
self.closing = true;
|
||||
return;
|
||||
}
|
||||
let processed = self.tls_session.process_new_packets();
|
||||
if processed.is_err() {
|
||||
log::warn!("cannot process packet: {:?}", processed);
|
||||
self.do_tls_write_and_handle_error();
|
||||
self.closing = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn try_plain_read(&mut self) {
|
||||
let mut buf = Vec::new();
|
||||
let rc = self.tls_session.reader().read_to_end(&mut buf);
|
||||
if let Err(ref e) = rc {
|
||||
if e.kind() != io::ErrorKind::WouldBlock {
|
||||
log::warn!("plaintext read failed: {:?}", rc);
|
||||
self.closing = true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
log::debug!("plaintext read {:?}", buf.len());
|
||||
self.incoming_plaintext(&buf);
|
||||
}
|
||||
}
|
||||
|
||||
fn try_back_read(&mut self) {
|
||||
if self.back.is_none() {
|
||||
return;
|
||||
}
|
||||
let mut buf = [0u8; 1024];
|
||||
let back = self.back.as_mut().unwrap();
|
||||
let rc = try_read(back.read(&mut buf));
|
||||
if rc.is_err() {
|
||||
log::warn!("backend read failed: {:?}", rc);
|
||||
self.closing = true;
|
||||
return;
|
||||
}
|
||||
let maybe_len = rc.unwrap();
|
||||
match maybe_len {
|
||||
Some(0) => {
|
||||
log::debug!("back eof");
|
||||
self.closing = true;
|
||||
}
|
||||
Some(len) => {
|
||||
self.tls_session.writer().write_all(&buf[..len]).unwrap();
|
||||
}
|
||||
None => {}
|
||||
};
|
||||
}
|
||||
|
||||
fn incoming_plaintext(&mut self, buf: &[u8]) {
|
||||
match self.mode {
|
||||
ServerMode::Echo => {
|
||||
self.tls_session.writer().write_all(buf).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn tls_write(&mut self) -> io::Result<usize> {
|
||||
self.tls_session.write_tls(&mut self.socket)
|
||||
}
|
||||
|
||||
fn do_tls_write_and_handle_error(&mut self) {
|
||||
let rc = self.tls_write();
|
||||
if rc.is_err() {
|
||||
log::warn!("write failed {:?}", rc);
|
||||
self.closing = true;
|
||||
}
|
||||
}
|
||||
|
||||
fn register(&mut self, registry: &mio::Registry) {
|
||||
let event_set = self.event_set();
|
||||
registry
|
||||
.register(&mut self.socket, self.token, event_set)
|
||||
.unwrap();
|
||||
if self.back.is_some() {
|
||||
registry
|
||||
.register(
|
||||
self.back.as_mut().unwrap(),
|
||||
self.token,
|
||||
mio::Interest::READABLE,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn reregister(&mut self, registry: &mio::Registry) {
|
||||
let event_set = self.event_set();
|
||||
registry
|
||||
.reregister(&mut self.socket, self.token, event_set)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn deregister(&mut self, registry: &mio::Registry) {
|
||||
registry.deregister(&mut self.socket).unwrap();
|
||||
if self.back.is_some() {
|
||||
registry.deregister(self.back.as_mut().unwrap()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn event_set(&self) -> mio::Interest {
|
||||
let rd = self.tls_session.wants_read();
|
||||
let wr = self.tls_session.wants_write();
|
||||
if rd && wr {
|
||||
mio::Interest::READABLE | mio::Interest::WRITABLE
|
||||
} else if wr {
|
||||
mio::Interest::WRITABLE
|
||||
} else {
|
||||
mio::Interest::READABLE
|
||||
}
|
||||
}
|
||||
|
||||
fn is_closed(&self) -> bool {
|
||||
self.closed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_certs(filename: &PathBuf) -> Vec<rustls::Certificate> {
|
||||
let certfile = fs::File::open(filename).expect("cannot open certificate file");
|
||||
let mut reader = BufReader::new(certfile);
|
||||
rustls_pemfile::certs(&mut reader)
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|v| rustls::Certificate(v.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn load_private_key(filename: &PathBuf) -> rustls::PrivateKey {
|
||||
let keyfile = fs::File::open(filename).expect("cannot open private key file");
|
||||
let mut reader = BufReader::new(keyfile);
|
||||
loop {
|
||||
match rustls_pemfile::read_one(&mut reader).expect("cannot parse private key .pem file") {
|
||||
Some(rustls_pemfile::Item::RSAKey(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::PKCS8Key(key)) => return rustls::PrivateKey(key),
|
||||
Some(rustls_pemfile::Item::ECKey(key)) => return rustls::PrivateKey(key),
|
||||
None => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
panic!(
|
||||
"no keys found in {:?} (encrypted keys not supported)",
|
||||
filename
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn run(listener: TcpListener) {
|
||||
let versions = &[&rustls::version::TLS13];
|
||||
let test_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests");
|
||||
let certs = load_certs(&test_dir.join("fixtures").join("leaf-server.pem"));
|
||||
let privkey = load_private_key(&test_dir.join("fixtures").join("leaf-server-key.pem"));
|
||||
let config = rustls::ServerConfig::builder()
|
||||
.with_cipher_suites(rustls::ALL_CIPHER_SUITES)
|
||||
.with_kx_groups(&rustls::ALL_KX_GROUPS)
|
||||
.with_protocol_versions(versions)
|
||||
.unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, privkey)
|
||||
.unwrap();
|
||||
run_with_config(listener, config)
|
||||
}
|
||||
|
||||
pub fn run_with_config(mut listener: TcpListener, config: rustls::ServerConfig) {
|
||||
let mut poll = mio::Poll::new().unwrap();
|
||||
poll.registry()
|
||||
.register(&mut listener, LISTENER, mio::Interest::READABLE)
|
||||
.unwrap();
|
||||
let mut tlsserv = EchoServer::new(listener, ServerMode::Echo, Arc::new(config));
|
||||
let mut events = mio::Events::with_capacity(256);
|
||||
loop {
|
||||
if let Err(e) = poll.poll(&mut events, None) {
|
||||
if e.kind() == std::io::ErrorKind::Interrupted {
|
||||
log::debug!("I/O error {:?}", e);
|
||||
continue;
|
||||
}
|
||||
panic!("I/O error {:?}", e);
|
||||
}
|
||||
for event in events.iter() {
|
||||
match event.token() {
|
||||
LISTENER => {
|
||||
tlsserv
|
||||
.accept(poll.registry())
|
||||
.expect("error accepting socket");
|
||||
}
|
||||
_ => tlsserv.conn_event(poll.registry(), event),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user