You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

346 line
9.5 KiB

  1. #include "server.h"
  2. #include <vector>
  3. #include <set>
  4. #include <boost/thread/thread.hpp>
  5. #include <boost/thread/mutex.hpp>
  6. #include <boost/thread/condition_variable.hpp>
  7. #include <boost/thread/locks.hpp>
  8. #include <boost/pending/queue.hpp>
  9. #include <netinet/tcp.h>
  10. #include <fcntl.h>
  11. #include <iostream>
  12. #include <exception>
  13. #include <openssl/ssl.h>
  14. #include <openssl/err.h>
  15. #include "defs.h"
  16. using namespace std;
  17. using namespace boost;
  18. typedef pair<int, SSL*> SocketSSLHandles_t;
  19. SocketSSLHandles_t WriteHandler(0,0);
  20. // Keeps socket handles that are ready to be read
  21. boost::queue<SocketSSLHandles_t> ReadQueue;
  22. // Socket Set is a set that keeps sockets on which we can 'select()'
  23. typedef set<SocketSSLHandles_t> SocketSet_t;
  24. SocketSet_t SocketSet;
  25. mutex SocketSetMutex;
  26. // WaitForWrite is a condition variable that is signaled when Sender must start sending the data
  27. condition_variable WaitForWrite;
  28. mutex WaitForWriteMutex;
  29. // mutex that synchronizes access to SSL_read/SSL_write
  30. mutex WriteReadMutex;
  31. // thread functions to send and receive
  32. void Receive();
  33. void Send();
  34. Server::Server()
  35. : SSLProcess(true)
  36. , _master(0)
  37. {
  38. }
  39. void Server::startListen(void) {
  40. struct sockaddr_in local_address;
  41. _master = ::socket(PF_INET, SOCK_STREAM, 0);
  42. memset(&local_address, 0, sizeof(local_address));
  43. local_address.sin_family = AF_INET;
  44. local_address.sin_port = htons(PORT);
  45. local_address.sin_addr.s_addr = INADDR_ANY;
  46. int reuseval = 1;
  47. setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
  48. // set socket non-blocking
  49. fcntl(_master, F_SETFL, O_NONBLOCK);
  50. // Bind to the socket
  51. if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
  52. throw runtime_error("Couldn't bind to local port");
  53. // Set a limit on connection queue.
  54. if(::listen(_master, 5) != 0)
  55. throw runtime_error("Not possible to get into listen state");
  56. }
  57. void Server::start()
  58. {
  59. if( !_ctx )
  60. throw runtime_error("SSL not initialized");
  61. startListen();
  62. Acceptor ac(_master, _ctx);
  63. _sender =new thread( Send );
  64. _reciver =new thread( Receive );
  65. _reactor =new thread( ac );
  66. }
  67. void Server::init()
  68. {
  69. sslInit();
  70. doServerSSLInit();
  71. }
  72. void Server::doServerSSLInit()
  73. {
  74. // Load certificate & private key
  75. if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
  76. ERR_print_errors_fp(stderr);
  77. _exit(1);
  78. }
  79. if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
  80. ERR_print_errors_fp(stderr);
  81. _exit(1);
  82. }
  83. // Verify if public-private keypair matches
  84. if ( !SSL_CTX_check_private_key(_ctx) ) {
  85. fprintf(stderr, "Private key is invalid.\n");
  86. _exit(1);
  87. }
  88. // set weak protocol, so it is easy to debug with wireshark
  89. SSL_CTX_set_options(_ctx, SSL_OP_NO_TLSv1_2 | SSL_OP_NO_TLSv1_1 | SSL_OP_NO_TLSv1 | SSL_OP_ALL | SSL_OP_SINGLE_DH_USE );
  90. }
  91. void Acceptor::operator()()
  92. {
  93. cout << "Entering acceptor loop..." << endl;
  94. while(1)
  95. {
  96. // wait timer for select
  97. struct timeval tv;
  98. tv.tv_sec = 0;
  99. tv.tv_usec = 10;
  100. fd_set fd_read;
  101. // set fd_sets
  102. FD_ZERO(&fd_read);
  103. // set max fd for select
  104. int maxv = _master;
  105. FD_SET(_master, &fd_read);
  106. // add all the sockets to sets
  107. {
  108. lock_guard<mutex> guard(SocketSetMutex);
  109. for(SocketSet_t::const_iterator aIt=SocketSet.begin();
  110. aIt!=SocketSet.end(); ++aIt)
  111. {
  112. FD_SET(aIt->first, &fd_read);
  113. if(aIt->first > maxv)
  114. maxv=aIt->first;
  115. }
  116. }
  117. // wait in select now
  118. select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
  119. {
  120. lock_guard<mutex> guard(SocketSetMutex);
  121. // check if you can read
  122. SocketSet_t::const_iterator aIt=SocketSet.begin();
  123. while(aIt!=SocketSet.end())
  124. {
  125. SocketSSLHandles_t aTmpHandles = *aIt;
  126. aIt++;
  127. if( FD_ISSET(aTmpHandles.first, &fd_read ) )
  128. {
  129. // you need to erase tmpHd from SocketSet - otherwise it will
  130. // be ready to read until SSL_read is not called on it
  131. SocketSet.erase(aTmpHandles);
  132. ReadQueue.push(aTmpHandles);
  133. }
  134. }
  135. }
  136. // if master is in fd_read - then it means new connection req
  137. // has arrived
  138. if( FD_ISSET(_master, &fd_read) )
  139. {
  140. int new_fd=openTCPSocket();
  141. if( new_fd >= 0 )
  142. {
  143. cout << "New socket with ID : " << new_fd
  144. << " is going to be added to map" << endl;
  145. SSL* ssl = openSSLSession(new_fd);
  146. lock_guard<mutex> guard(SocketSetMutex);
  147. SocketSet.insert(make_pair(new_fd, ssl));
  148. }
  149. }
  150. }
  151. }
  152. int Acceptor::openTCPSocket()
  153. {
  154. // Open up new connection
  155. cout << "New connection has arrived" << endl;
  156. struct sockaddr_in addr;
  157. int len = sizeof(addr);
  158. int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
  159. if(client == -1)
  160. perror("accept");
  161. return client;
  162. }
  163. SSL* Acceptor::openSSLSession(int iTCPHandle)
  164. {
  165. SSL *ssl = (SSL*) SSL_new(_ctx);
  166. SSL_set_fd(ssl, iTCPHandle);
  167. // normally this would be in other thread
  168. if(SSL_accept(ssl) == -1) {
  169. ERR_print_errors_fp(stderr);
  170. throw runtime_error("Can't SSL_accept => can't continue");
  171. }
  172. return ssl;
  173. }
  174. void Receive()
  175. {
  176. while(1)
  177. {
  178. char buf[1024];
  179. SocketSSLHandles_t handler;
  180. // TO-DO: this way it takes 100% CPU, some signal would be usefull
  181. while (!ReadQueue.empty())
  182. {
  183. handler = ReadQueue.front();
  184. ReadQueue.pop();
  185. memset(buf,'\0',1024);
  186. int len_rcv=0;
  187. cout << SSL_state_string(handler.second) << endl;
  188. {
  189. lock_guard<mutex> lock(WriteReadMutex);
  190. len_rcv = SSL_read(handler.second, buf, 1024);
  191. switch( SSL_get_error(handler.second, len_rcv) )
  192. {
  193. case SSL_ERROR_NONE:
  194. {
  195. // dirty thing - if it has \n on the end - remove it
  196. if( buf[len_rcv-1] == '\n' )
  197. buf[len_rcv-1] = '\0';
  198. cout << buf << endl;
  199. {
  200. // add it back to the socket so that select can use it
  201. lock_guard<mutex> guard(SocketSetMutex);
  202. SocketSet.insert(handler);
  203. // push handler ID and notify sender thread
  204. WriteHandler = handler;
  205. WaitForWrite.notify_one();
  206. }
  207. break;
  208. }
  209. case SSL_ERROR_WANT_READ:
  210. case SSL_ERROR_WANT_WRITE:
  211. {
  212. cout << "WANT_SOMETHING WHEN Receive" << endl;
  213. exit(1);
  214. break;
  215. }
  216. default :
  217. {
  218. cout << "Closing connection " << handler.first << endl;
  219. ::close(handler.first);
  220. exit(1);
  221. }
  222. }
  223. }
  224. }
  225. }
  226. }
  227. void Send()
  228. {
  229. while(1)
  230. {
  231. SocketSSLHandles_t handler(0,0);
  232. {
  233. unique_lock<mutex> lock(WaitForWriteMutex);
  234. WaitForWrite.wait(lock);
  235. handler = WriteHandler;
  236. }
  237. cout << "Writing to handler " << handler.first << endl;
  238. string buf(EXCHANGE_STRING);
  239. for(int i=0; i<SEND_ITERATIONS; ++i)
  240. {
  241. int len = 0;
  242. do
  243. {
  244. lock_guard<mutex> lock(WriteReadMutex);
  245. int write_len=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
  246. switch( SSL_get_error(handler.second, write_len) )
  247. {
  248. case SSL_ERROR_NONE:
  249. {
  250. len += write_len;
  251. break;
  252. }
  253. case SSL_ERROR_WANT_READ:
  254. case SSL_ERROR_WANT_WRITE:
  255. {
  256. cout << "WANT_SOMETHING WHEN Send" << endl;
  257. exit(1);
  258. break;
  259. }
  260. default :
  261. {
  262. cout << "Closing connection " << handler.first << endl;
  263. ::close(handler.first);
  264. exit(1);
  265. }
  266. }
  267. // for debugging re-neg
  268. // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
  269. } while( len != static_cast<int>(buf.size()) );
  270. }
  271. }
  272. }
  273. void Server::waitForStop()
  274. {
  275. _sender->join();
  276. _reciver->join();
  277. _reactor->join();
  278. }
  279. Server::~Server()
  280. {
  281. delete _sender;
  282. delete _reciver;
  283. delete _reactor;
  284. }
  285. /// --- MAIN --- ///
  286. int main() {
  287. Server server;
  288. server.init();
  289. server.start();
  290. server.waitForStop();
  291. return 0;
  292. }