You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

304 lines
7.9 KiB

  1. #include "server.h"
  2. #include <vector>
  3. #include <set>
  4. #include <boost/thread/thread.hpp>
  5. #include <boost/thread/mutex.hpp>
  6. #include <boost/thread/condition_variable.hpp>
  7. #include <boost/thread/locks.hpp>
  8. #include <boost/pending/queue.hpp>
  9. #include <netinet/tcp.h>
  10. #include <fcntl.h>
  11. #include <iostream>
  12. #include <exception>
  13. #include <openssl/ssl.h>
  14. #include <openssl/err.h>
  15. #include "defs.h"
  16. using namespace std;
  17. using namespace boost;
  18. typedef pair<int, SSL*> SocketSSLHandles_t;
  19. SocketSSLHandles_t WriteHandler(0,0);
  20. // Keeps socket handles that are ready to be read
  21. boost::queue<SocketSSLHandles_t> ReadQueue;
  22. // Socket Set is a set that keeps sockets on which we can 'select()'
  23. typedef set<SocketSSLHandles_t> SocketSet_t;
  24. SocketSet_t SocketSet;
  25. mutex SocketSetMutex;
  26. // WaitForWrite is a condition variable that is signaled when Sender must start sending the data
  27. condition_variable WaitForWrite;
  28. mutex WaitForWriteMutex;
  29. mutex WriteReadMutex;
  30. void Receive();
  31. void Send();
  32. Server::Server()
  33. : SSLProcess(true)
  34. , _master(0)
  35. {
  36. }
  37. void Server::startListen(void) {
  38. struct sockaddr_in local_address;
  39. _master = ::socket(PF_INET, SOCK_STREAM, 0);
  40. memset(&local_address, 0, sizeof(local_address));
  41. local_address.sin_family = AF_INET;
  42. local_address.sin_port = htons(PORT);
  43. local_address.sin_addr.s_addr = INADDR_ANY;
  44. int reuseval = 1;
  45. setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
  46. // set socket non-blocking
  47. fcntl(_master, F_SETFL, O_NONBLOCK);
  48. // Bind to the socket
  49. if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
  50. throw runtime_error("Couldn't bind to local port");
  51. // Set a limit on connection queue.
  52. if(::listen(_master, 5) != 0)
  53. throw runtime_error("Not possible to get into listen state");
  54. }
  55. void Server::start()
  56. {
  57. if( !_ctx )
  58. throw runtime_error("SSL not initialized");
  59. startListen();
  60. Acceptor ac(_master, _ctx);
  61. _sender =new thread( Send );
  62. _reciver =new thread( Receive );
  63. _reactor =new thread( ac );
  64. }
  65. void Server::init()
  66. {
  67. sslInit();
  68. doServerSSLInit();
  69. }
  70. void Server::doServerSSLInit()
  71. {
  72. // Load certificate & private key
  73. if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
  74. ERR_print_errors_fp(stderr);
  75. _exit(1);
  76. }
  77. if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
  78. ERR_print_errors_fp(stderr);
  79. _exit(1);
  80. }
  81. // Verify if public-private keypair matches
  82. if ( !SSL_CTX_check_private_key(_ctx) ) {
  83. fprintf(stderr, "Private key is invalid.\n");
  84. _exit(1);
  85. }
  86. }
  87. void Acceptor::operator()()
  88. {
  89. cout << "Entering acceptor loop..." << endl;
  90. while(1)
  91. {
  92. // wait timer for select
  93. struct timeval tv;
  94. tv.tv_sec = 0;
  95. tv.tv_usec = 10;
  96. fd_set fd_read;
  97. // set fd_sets
  98. FD_ZERO(&fd_read);
  99. // set max fd for select
  100. int maxv = _master;
  101. FD_SET(_master, &fd_read);
  102. // add all the sockets to sets
  103. {
  104. lock_guard<mutex> guard(SocketSetMutex);
  105. for(SocketSet_t::const_iterator aIt=SocketSet.begin();
  106. aIt!=SocketSet.end(); ++aIt)
  107. {
  108. FD_SET(aIt->first, &fd_read);
  109. if(aIt->first > maxv)
  110. maxv=aIt->first;
  111. }
  112. }
  113. // wait in select now
  114. select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
  115. {
  116. lock_guard<mutex> guard(SocketSetMutex);
  117. // check if you can read
  118. SocketSet_t::const_iterator aIt=SocketSet.begin();
  119. while(aIt!=SocketSet.end())
  120. {
  121. SocketSSLHandles_t aTmpHandles = *aIt;
  122. aIt++;
  123. if( FD_ISSET(aTmpHandles.first, &fd_read ) )
  124. {
  125. // you need to erase tmpHd from SocketSet - otherwise it will
  126. // be ready to read until SSL_read is not called on it
  127. SocketSet.erase(aTmpHandles);
  128. ReadQueue.push(aTmpHandles);
  129. }
  130. }
  131. }
  132. // if master is in fd_read - then it means new connection req
  133. // has arrived
  134. if( FD_ISSET(_master, &fd_read) )
  135. {
  136. int new_fd=openTCPSocket();
  137. if( new_fd >= 0 )
  138. {
  139. cout << "New socket with ID : " << new_fd
  140. << " is going to be added to map" << endl;
  141. SSL* ssl = openSSLSession(new_fd);
  142. lock_guard<mutex> guard(SocketSetMutex);
  143. SocketSet.insert(make_pair(new_fd, ssl));
  144. }
  145. }
  146. }
  147. }
  148. int Acceptor::openTCPSocket()
  149. {
  150. // Open up new connection
  151. cout << "New connection has arrived" << endl;
  152. struct sockaddr_in addr;
  153. int len = sizeof(addr);
  154. int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
  155. if(client == -1)
  156. perror("accept");
  157. return client;
  158. }
  159. SSL* Acceptor::openSSLSession(int iTCPHandle)
  160. {
  161. SSL *ssl = (SSL*) SSL_new(_ctx);
  162. SSL_set_fd(ssl, iTCPHandle);
  163. // normally this would be in other thread
  164. if(SSL_accept(ssl) == -1) {
  165. ERR_print_errors_fp(stderr);
  166. throw runtime_error("Can't SSL_accept => can't continue");
  167. }
  168. return ssl;
  169. }
  170. void Receive()
  171. {
  172. while(1)
  173. {
  174. char buf[1024];
  175. SocketSSLHandles_t handler;
  176. // TO-DO: this way it takes 100% CPU, some signal would be usefull
  177. while (!ReadQueue.empty())
  178. {
  179. handler = ReadQueue.front();
  180. ReadQueue.pop();
  181. memset(buf,'\0',1024);
  182. int len_rcv=0;
  183. cout << SSL_state_string(handler.second) << endl;
  184. {
  185. lock_guard<mutex> lock(WriteReadMutex);
  186. len_rcv = SSL_read(handler.second, buf, 1024);
  187. }
  188. if( len_rcv )
  189. {
  190. // dirty thing - if it has \n on the end - remove it
  191. if( buf[len_rcv-1] == '\n' )
  192. buf[len_rcv-1] = '\0';
  193. cout << buf << endl;
  194. // add it back to the socket so that select can use it
  195. lock_guard<mutex> guard(SocketSetMutex);
  196. SocketSet.insert(handler);
  197. // push handler ID and notify sender thread
  198. WriteHandler = handler;
  199. WaitForWrite.notify_one();
  200. }
  201. else
  202. {
  203. cout << "Closing connection " << handler.first << endl;
  204. ::close(handler.first);
  205. }
  206. }
  207. }
  208. }
  209. void Send()
  210. {
  211. while(1)
  212. {
  213. SocketSSLHandles_t handler(0,0);
  214. {
  215. unique_lock<mutex> lock(WaitForWriteMutex);
  216. WaitForWrite.wait(lock);
  217. handler = WriteHandler;
  218. }
  219. cout << "Writing to handler " << handler.first << endl;
  220. string buf(EXCHANGE_STRING);
  221. for(int i=0; i<SEND_ITERATIONS; ++i)
  222. {
  223. int len = 0;
  224. do
  225. {
  226. lock_guard<mutex> lock(WriteReadMutex);
  227. len+=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
  228. // for debugging re-neg
  229. // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
  230. } while( len != static_cast<int>(buf.size()) );
  231. }
  232. }
  233. }
  234. void Server::waitForStop()
  235. {
  236. _sender->join();
  237. _reciver->join();
  238. _reactor->join();
  239. }
  240. Server::~Server()
  241. {
  242. delete _sender;
  243. delete _reciver;
  244. delete _reactor;
  245. }
  246. /// --- MAIN --- ///
  247. int main() {
  248. Server server;
  249. server.init();
  250. server.start();
  251. server.waitForStop();
  252. return 0;
  253. }