diff --git a/tool/client.cc b/tool/client.cc index b9f1c13b..39cb7f0f 100644 --- a/tool/client.cc +++ b/tool/client.cc @@ -16,6 +16,12 @@ #include +#if !defined(OPENSSL_WINDOWS) +#include +#else +#include +#endif + #include #include #include @@ -92,6 +98,10 @@ static const struct argument kArguments[] = { "-grease", kBooleanArgument, "Enable GREASE", }, + { + "-resume", kBooleanArgument, + "Establish a second connection resuming the original connection.", + }, { "", kOptionalArgument, "", }, @@ -122,6 +132,7 @@ static void KeyLogCallback(const SSL *ssl, const char *line) { } static bssl::UniquePtr session_out; +static bssl::UniquePtr resume_session; static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) { if (session_out) { @@ -132,8 +143,105 @@ static int NewSessionCallback(SSL *ssl, SSL_SESSION *session) { return 0; } } + resume_session = bssl::UniquePtr(session); + return 1; +} + +static bool WaitForSession(SSL *ssl, int sock) { + fd_set read_fds; + FD_ZERO(&read_fds); - return 0; + if (!SocketSetNonBlocking(sock, true)) { + return false; + } + + while (!resume_session) { + FD_SET(sock, &read_fds); + int ret = select(sock + 1, &read_fds, NULL, NULL, NULL); + if (ret <= 0) { + perror("select"); + return false; + } + + uint8_t buffer[512]; + int ssl_ret = SSL_read(ssl, buffer, sizeof(buffer)); + + if (ssl_ret <= 0) { + int ssl_err = SSL_get_error(ssl, ssl_ret); + if (ssl_err == SSL_ERROR_WANT_READ) { + continue; + } + fprintf(stderr, "Error while reading: %d\n", ssl_err); + ERR_print_errors_cb(PrintErrorCallback, stderr); + return false; + } + } + + return true; +} + +static bool DoConnection(SSL_CTX *ctx, + std::map args_map, + bool (*cb)(SSL *ssl, int sock)) { + int sock = -1; + if (!Connect(&sock, args_map["-connect"])) { + return false; + } + + if (args_map.count("-starttls") != 0) { + const std::string& starttls = args_map["-starttls"]; + if (starttls == "smtp") { + if (!DoSMTPStartTLS(sock)) { + return false; + } + } else { + fprintf(stderr, "Unknown value for -starttls: %s\n", starttls.c_str()); + return false; + } + } + + bssl::UniquePtr bio(BIO_new_socket(sock, BIO_CLOSE)); + bssl::UniquePtr ssl(SSL_new(ctx)); + + if (args_map.count("-server-name") != 0) { + SSL_set_tlsext_host_name(ssl.get(), args_map["-server-name"].c_str()); + } + + if (args_map.count("-session-in") != 0) { + bssl::UniquePtr in(BIO_new_file(args_map["-session-in"].c_str(), + "rb")); + if (!in) { + fprintf(stderr, "Error reading session\n"); + ERR_print_errors_cb(PrintErrorCallback, stderr); + return false; + } + bssl::UniquePtr session(PEM_read_bio_SSL_SESSION(in.get(), + nullptr, nullptr, nullptr)); + if (!session) { + fprintf(stderr, "Error reading session\n"); + ERR_print_errors_cb(PrintErrorCallback, stderr); + return false; + } + SSL_set_session(ssl.get(), session.get()); + } else if (resume_session) { + SSL_set_session(ssl.get(), resume_session.get()); + } + + SSL_set_bio(ssl.get(), bio.get(), bio.get()); + bio.release(); + + int ret = SSL_connect(ssl.get()); + if (ret != 1) { + int ssl_err = SSL_get_error(ssl.get(), ret); + fprintf(stderr, "Error while connecting: %d\n", ssl_err); + ERR_print_errors_cb(PrintErrorCallback, stderr); + return false; + } + + fprintf(stderr, "Connected.\n"); + PrintConnectionInfo(ssl.get()); + + return cb(ssl.get(), sock); } bool Client(const std::vector &args) { @@ -261,6 +369,9 @@ bool Client(const std::vector &args) { } } + SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_CLIENT); + SSL_CTX_sess_set_new_cb(ctx.get(), NewSessionCallback); + if (args_map.count("-session-out") != 0) { session_out.reset(BIO_new_file(args_map["-session-out"].c_str(), "wb")); if (!session_out) { @@ -269,71 +380,16 @@ bool Client(const std::vector &args) { ERR_print_errors_cb(PrintErrorCallback, stderr); return false; } - SSL_CTX_set_session_cache_mode(ctx.get(), SSL_SESS_CACHE_CLIENT); - SSL_CTX_sess_set_new_cb(ctx.get(), NewSessionCallback); } if (args_map.count("-grease") != 0) { SSL_CTX_set_grease_enabled(ctx.get(), 1); } - int sock = -1; - if (!Connect(&sock, args_map["-connect"])) { - return false; - } - - if (args_map.count("-starttls") != 0) { - const std::string& starttls = args_map["-starttls"]; - if (starttls == "smtp") { - if (!DoSMTPStartTLS(sock)) { - return false; - } - } else { - fprintf(stderr, "Unknown value for -starttls: %s\n", starttls.c_str()); - return false; - } - } - - bssl::UniquePtr bio(BIO_new_socket(sock, BIO_CLOSE)); - bssl::UniquePtr ssl(SSL_new(ctx.get())); - - if (args_map.count("-server-name") != 0) { - SSL_set_tlsext_host_name(ssl.get(), args_map["-server-name"].c_str()); - } - - if (args_map.count("-session-in") != 0) { - bssl::UniquePtr in(BIO_new_file(args_map["-session-in"].c_str(), - "rb")); - if (!in) { - fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); - return false; - } - bssl::UniquePtr session(PEM_read_bio_SSL_SESSION(in.get(), - nullptr, nullptr, nullptr)); - if (!session) { - fprintf(stderr, "Error reading session\n"); - ERR_print_errors_cb(PrintErrorCallback, stderr); - return false; - } - SSL_set_session(ssl.get(), session.get()); - } - - SSL_set_bio(ssl.get(), bio.get(), bio.get()); - bio.release(); - - int ret = SSL_connect(ssl.get()); - if (ret != 1) { - int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); + if (args_map.count("-resume") != 0 && + !DoConnection(ctx.get(), args_map, &WaitForSession)) { return false; } - fprintf(stderr, "Connected.\n"); - PrintConnectionInfo(ssl.get()); - - bool ok = TransferData(ssl.get(), sock); - - return ok; + return DoConnection(ctx.get(), args_map, &TransferData); } diff --git a/tool/server.cc b/tool/server.cc index a049422e..94abbbd9 100644 --- a/tool/server.cc +++ b/tool/server.cc @@ -24,33 +24,37 @@ static const struct argument kArguments[] = { { - "-accept", kRequiredArgument, - "The port of the server to bind on; eg 45102", + "-accept", kRequiredArgument, + "The port of the server to bind on; eg 45102", }, { - "-cipher", kOptionalArgument, - "An OpenSSL-style cipher suite string that configures the offered ciphers", + "-cipher", kOptionalArgument, + "An OpenSSL-style cipher suite string that configures the offered " + "ciphers", }, { - "-max-version", kOptionalArgument, - "The maximum acceptable protocol version", + "-max-version", kOptionalArgument, + "The maximum acceptable protocol version", }, { - "-min-version", kOptionalArgument, - "The minimum acceptable protocol version", + "-min-version", kOptionalArgument, + "The minimum acceptable protocol version", }, { - "-key", kOptionalArgument, - "PEM-encoded file containing the private key, leaf certificate and " - "optional certificate chain. A self-signed certificate is generated " - "at runtime if this argument is not provided.", + "-key", kOptionalArgument, + "PEM-encoded file containing the private key, leaf certificate and " + "optional certificate chain. A self-signed certificate is generated " + "at runtime if this argument is not provided.", }, { - "-ocsp-response", kOptionalArgument, - "OCSP response file to send", + "-ocsp-response", kOptionalArgument, "OCSP response file to send", }, { - "", kOptionalArgument, "", + "-loop", kBooleanArgument, + "The server will continue accepting new sequential connections.", + }, + { + "", kOptionalArgument, "", }, }; @@ -219,25 +223,30 @@ bool Server(const std::vector &args) { return false; } - int sock = -1; - if (!Accept(&sock, args_map["-accept"])) { - return false; - } + bool result = true; + do { + int sock = -1; + if (!Accept(&sock, args_map["-accept"])) { + return false; + } - BIO *bio = BIO_new_socket(sock, BIO_CLOSE); - bssl::UniquePtr ssl(SSL_new(ctx.get())); - SSL_set_bio(ssl.get(), bio, bio); + BIO *bio = BIO_new_socket(sock, BIO_CLOSE); + bssl::UniquePtr ssl(SSL_new(ctx.get())); + SSL_set_bio(ssl.get(), bio, bio); - int ret = SSL_accept(ssl.get()); - if (ret != 1) { - int ssl_err = SSL_get_error(ssl.get(), ret); - fprintf(stderr, "Error while connecting: %d\n", ssl_err); - ERR_print_errors_cb(PrintErrorCallback, stderr); - return false; - } + int ret = SSL_accept(ssl.get()); + if (ret != 1) { + int ssl_err = SSL_get_error(ssl.get(), ret); + fprintf(stderr, "Error while connecting: %d\n", ssl_err); + ERR_print_errors_cb(PrintErrorCallback, stderr); + return false; + } + + fprintf(stderr, "Connected.\n"); + PrintConnectionInfo(ssl.get()); - fprintf(stderr, "Connected.\n"); - PrintConnectionInfo(ssl.get()); + result = TransferData(ssl.get(), sock); + } while (result && args_map.count("-loop") != 0); - return TransferData(ssl.get(), sock); + return result; }