diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..43a441f --- /dev/null +++ b/Makefile @@ -0,0 +1,26 @@ +EXECUTABLE = server +EXECUTABLECLI = client + +CC=g++ -lboost_system -lboost_thread +CFLAGS=-Wall -ggdb -DDEBUG -Wreorder +COMPILE=$(CC) $(CFLAGS) + +all: server client + +server: server.o ssl_process.o + $(CC) -lpthread -lcrypto -lssl ssl_process.o server.o -o $(EXECUTABLE) + +client: client.o ssl_process.o + $(CC) -lpthread -lcrypto -lssl -o $(EXECUTABLECLI) ssl_process.o client.o + +ssl_process.o: ssl_process.cpp + $(COMPILE) -o ssl_process.o -c ssl_process.cpp + +server.o: server.cpp + $(COMPILE) -o server.o -c server.cpp + +client.o: client.cpp + $(COMPILE) -o client.o -c client.cpp + +clean: + rm -rf *.o server client diff --git a/client.cpp b/client.cpp new file mode 100644 index 0000000..3971f25 --- /dev/null +++ b/client.cpp @@ -0,0 +1,196 @@ +/******************************************************************************/ +/** + \Author Krzysztof Kwiatkowski + \File client.cpp + \Description The SSL client which connects to the server.cpp + and initiates renegotitaion after RENEG_INIT_LEN + chars exchanged with the server + +*******************************************************************************/ + +#include "client.h" +#include +#include "defs.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace boost; + +int Handler = 0; +SSL* SSLHandler = 0; +int CharsRead = 0; + +// with this you can block sender thread during renegotiation +mutex WriteReadMutex; + +void Sender() +{ + while(1) + { + string buf(EXCHANGE_STRING); + int len = 0; + SSL_write(SSLHandler, buf.c_str()+len, buf.size()-len); +/* + do + { + lock_guard lock(WriteReadMutex); + len+=SSL_write(SSLHandler, buf.c_str()+len, buf.size()-len); + // for debugging re-neg + // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl; + } while( len != static_cast(buf.size()) ); +*/ + } +}; + +void Client::receive() +{ + char buf[MAX_PACKET_SIZE]; + + // TODO: this way it takes 100% CPU, some signal would be usefull + memset(buf,'\0',MAX_PACKET_SIZE); + int len_rcv = 0; + { + lock_guard lock(WriteReadMutex); + cout << "A" << endl; + len_rcv = SSL_read(SSLHandler, buf, MAX_PACKET_SIZE); + } + + if( len_rcv != 0 ) + { + CharsRead += len_rcv; + + // dirty thing - if it has \n on the end - remove it + if( buf[len_rcv-1] == '\n' ) + buf[len_rcv-1] = '\0'; + + cout << buf << endl; + } + else + { + cout << "Closing connection " << _handler << endl; + ::close(_handler); + } +} + + +void Client::connect() +{ + lock_guard lock(WriteReadMutex); + + struct sockaddr_in echoserver; + int sock = socket(AF_INET, SOCK_STREAM, 0); + memset(&echoserver, 0, sizeof(echoserver)); + echoserver.sin_family = AF_INET; + echoserver.sin_addr.s_addr = inet_addr(IP); + echoserver.sin_port = htons(PORT); + + /* Establish connection */ + if ( 0 > ::connect(sock, (struct sockaddr *) &echoserver, sizeof(echoserver)) ) + { + throw runtime_error("Can't connect to the server"); + } + Handler = sock; + SSLHandler = SSL_new(_ctx); + SSL_set_fd(SSLHandler, Handler); + if( SSL_connect(SSLHandler) <= 0) + { + cerr << "Can't setup SSL session" << endl; + exit(1); + } + fcntl(sock, F_SETFL, O_NONBLOCK); +} + +void Client::start() +{ + // start sender thread first + _sender =new thread( Sender ); + + struct timeval tv; + + // go to select loop + while(1) + { + // wait timer for select + tv.tv_sec = 0; + tv.tv_usec = 10; + + fd_set fd_read; + + FD_ZERO(&fd_read); + FD_SET(_handler, &fd_read); + select(_handler+1, &fd_read, NULL, NULL, (struct timeval *)&tv); + if( FD_ISSET(_handler, &fd_read ) ) + { + // this should be in other thread but... it works + if( CharsRead > RENEG_INIT_LEN ) + { + CharsRead = 0; + renegotiate(); + } + receive(); + } + } + +} + +void Client::renegotiate() +{ + lock_guard lock_reads(WriteReadMutex); + cout << "B" << endl; + + cout << "Starting SSL renegotiation on SSL" + << "client (initiating by SSL client)" << endl; + + cout << "SSL State: " << SSL_state_string(SSLHandler) << endl; + if(SSL_renegotiate(SSLHandler) <= 0){ + cerr << "SSL_renegotiate() failed. STATE: " + << SSL_state_string(SSLHandler) << endl; + ERR_print_errors_fp(stderr); + exit(1); + } + + cout << "SSL State: " << SSL_state_string(SSLHandler) << endl; + if(SSL_do_handshake(SSLHandler) <= 0){ + cerr << "SSL_do_handshake() failed. STATE: " + << SSL_state_string(SSLHandler) << endl; + ERR_print_errors_fp(stderr); + exit(1); + } +} + +void Client::init() +{ + sslInit(); +} + +// --- MAIN --- // +int main() +{ + try + { + Client client; + client.init(); + client.connect(); + client.start(); + } + catch(std::runtime_error& e) + { + cerr << "ERROR " << e.what() << endl; + } + catch(...) + { + cerr << "Unknown exception" << endl; + } + return 0; +} diff --git a/client.h b/client.h new file mode 100644 index 0000000..e96f8f3 --- /dev/null +++ b/client.h @@ -0,0 +1,25 @@ +#ifndef _SSLCLIENT_H_ +#define _SSLCLIENT_H_ +#include +#include +#include "ssl_process.h" + +class Client : public SSLProcess +{ + boost::thread* _sender; + boost::thread* _receiver; + int _handler; + + void receive(); + void renegotiate(); + +public: + Client() + : SSLProcess(false){}; + + void init(); + void connect(); + void start(); +}; + +#endif /* _SSLCLIENT_H_ */ diff --git a/defs.h b/defs.h new file mode 100644 index 0000000..1bc8594 --- /dev/null +++ b/defs.h @@ -0,0 +1,9 @@ +#define MAX_PACKET_SIZE 1024 +#define PORT 1420 +#define IP "127.0.0.1" +#define EXCHANGE_STRING "ABCDEFGHIJKLMNOPRSTUWXYZ" +#define EXCHANGE_STRING_LEN sizeof(EXCHANGE_STRING)/sizeof(EXCHANGE_STRING[0]) +#define RENEG_INIT_LEN 200 +#define CERTIFICATE_FILE "etc/cert" +#define PRIVATE_KEY_FILE "etc/pkey" +#define SEND_ITERATIONS 100000 diff --git a/server.cpp b/server.cpp new file mode 100644 index 0000000..514ab76 --- /dev/null +++ b/server.cpp @@ -0,0 +1,303 @@ +#include "server.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "defs.h" + +using namespace std; +using namespace boost; + +typedef pair SocketSSLHandles_t; +SocketSSLHandles_t WriteHandler(0,0); + +// Keeps socket handles that are ready to be read +boost::queue ReadQueue; + +// Socket Set is a set that keeps sockets on which we can 'select()' +typedef set SocketSet_t; +SocketSet_t SocketSet; +mutex SocketSetMutex; + +// WaitForWrite is a condition variable that is signaled when Sender must start sending the data +condition_variable WaitForWrite; +mutex WaitForWriteMutex; +mutex WriteReadMutex; + +void Receive(); +void Send(); + +Server::Server() + : SSLProcess(true) + , _master(0) +{ +} + +void Server::startListen(void) { + struct sockaddr_in local_address; + + _master = ::socket(PF_INET, SOCK_STREAM, 0); + memset(&local_address, 0, sizeof(local_address)); + + local_address.sin_family = AF_INET; + local_address.sin_port = htons(PORT); + local_address.sin_addr.s_addr = INADDR_ANY; + + int reuseval = 1; + setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval)); + + // set socket non-blocking + fcntl(_master, F_SETFL, O_NONBLOCK); + + // Bind to the socket + if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0) + throw runtime_error("Couldn't bind to local port"); + + // Set a limit on connection queue. + if(::listen(_master, 5) != 0) + throw runtime_error("Not possible to get into listen state"); +} + +void Server::start() +{ + if( !_ctx ) + throw runtime_error("SSL not initialized"); + + startListen(); + + Acceptor ac(_master, _ctx); + _sender =new thread( Send ); + _reciver =new thread( Receive ); + _reactor =new thread( ac ); + + +} + +void Server::init() +{ + sslInit(); + doServerSSLInit(); +} + +void Server::doServerSSLInit() +{ + // Load certificate & private key + if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) { + ERR_print_errors_fp(stderr); + _exit(1); + } + + if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) { + ERR_print_errors_fp(stderr); + _exit(1); + } + + // Verify if public-private keypair matches + if ( !SSL_CTX_check_private_key(_ctx) ) { + fprintf(stderr, "Private key is invalid.\n"); + _exit(1); + } +} + +void Acceptor::operator()() +{ + cout << "Entering acceptor loop..." << endl; + while(1) + { + // wait timer for select + struct timeval tv; + tv.tv_sec = 0; + tv.tv_usec = 10; + + fd_set fd_read; + + // set fd_sets + FD_ZERO(&fd_read); + + // set max fd for select + int maxv = _master; + FD_SET(_master, &fd_read); + + // add all the sockets to sets + { + lock_guard guard(SocketSetMutex); + for(SocketSet_t::const_iterator aIt=SocketSet.begin(); + aIt!=SocketSet.end(); ++aIt) + { + FD_SET(aIt->first, &fd_read); + if(aIt->first > maxv) + maxv=aIt->first; + } + } + + // wait in select now + select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv); + { + lock_guard guard(SocketSetMutex); + + // check if you can read + SocketSet_t::const_iterator aIt=SocketSet.begin(); + while(aIt!=SocketSet.end()) + { + + SocketSSLHandles_t aTmpHandles = *aIt; + aIt++; + if( FD_ISSET(aTmpHandles.first, &fd_read ) ) + { + // you need to erase tmpHd from SocketSet - otherwise it will + // be ready to read until SSL_read is not called on it + SocketSet.erase(aTmpHandles); + ReadQueue.push(aTmpHandles); + } + } + } + + // if master is in fd_read - then it means new connection req + // has arrived + if( FD_ISSET(_master, &fd_read) ) + { + int new_fd=openTCPSocket(); + if( new_fd >= 0 ) + { + cout << "New socket with ID : " << new_fd + << " is going to be added to map" << endl; + SSL* ssl = openSSLSession(new_fd); + lock_guard guard(SocketSetMutex); + SocketSet.insert(make_pair(new_fd, ssl)); + } + } + } +} + +int Acceptor::openTCPSocket() +{ + // Open up new connection + cout << "New connection has arrived" << endl; + struct sockaddr_in addr; + int len = sizeof(addr); + int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len); + if(client == -1) + perror("accept"); + return client; +} + +SSL* Acceptor::openSSLSession(int iTCPHandle) +{ + SSL *ssl = (SSL*) SSL_new(_ctx); + SSL_set_fd(ssl, iTCPHandle); + + // normally this would be in other thread + if(SSL_accept(ssl) == -1) { + ERR_print_errors_fp(stderr); + throw runtime_error("Can't SSL_accept => can't continue"); + } + return ssl; +} + +void Receive() +{ + while(1) + { + char buf[1024]; + SocketSSLHandles_t handler; + + // TO-DO: this way it takes 100% CPU, some signal would be usefull + while (!ReadQueue.empty()) + { + handler = ReadQueue.front(); + ReadQueue.pop(); + + memset(buf,'\0',1024); + int len_rcv=0; + cout << SSL_state_string(handler.second) << endl; + { + lock_guard lock(WriteReadMutex); + len_rcv = SSL_read(handler.second, buf, 1024); + } + + if( len_rcv ) + { + // dirty thing - if it has \n on the end - remove it + if( buf[len_rcv-1] == '\n' ) + buf[len_rcv-1] = '\0'; + + cout << buf << endl; + + // add it back to the socket so that select can use it + lock_guard guard(SocketSetMutex); + SocketSet.insert(handler); + + // push handler ID and notify sender thread + WriteHandler = handler; + WaitForWrite.notify_one(); + } + else + { + cout << "Closing connection " << handler.first << endl; + ::close(handler.first); + } + } + } +} + +void Send() +{ + while(1) + { + SocketSSLHandles_t handler(0,0); + { + unique_lock lock(WaitForWriteMutex); + WaitForWrite.wait(lock); + handler = WriteHandler; + } + + cout << "Writing to handler " << handler.first << endl; + string buf(EXCHANGE_STRING); + for(int i=0; i lock(WriteReadMutex); + len+=SSL_write(handler.second, buf.c_str()+len, buf.size()-len); + // for debugging re-neg + // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl; + } while( len != static_cast(buf.size()) ); + } + } +} + + +void Server::waitForStop() +{ + _sender->join(); + _reciver->join(); + _reactor->join(); +} + +Server::~Server() +{ + delete _sender; + delete _reciver; + delete _reactor; +} + +/// --- MAIN --- /// +int main() { + Server server; + server.init(); + server.start(); + server.waitForStop(); + + return 0; +} diff --git a/server.h b/server.h new file mode 100644 index 0000000..5480e73 --- /dev/null +++ b/server.h @@ -0,0 +1,58 @@ +#ifndef __SSL_SERVER_H__ +#define __SSL_SERVER_H__ + +#include "ssl_process.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +struct Acceptor +{ + int _master; + SSL_CTX* _ctx; + + Acceptor(int iMasterHd, SSL_CTX* iCtx) + : _master(iMasterHd) + , _ctx(iCtx) + { } + + void operator()(); + int openTCPSocket(); + SSL* openSSLSession(int); + +}; + +class Server : public SSLProcess +{ + + // Private data +private: + int _master; + + void startListen(); + void doServerSSLInit(); + + boost::thread* _sender; + boost::thread* _reciver; + boost::thread* _reactor; + +public: + // Constructors + Server(); + ~Server(); + + void init(); + void start(); + void waitForStop(); +}; + +#endif diff --git a/ssl_process.cpp b/ssl_process.cpp new file mode 100644 index 0000000..f2bc44c --- /dev/null +++ b/ssl_process.cpp @@ -0,0 +1,31 @@ +#include "ssl_process.h" + +SSLProcess::SSLProcess(bool isServer) + : _ctx(0) + , _isServer(isServer) +{ +} + +SSLProcess::~SSLProcess() +{ + if(!_ctx) + SSL_CTX_free(_ctx); +} + +void SSLProcess::sslInit() +{ + // Load algorithms and error strings. + ERR_load_crypto_strings(); + OpenSSL_add_all_algorithms(); + SSL_load_error_strings(); + SSL_library_init(); + + // Create new context for server method. + _ctx = SSL_CTX_new( + _isServer ? SSLv23_server_method() : SSLv23_client_method()); + if(_ctx == 0) + { + ERR_print_errors_fp(stderr); + exit(1); + } +} diff --git a/ssl_process.h b/ssl_process.h new file mode 100644 index 0000000..b446ae6 --- /dev/null +++ b/ssl_process.h @@ -0,0 +1,30 @@ +#ifndef __SSL_PROCESS__ +#define __SSL_PROCESS__ + +#include +#include + +class SSLProcess +{ +public: + virtual ~SSLProcess(); + virtual void init() = 0; + +protected: + SSLProcess(bool isServer); + + void sslInit(); + bool isServer() + { return _isServer; } + + SSL_CTX* _ctx; + +private: + SSLProcess(); + SSLProcess(const SSLProcess&); + SSLProcess& operator=(const SSLProcess&); + + bool _isServer; +}; + +#endif // __SSL_PROCESS__