buggy_openssl_with_fullduplex/server.cpp

304 lines
7.9 KiB
C++
Raw Normal View History

2013-09-09 14:11:51 +01:00
#include "server.h"
#include <vector>
#include <set>
#include <boost/thread/thread.hpp>
#include <boost/thread/mutex.hpp>
#include <boost/thread/condition_variable.hpp>
#include <boost/thread/locks.hpp>
#include <boost/pending/queue.hpp>
#include <netinet/tcp.h>
#include <fcntl.h>
#include <iostream>
#include <exception>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include "defs.h"
using namespace std;
using namespace boost;
typedef pair<int, SSL*> SocketSSLHandles_t;
SocketSSLHandles_t WriteHandler(0,0);
// Keeps socket handles that are ready to be read
boost::queue<SocketSSLHandles_t> ReadQueue;
// Socket Set is a set that keeps sockets on which we can 'select()'
typedef set<SocketSSLHandles_t> SocketSet_t;
SocketSet_t SocketSet;
mutex SocketSetMutex;
// WaitForWrite is a condition variable that is signaled when Sender must start sending the data
condition_variable WaitForWrite;
mutex WaitForWriteMutex;
mutex WriteReadMutex;
void Receive();
void Send();
Server::Server()
: SSLProcess(true)
, _master(0)
{
}
void Server::startListen(void) {
struct sockaddr_in local_address;
_master = ::socket(PF_INET, SOCK_STREAM, 0);
memset(&local_address, 0, sizeof(local_address));
local_address.sin_family = AF_INET;
local_address.sin_port = htons(PORT);
local_address.sin_addr.s_addr = INADDR_ANY;
int reuseval = 1;
setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
// set socket non-blocking
fcntl(_master, F_SETFL, O_NONBLOCK);
// Bind to the socket
if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
throw runtime_error("Couldn't bind to local port");
// Set a limit on connection queue.
if(::listen(_master, 5) != 0)
throw runtime_error("Not possible to get into listen state");
}
void Server::start()
{
if( !_ctx )
throw runtime_error("SSL not initialized");
startListen();
Acceptor ac(_master, _ctx);
_sender =new thread( Send );
_reciver =new thread( Receive );
_reactor =new thread( ac );
}
void Server::init()
{
sslInit();
doServerSSLInit();
}
void Server::doServerSSLInit()
{
// Load certificate & private key
if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
ERR_print_errors_fp(stderr);
_exit(1);
}
if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
ERR_print_errors_fp(stderr);
_exit(1);
}
// Verify if public-private keypair matches
if ( !SSL_CTX_check_private_key(_ctx) ) {
fprintf(stderr, "Private key is invalid.\n");
_exit(1);
}
}
void Acceptor::operator()()
{
cout << "Entering acceptor loop..." << endl;
while(1)
{
// wait timer for select
struct timeval tv;
tv.tv_sec = 0;
tv.tv_usec = 10;
fd_set fd_read;
// set fd_sets
FD_ZERO(&fd_read);
// set max fd for select
int maxv = _master;
FD_SET(_master, &fd_read);
// add all the sockets to sets
{
lock_guard<mutex> guard(SocketSetMutex);
for(SocketSet_t::const_iterator aIt=SocketSet.begin();
aIt!=SocketSet.end(); ++aIt)
{
FD_SET(aIt->first, &fd_read);
if(aIt->first > maxv)
maxv=aIt->first;
}
}
// wait in select now
select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
{
lock_guard<mutex> guard(SocketSetMutex);
// check if you can read
SocketSet_t::const_iterator aIt=SocketSet.begin();
while(aIt!=SocketSet.end())
{
SocketSSLHandles_t aTmpHandles = *aIt;
aIt++;
if( FD_ISSET(aTmpHandles.first, &fd_read ) )
{
// you need to erase tmpHd from SocketSet - otherwise it will
// be ready to read until SSL_read is not called on it
SocketSet.erase(aTmpHandles);
ReadQueue.push(aTmpHandles);
}
}
}
// if master is in fd_read - then it means new connection req
// has arrived
if( FD_ISSET(_master, &fd_read) )
{
int new_fd=openTCPSocket();
if( new_fd >= 0 )
{
cout << "New socket with ID : " << new_fd
<< " is going to be added to map" << endl;
SSL* ssl = openSSLSession(new_fd);
lock_guard<mutex> guard(SocketSetMutex);
SocketSet.insert(make_pair(new_fd, ssl));
}
}
}
}
int Acceptor::openTCPSocket()
{
// Open up new connection
cout << "New connection has arrived" << endl;
struct sockaddr_in addr;
int len = sizeof(addr);
int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
if(client == -1)
perror("accept");
return client;
}
SSL* Acceptor::openSSLSession(int iTCPHandle)
{
SSL *ssl = (SSL*) SSL_new(_ctx);
SSL_set_fd(ssl, iTCPHandle);
// normally this would be in other thread
if(SSL_accept(ssl) == -1) {
ERR_print_errors_fp(stderr);
throw runtime_error("Can't SSL_accept => can't continue");
}
return ssl;
}
void Receive()
{
while(1)
{
char buf[1024];
SocketSSLHandles_t handler;
// TO-DO: this way it takes 100% CPU, some signal would be usefull
while (!ReadQueue.empty())
{
handler = ReadQueue.front();
ReadQueue.pop();
memset(buf,'\0',1024);
int len_rcv=0;
cout << SSL_state_string(handler.second) << endl;
{
lock_guard<mutex> lock(WriteReadMutex);
len_rcv = SSL_read(handler.second, buf, 1024);
}
if( len_rcv )
{
// dirty thing - if it has \n on the end - remove it
if( buf[len_rcv-1] == '\n' )
buf[len_rcv-1] = '\0';
cout << buf << endl;
// add it back to the socket so that select can use it
lock_guard<mutex> guard(SocketSetMutex);
SocketSet.insert(handler);
// push handler ID and notify sender thread
WriteHandler = handler;
WaitForWrite.notify_one();
}
else
{
cout << "Closing connection " << handler.first << endl;
::close(handler.first);
}
}
}
}
void Send()
{
while(1)
{
SocketSSLHandles_t handler(0,0);
{
unique_lock<mutex> lock(WaitForWriteMutex);
WaitForWrite.wait(lock);
handler = WriteHandler;
}
cout << "Writing to handler " << handler.first << endl;
string buf(EXCHANGE_STRING);
for(int i=0; i<SEND_ITERATIONS; ++i)
{
int len = 0;
do
{
lock_guard<mutex> lock(WriteReadMutex);
len+=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
// for debugging re-neg
// cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
} while( len != static_cast<int>(buf.size()) );
}
}
}
void Server::waitForStop()
{
_sender->join();
_reciver->join();
_reactor->join();
}
Server::~Server()
{
delete _sender;
delete _reciver;
delete _reactor;
}
/// --- MAIN --- ///
int main() {
Server server;
server.init();
server.start();
server.waitForStop();
return 0;
}