Nie możesz wybrać więcej, niż 25 tematów Tematy muszą się zaczynać od litery lub cyfry, mogą zawierać myślniki ('-') i mogą mieć do 35 znaków.
 
 
 

308 wiersze
8.0 KiB

  1. #include "server.h"
  2. #include <vector>
  3. #include <set>
  4. #include <boost/thread/thread.hpp>
  5. #include <boost/thread/mutex.hpp>
  6. #include <boost/thread/condition_variable.hpp>
  7. #include <boost/thread/locks.hpp>
  8. #include <boost/pending/queue.hpp>
  9. #include <netinet/tcp.h>
  10. #include <fcntl.h>
  11. #include <iostream>
  12. #include <exception>
  13. #include <openssl/ssl.h>
  14. #include <openssl/err.h>
  15. #include "defs.h"
  16. using namespace std;
  17. using namespace boost;
  18. typedef pair<int, SSL*> SocketSSLHandles_t;
  19. SocketSSLHandles_t WriteHandler(0,0);
  20. // Keeps socket handles that are ready to be read
  21. boost::queue<SocketSSLHandles_t> ReadQueue;
  22. // Socket Set is a set that keeps sockets on which we can 'select()'
  23. typedef set<SocketSSLHandles_t> SocketSet_t;
  24. SocketSet_t SocketSet;
  25. mutex SocketSetMutex;
  26. // WaitForWrite is a condition variable that is signaled when Sender must start sending the data
  27. condition_variable WaitForWrite;
  28. mutex WaitForWriteMutex;
  29. // mutex that synchronizes access to SSL_read/SSL_write
  30. mutex WriteReadMutex;
  31. // thread functions to send and receive
  32. void Receive();
  33. void Send();
  34. Server::Server()
  35. : SSLProcess(true)
  36. , _master(0)
  37. {
  38. }
  39. void Server::startListen(void) {
  40. struct sockaddr_in local_address;
  41. _master = ::socket(PF_INET, SOCK_STREAM, 0);
  42. memset(&local_address, 0, sizeof(local_address));
  43. local_address.sin_family = AF_INET;
  44. local_address.sin_port = htons(PORT);
  45. local_address.sin_addr.s_addr = INADDR_ANY;
  46. int reuseval = 1;
  47. setsockopt(_master,SOL_SOCKET,SO_REUSEADDR, &reuseval, sizeof(reuseval));
  48. // set socket non-blocking
  49. fcntl(_master, F_SETFL, O_NONBLOCK);
  50. // Bind to the socket
  51. if(::bind(_master, (struct sockaddr *)&local_address, sizeof(local_address)) != 0)
  52. throw runtime_error("Couldn't bind to local port");
  53. // Set a limit on connection queue.
  54. if(::listen(_master, 5) != 0)
  55. throw runtime_error("Not possible to get into listen state");
  56. }
  57. void Server::start()
  58. {
  59. if( !_ctx )
  60. throw runtime_error("SSL not initialized");
  61. startListen();
  62. Acceptor ac(_master, _ctx);
  63. _sender =new thread( Send );
  64. _reciver =new thread( Receive );
  65. _reactor =new thread( ac );
  66. }
  67. void Server::init()
  68. {
  69. sslInit();
  70. doServerSSLInit();
  71. }
  72. void Server::doServerSSLInit()
  73. {
  74. // Load certificate & private key
  75. if ( SSL_CTX_use_certificate_chain_file(_ctx, CERTIFICATE_FILE) <= 0) {
  76. ERR_print_errors_fp(stderr);
  77. _exit(1);
  78. }
  79. if ( SSL_CTX_use_PrivateKey_file(_ctx, PRIVATE_KEY_FILE, SSL_FILETYPE_PEM) <= 0) {
  80. ERR_print_errors_fp(stderr);
  81. _exit(1);
  82. }
  83. // Verify if public-private keypair matches
  84. if ( !SSL_CTX_check_private_key(_ctx) ) {
  85. fprintf(stderr, "Private key is invalid.\n");
  86. _exit(1);
  87. }
  88. }
  89. void Acceptor::operator()()
  90. {
  91. cout << "Entering acceptor loop..." << endl;
  92. while(1)
  93. {
  94. // wait timer for select
  95. struct timeval tv;
  96. tv.tv_sec = 0;
  97. tv.tv_usec = 10;
  98. fd_set fd_read;
  99. // set fd_sets
  100. FD_ZERO(&fd_read);
  101. // set max fd for select
  102. int maxv = _master;
  103. FD_SET(_master, &fd_read);
  104. // add all the sockets to sets
  105. {
  106. lock_guard<mutex> guard(SocketSetMutex);
  107. for(SocketSet_t::const_iterator aIt=SocketSet.begin();
  108. aIt!=SocketSet.end(); ++aIt)
  109. {
  110. FD_SET(aIt->first, &fd_read);
  111. if(aIt->first > maxv)
  112. maxv=aIt->first;
  113. }
  114. }
  115. // wait in select now
  116. select(maxv+1, &fd_read, NULL, NULL, (struct timeval *)&tv);
  117. {
  118. lock_guard<mutex> guard(SocketSetMutex);
  119. // check if you can read
  120. SocketSet_t::const_iterator aIt=SocketSet.begin();
  121. while(aIt!=SocketSet.end())
  122. {
  123. SocketSSLHandles_t aTmpHandles = *aIt;
  124. aIt++;
  125. if( FD_ISSET(aTmpHandles.first, &fd_read ) )
  126. {
  127. // you need to erase tmpHd from SocketSet - otherwise it will
  128. // be ready to read until SSL_read is not called on it
  129. SocketSet.erase(aTmpHandles);
  130. ReadQueue.push(aTmpHandles);
  131. }
  132. }
  133. }
  134. // if master is in fd_read - then it means new connection req
  135. // has arrived
  136. if( FD_ISSET(_master, &fd_read) )
  137. {
  138. int new_fd=openTCPSocket();
  139. if( new_fd >= 0 )
  140. {
  141. cout << "New socket with ID : " << new_fd
  142. << " is going to be added to map" << endl;
  143. SSL* ssl = openSSLSession(new_fd);
  144. lock_guard<mutex> guard(SocketSetMutex);
  145. SocketSet.insert(make_pair(new_fd, ssl));
  146. }
  147. }
  148. }
  149. }
  150. int Acceptor::openTCPSocket()
  151. {
  152. // Open up new connection
  153. cout << "New connection has arrived" << endl;
  154. struct sockaddr_in addr;
  155. int len = sizeof(addr);
  156. int client = accept(_master, (struct sockaddr *)&addr, (socklen_t *)&len);
  157. if(client == -1)
  158. perror("accept");
  159. return client;
  160. }
  161. SSL* Acceptor::openSSLSession(int iTCPHandle)
  162. {
  163. SSL *ssl = (SSL*) SSL_new(_ctx);
  164. SSL_set_fd(ssl, iTCPHandle);
  165. // normally this would be in other thread
  166. if(SSL_accept(ssl) == -1) {
  167. ERR_print_errors_fp(stderr);
  168. throw runtime_error("Can't SSL_accept => can't continue");
  169. }
  170. return ssl;
  171. }
  172. void Receive()
  173. {
  174. while(1)
  175. {
  176. char buf[1024];
  177. SocketSSLHandles_t handler;
  178. // TO-DO: this way it takes 100% CPU, some signal would be usefull
  179. while (!ReadQueue.empty())
  180. {
  181. handler = ReadQueue.front();
  182. ReadQueue.pop();
  183. memset(buf,'\0',1024);
  184. int len_rcv=0;
  185. cout << SSL_state_string(handler.second) << endl;
  186. {
  187. lock_guard<mutex> lock(WriteReadMutex);
  188. len_rcv = SSL_read(handler.second, buf, 1024);
  189. }
  190. if( len_rcv )
  191. {
  192. // dirty thing - if it has \n on the end - remove it
  193. if( buf[len_rcv-1] == '\n' )
  194. buf[len_rcv-1] = '\0';
  195. cout << buf << endl;
  196. // add it back to the socket so that select can use it
  197. lock_guard<mutex> guard(SocketSetMutex);
  198. SocketSet.insert(handler);
  199. // push handler ID and notify sender thread
  200. WriteHandler = handler;
  201. WaitForWrite.notify_one();
  202. }
  203. else
  204. {
  205. cout << "Closing connection " << handler.first << endl;
  206. ::close(handler.first);
  207. }
  208. }
  209. }
  210. }
  211. void Send()
  212. {
  213. while(1)
  214. {
  215. SocketSSLHandles_t handler(0,0);
  216. {
  217. unique_lock<mutex> lock(WaitForWriteMutex);
  218. WaitForWrite.wait(lock);
  219. handler = WriteHandler;
  220. }
  221. cout << "Writing to handler " << handler.first << endl;
  222. string buf(EXCHANGE_STRING);
  223. for(int i=0; i<SEND_ITERATIONS; ++i)
  224. {
  225. int len = 0;
  226. do
  227. {
  228. lock_guard<mutex> lock(WriteReadMutex);
  229. len+=SSL_write(handler.second, buf.c_str()+len, buf.size()-len);
  230. // for debugging re-neg
  231. // cout << "SSL STATE: " << SSL_state_string(handler.second) << endl;
  232. } while( len != static_cast<int>(buf.size()) );
  233. }
  234. }
  235. }
  236. void Server::waitForStop()
  237. {
  238. _sender->join();
  239. _reciver->join();
  240. _reactor->join();
  241. }
  242. Server::~Server()
  243. {
  244. delete _sender;
  245. delete _reciver;
  246. delete _reactor;
  247. }
  248. /// --- MAIN --- ///
  249. int main() {
  250. Server server;
  251. server.init();
  252. server.start();
  253. server.waitForStop();
  254. return 0;
  255. }