|
- #include "server.h"
- #include <vector>
- #include <set>
- #include <boost/thread/thread.hpp>
- #include <boost/thread/mutex.hpp>
- #include <boost/thread/condition_variable.hpp>
- #include <boost/thread/locks.hpp>
- #include <boost/pending/queue.hpp>
- #include <netinet/tcp.h>
- #include <fcntl.h>
- #include <iostream>
- #include <exception>
-
- #include <openssl/ssl.h>
- #include <openssl/err.h>
- #include "defs.h"
-
- using namespace std;
- using namespace boost;
-
- typedef pair<int, SSL*> SocketSSLHandles_t;
- SocketSSLHandles_t WriteHandler(0,0);
-
- // Keeps socket handles that are ready to be read
- boost::queue<SocketSSLHandles_t> ReadQueue;
-
- // Socket Set is a set that keeps sockets on which we can 'select()'
- typedef set<SocketSSLHandles_t> SocketSet_t;
- SocketSet_t SocketSet;
- mutex SocketSetMutex;
-
- // WaitForWrite is a condition variable that is signaled when Sender must start sending the data
- condition_variable WaitForWrite;
- mutex WaitForWriteMutex;
-
- // mutex that synchronizes access to SSL_read/SSL_write
- mutex WriteReadMutex;
-
- // thread functions to send and receive
- void Receive();
- void Send();
-
-
- Server::Server()
- : SSLProcess(true)
- , _master(0)
- {
- }
-
- void Server::startListen(void) {
- struct sockaddr_in local_address;
-
- _master = ::socket(PF_INET, SOCK_STREAM, 0);
- memset(&local_address, 0, sizeof(local_address));
-
- local_address.sin_family = AF_INET;
- local_address.sin_port = htons(PORT);
- local_address.sin_addr.s_addr = INADDR_ANY;
-
- int reuseval = 1;
- setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
-
- // set socket non-blocking
- fcntl(_master, F_SETFL, O_NONBLOCK);
-
- // Bind to the socket
- if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
- throw runtime_error("Couldn't bind to local port");
-
- // Set a limit on connection queue.
- if(::listen(_master, 5) != 0)
- throw runtime_error("Not possible to get into listen state");
- }
-
- void Server::start()
- {
- if( !_ctx )
- throw runtime_error("SSL not initialized");
-
- startListen();
-
- Acceptor ac(_master, _ctx);
- _sender =new thread( Send );
- _reciver =new thread( Receive );
- _reactor =new thread( ac );
-
-
- }
-
- void Server::init()
- {
- sslInit();
- doServerSSLInit();
- }
-
- void Server::doServerSSLInit()
- {
- // Load certificate & private key
- if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
- ERR_print_errors_fp(stderr);
- _exit(1);
- }
-
- if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
- ERR_print_errors_fp(stderr);
- _exit(1);
- }
-
- // Verify if public-private keypair matches
- if ( !SSL_CTX_check_private_key(_ctx) ) {
- fprintf(stderr, "Private key is invalid.\n");
- _exit(1);
- }
- }
-
- void Acceptor::operator()()
- {
- cout << "Entering acceptor loop..." << endl;
- while(1)
- {
- // wait timer for select
- struct timeval tv;
- tv.tv_sec = 0;
- tv.tv_usec = 10;
-
- fd_set fd_read;
-
- // set fd_sets
- FD_ZERO(&fd_read);
-
- // set max fd for select
- int maxv = _master;
- FD_SET(_master, &fd_read);
-
- // add all the sockets to sets
- {
- lock_guard<mutex> guard(SocketSetMutex);
- for(SocketSet_t::const_iterator aIt=SocketSet.begin();
- aIt!=SocketSet.end(); ++aIt)
- {
- FD_SET(aIt->first, &fd_read);
- if(aIt->first > maxv)
- maxv=aIt->first;
- }
- }
-
- // wait in select now
- select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
- {
- lock_guard<mutex> guard(SocketSetMutex);
-
- // check if you can read
- SocketSet_t::const_iterator aIt=SocketSet.begin();
- while(aIt!=SocketSet.end())
- {
-
- SocketSSLHandles_t aTmpHandles = *aIt;
- aIt++;
- if( FD_ISSET(aTmpHandles.first, &fd_read ) )
- {
- // you need to erase tmpHd from SocketSet - otherwise it will
- // be ready to read until SSL_read is not called on it
- SocketSet.erase(aTmpHandles);
- ReadQueue.push(aTmpHandles);
- }
- }
- }
-
- // if master is in fd_read - then it means new connection req
- // has arrived
- if( FD_ISSET(_master, &fd_read) )
- {
- int new_fd=openTCPSocket();
- if( new_fd >= 0 )
- {
- cout << "New socket with ID : " << new_fd
- << " is going to be added to map" << endl;
- SSL* ssl = openSSLSession(new_fd);
- lock_guard<mutex> guard(SocketSetMutex);
- SocketSet.insert(make_pair(new_fd, ssl));
- }
- }
- }
- }
-
- int Acceptor::openTCPSocket()
- {
- // Open up new connection
- cout << "New connection has arrived" << endl;
- struct sockaddr_in addr;
- int len = sizeof(addr);
- int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
- if(client == -1)
- perror("accept");
- return client;
- }
-
- SSL* Acceptor::openSSLSession(int iTCPHandle)
- {
- SSL *ssl = (SSL*) SSL_new(_ctx);
- SSL_set_fd(ssl, iTCPHandle);
-
- // normally this would be in other thread
- if(SSL_accept(ssl) == -1) {
- ERR_print_errors_fp(stderr);
- throw runtime_error("Can't SSL_accept => can't continue");
- }
- return ssl;
- }
-
- void Receive()
- {
- while(1)
- {
- char buf[1024];
- SocketSSLHandles_t handler;
-
- // TO-DO: this way it takes 100% CPU, some signal would be usefull
- while (!ReadQueue.empty())
- {
- handler = ReadQueue.front();
- ReadQueue.pop();
-
- memset(buf,'\0',1024);
- int len_rcv=0;
- cout << SSL_state_string(handler.second) << endl;
- {
- lock_guard<mutex> lock(WriteReadMutex);
- len_rcv = SSL_read(handler.second, buf, 1024);
- }
-
- if( len_rcv )
- {
- // dirty thing - if it has \n on the end - remove it
- if( buf[len_rcv-1] == '\n' )
- buf[len_rcv-1] = '\0';
-
- cout << buf << endl;
-
- // add it back to the socket so that select can use it
- lock_guard<mutex> guard(SocketSetMutex);
- SocketSet.insert(handler);
-
- // push handler ID and notify sender thread
- WriteHandler = handler;
- WaitForWrite.notify_one();
- }
- else
- {
- cout << "Closing connection " << handler.first << endl;
- ::close(handler.first);
- }
- }
- }
- }
-
- void Send()
- {
- while(1)
- {
- SocketSSLHandles_t handler(0,0);
- {
- unique_lock<mutex> lock(WaitForWriteMutex);
- WaitForWrite.wait(lock);
- handler = WriteHandler;
- }
-
- cout << "Writing to handler " << handler.first << endl;
- string buf(EXCHANGE_STRING);
- for(int i=0; i<SEND_ITERATIONS; ++i)
- {
- int len = 0;
- do
- {
- lock_guard<mutex> lock(WriteReadMutex);
- len+=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
- // for debugging re-neg
- // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
- } while( len != static_cast<int>(buf.size()) );
- }
- }
- }
-
-
- void Server::waitForStop()
- {
- _sender->join();
- _reciver->join();
- _reactor->join();
- }
-
- Server::~Server()
- {
- delete _sender;
- delete _reciver;
- delete _reactor;
- }
-
- /// --- MAIN --- ///
- int main() {
- Server server;
- server.init();
- server.start();
- server.waitForStop();
-
- return 0;
- }
|