  1. #include "server.h"
  2. #include <vector>
  3. #include <set>
  4. #include <boost/thread/thread.hpp>
  5. #include <boost/thread/mutex.hpp>
  6. #include <boost/thread/condition_variable.hpp>
  7. #include <boost/thread/locks.hpp>
  8. #include <boost/pending/queue.hpp>
  9. #include <netinet/tcp.h>
  10. #include <fcntl.h>
  11. #include <iostream>
  12. #include <exception>
  13. #include <netinet/tcp.h>
  14. #include <openssl/ssl.h>
  15. #include <openssl/err.h>
  16. #include "defs.h"
  17. using namespace std;
  18. using namespace boost;
  19. typedef pair<int, SSL*> SocketSSLHandles_t;
  20. SocketSSLHandles_t WriteHandler(0,0);
  21. // Keeps socket handles that are ready to be read
  22. boost::queue<SocketSSLHandles_t> ReadQueue;
  23. // Socket Set is a set that keeps sockets on which we can 'select()'
  24. typedef set<SocketSSLHandles_t> SocketSet_t;
  25. SocketSet_t SocketSet;
  26. mutex SocketSetMutex;
  27. // WaitForWrite is a condition variable that is signaled when Sender must start sending the data
  28. condition_variable WaitForWrite;
  29. mutex WaitForWriteMutex;
  30. // mutex that synchronizes access to SSL_read/SSL_write
  31. mutex WriteReadMutex;
  32. // thread functions to send and receive
  33. void Receive();
  34. void Send();
  35. int Gmaster=0;
  36. bool handle_error_code(int& len, SSL* SSLHandler, int code, const char* func)
  37. {
  38. switch( SSL_get_error( SSLHandler, code ) )
  39. {
  40. case SSL_ERROR_NONE:
  41. len+=code;
  42. return false;
  44. cout << "CONNETION CLOSE ON WRITE" << endl;
  45. exit(1);
  46. break;
  48. cout << func << " WANT READ" << endl;
  49. break;
  51. cout << func << " WANT WRITE" << endl;
  52. break;
  54. cout << func << " ESYSCALL" << endl;
  55. // exit(1);
  56. break;
  57. case SSL_ERROR_SSL:
  58. cout << func << " ESSL" << endl;
  59. exit(1);
  60. break;
  61. default:
  62. cout << func << " SOMETHING ELSE" << endl;
  63. }
  64. return true;
  65. }
  66. Server::Server()
  67. : SSLProcess(true)
  68. , _master(0)
  69. {
  70. }
  71. void Server::startListen(void) {
  72. struct sockaddr_in local_address;
  73. _master = ::socket(PF_INET, SOCK_STREAM, 0);
  74. memset(&local_address, 0, sizeof(local_address));
  75. local_address.sin_family = AF_INET;
  76. local_address.sin_port = htons(PORT);
  77. local_address.sin_addr.s_addr = INADDR_ANY;
  78. int reuseval = 1;
  79. setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
  80. // set socket non-blocking
  81. fcntl(_master, F_SETFL, O_NONBLOCK);
  82. // Bind to the socket
  83. if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
  84. throw runtime_error("Couldn't bind to local port");
  85. // Set a limit on connection queue.
  86. if(::listen(_master, 5) != 0)
  87. throw runtime_error("Not possible to get into listen state");
  88. }
  89. void Server::start()
  90. {
  91. if( !_ctx )
  92. throw runtime_error("SSL not initialized");
  93. startListen();
  94. Acceptor ac(_master, _ctx);
  95. Gmaster=_master;
  96. _sender =new thread( Send );
  97. _reciver =new thread( Receive );
  98. _reactor =new thread( ac );
  99. }
  100. void Server::init()
  101. {
  102. sslInit();
  103. doServerSSLInit();
  104. }
  105. void Server::doServerSSLInit()
  106. {
  107. // Load certificate & private key
  108. if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
  109. ERR_print_errors_fp(stderr);
  110. _exit(1);
  111. }
  112. if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
  113. ERR_print_errors_fp(stderr);
  114. _exit(1);
  115. }
  116. // Verify if public-private keypair matches
  117. if ( !SSL_CTX_check_private_key(_ctx) ) {
  118. fprintf(stderr, "Private key is invalid.\n");
  119. _exit(1);
  120. }
  121. // set weak protocol, so it is easy to debug with wireshark
  122. SSL_CTX_set_options(_ctx, SSL_OP_NO_TLSv1_2
  123. | SSL_OP_NO_TLSv1_1
  124. | SSL_OP_NO_TLSv1
  125. | SSL_OP_ALL
  126. | SSL_OP_SINGLE_DH_USE );
  127. }
  128. void Acceptor::operator()()
  129. {
  130. cout << "Entering acceptor loop..." << endl;
  131. while(1)
  132. {
  133. // wait timer for select
  134. struct timeval tv;
  135. tv.tv_sec = 0;
  136. tv.tv_usec = 10;
  137. fd_set fd_read;
  138. // set fd_sets
  139. FD_ZERO(&fd_read);
  140. // set max fd for select
  141. int maxv = _master;
  142. FD_SET(_master, &fd_read);
  143. // add all the sockets to sets
  144. {
  145. lock_guard<mutex> guard(SocketSetMutex);
  146. for(SocketSet_t::const_iterator aIt=SocketSet.begin();
  147. aIt!=SocketSet.end(); ++aIt)
  148. {
  149. FD_SET(aIt->first, &fd_read);
  150. if(aIt->first > maxv)
  151. maxv=aIt->first;
  152. }
  153. }
  154. // wait in select now
  155. select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
  156. {
  157. lock_guard<mutex> guard(SocketSetMutex);
  158. // check if you can read
  159. SocketSet_t::const_iterator aIt=SocketSet.begin();
  160. while(aIt!=SocketSet.end())
  161. {
  162. SocketSSLHandles_t aTmpHandles = *aIt;
  163. aIt++;
  164. if( FD_ISSET(aTmpHandles.first, &fd_read ) )
  165. {
  166. // you need to erase tmpHd from SocketSet - otherwise it will
  167. // be ready to read until SSL_read is not called on it
  168. SocketSet.erase(aTmpHandles);
  169. ReadQueue.push(aTmpHandles);
  170. }
  171. }
  172. }
  173. // if master is in fd_read - then it means new connection req
  174. // has arrived
  175. if( FD_ISSET(_master, &fd_read) )
  176. {
  177. int new_fd=openTCPSocket();
  178. if( new_fd >= 0 )
  179. {
  180. int flag =1;
  181. // setsockopt(new_fd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
  182. cout << "New socket with ID : " << new_fd
  183. << " is going to be added to map" << endl;
  184. SSL* ssl = openSSLSession(new_fd);
  185. lock_guard<mutex> guard(SocketSetMutex);
  186. SocketSet.insert(make_pair(new_fd, ssl));
  187. }
  188. }
  189. }
  190. }
  191. int Acceptor::openTCPSocket()
  192. {
  193. // Open up new connection
  194. cout << "New connection has arrived" << endl;
  195. struct sockaddr_in addr;
  196. int len = sizeof(addr);
  197. int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
  198. if(client == -1)
  199. perror("accept");
  200. return client;
  201. }
  202. SSL* Acceptor::openSSLSession(int iTCPHandle)
  203. {
  204. SSL *ssl = (SSL*) SSL_new(_ctx);
  205. SSL_set_fd(ssl, iTCPHandle);
  206. // normally this would be in other thread
  207. if(SSL_accept(ssl) == -1) {
  208. ERR_print_errors_fp(stderr);
  209. throw runtime_error("Can't SSL_accept => can't continue");
  210. }
  211. return ssl;
  212. }
  213. void Receive()
  214. {
  215. while(1)
  216. {
  217. char buf[1024];
  218. SocketSSLHandles_t handler;
  219. // TO-DO: this way it takes 100% CPU, some signal would be usefull
  220. while (!ReadQueue.empty())
  221. {
  222. handler = ReadQueue.front();
  223. ReadQueue.pop();
  224. memset(buf,'\0',1024);
  225. int len_rcv=0;
  226. cout << SSL_state_string(handler.second) << endl;
  227. {
  228. lock_guard<mutex> lock(WriteReadMutex);
  229. int flag = 1;
  230. while( flag!=0 )
  231. {
  232. cout << "SSL_read: start" << endl;
  233. len_rcv = SSL_read(handler.second, buf, 1024);
  234. flag = SSL_pending(handler.second);
  235. cout << "PENDING: " << flag << endl;
  236. // cout << "SSL_read: stop" << endl;
  237. if( !handle_error_code(len_rcv, handler.second, len_rcv, "rcv") )
  238. {
  239. // dirty thing - if it has \n on the end - remove it
  240. if( buf[len_rcv-1] == '\n' )
  241. buf[len_rcv-1] = '\0';
  242. cout << buf << endl;
  243. {
  244. // add it back to the socket so that select can use it
  245. lock_guard<mutex> guard(SocketSetMutex);
  246. SocketSet.insert(handler);
  247. // push handler ID and notify sender thread
  248. WriteHandler = handler;
  249. WaitForWrite.notify_one();
  250. }
  251. break;
  252. }
  253. }
  254. }
  255. }
  256. }
  257. }
  258. void Send()
  259. {
  260. while(1)
  261. {
  262. SocketSSLHandles_t handler(0,0);
  263. {
  264. unique_lock<mutex> lock(WaitForWriteMutex);
  265. WaitForWrite.wait(lock);
  266. handler = WriteHandler;
  267. }
  268. cout << "Writing to handler " << handler.first << endl;
  269. string buf(EXCHANGE_STRING);
  270. for(int i=0; i<SEND_ITERATIONS; ++i)
  271. {
  272. int len = 0;
  273. // wait timer for select
  274. struct timeval tv;
  275. tv.tv_sec = 0;
  276. tv.tv_usec = 10;
  277. do
  278. {
  279. fd_set fd_write;
  280. FD_ZERO(&fd_write);
  281. FD_SET(Gmaster, &fd_write);
  282. FD_SET(handler.first, &fd_write);
  283. int maxv=Gmaster;
  284. if(Gmaster < handler.first)
  285. maxv=handler.first;
  286. select(maxv+1, NULL, &fd_write, NULL, (struct timeval *)&tv);
  287. if( FD_ISSET(handler.first, &fd_write) )
  288. {
  289. lock_guard<mutex> lock(WriteReadMutex);
  290. // cout << "SSL_write: start" << endl;
  291. int write_len=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
  292. // cout << "SSL_write: stop " << endl;
  293. handle_error_code(len, handler.second, write_len, "write");
  294. // for debugging re-neg
  295. // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
  296. }
  297. } while( len != static_cast<int>(buf.size()) );
  298. }
  299. }
  300. }
  301. void Server::waitForStop()
  302. {
  303. _sender->join();
  304. _reciver->join();
  305. _reactor->join();
  306. }
  307. Server::~Server()
  308. {
  309. delete _sender;
  310. delete _reciver;
  311. delete _reactor;
  312. }
  313. /// --- MAIN --- ///
  314. int main() {
  315. Server server;
  316. server.init();
  317. server.start();
  318. server.waitForStop();
  319. return 0;
  320. }