diff --git a/tool/transport_common.cc b/tool/transport_common.cc index 1e9b72f9..f1a944aa 100644 --- a/tool/transport_common.cc +++ b/tool/transport_common.cc @@ -365,6 +365,11 @@ bool SocketSetNonBlocking(int sock, bool is_non_blocking) { return ok; } +enum class StdinWait { + kStdinRead, + kSocketWrite, +}; + #if !defined(OPENSSL_WINDOWS) // SocketWaiter abstracts waiting for either the socket or stdin to be readable @@ -380,23 +385,28 @@ class SocketWaiter { // Wait waits for at least on of the socket or stdin or be ready. On success, // it sets |*socket_ready| and |*stdin_ready| to whether the respective - // objects are readable and returns true. On error, it returns false. - bool Wait(bool *socket_ready, bool *stdin_ready) { + // objects are readable and returns true. On error, it returns false. stdin's + // readiness may either be the socket being writable or stdin being readable, + // depending on |stdin_wait|. + bool Wait(StdinWait stdin_wait, bool *socket_ready, bool *stdin_ready) { *socket_ready = true; *stdin_ready = false; - fd_set read_fds; + fd_set read_fds, write_fds; FD_ZERO(&read_fds); - if (stdin_open_) { + FD_ZERO(&write_fds); + if (stdin_wait == StdinWait::kSocketWrite) { + FD_SET(sock_, &write_fds); + } else if (stdin_open_) { FD_SET(STDIN_FILENO, &read_fds); } FD_SET(sock_, &read_fds); - if (select(sock_ + 1, &read_fds, NULL, NULL, NULL) <= 0) { + if (select(sock_ + 1, &read_fds, &write_fds, NULL, NULL) <= 0) { perror("select"); return false; } - if (FD_ISSET(STDIN_FILENO, &read_fds)) { + if (FD_ISSET(STDIN_FILENO, &read_fds) || FD_ISSET(sock_, &write_fds)) { *stdin_ready = true; } if (FD_ISSET(sock_, &read_fds)) { @@ -522,20 +532,30 @@ class SocketWaiter { return true; } - bool Wait(bool *socket_ready, bool *stdin_ready) { + bool Wait(StdinWait stdin_wait, bool *socket_ready, bool *stdin_ready) { *socket_ready = true; *stdin_ready = false; - ScopedWSAEVENT sock_event(WSACreateEvent()); - if (!sock_event || - WSAEventSelect(sock_, sock_event.get(), FD_READ | FD_CLOSE) != 0) { - PrintSocketError("Error waiting for socket"); + ScopedWSAEVENT sock_read_event(WSACreateEvent()); + if (!sock_read_event || + WSAEventSelect(sock_, sock_read_event.get(), FD_READ | FD_CLOSE) != 0) { + PrintSocketError("Error waiting for socket read"); return false; } DWORD count = 1; - WSAEVENT events[2] = {sock_event.get(), WSA_INVALID_EVENT}; - if (listen_stdin_) { + WSAEVENT events[3] = {sock_read_event.get(), WSA_INVALID_EVENT}; + ScopedWSAEVENT sock_write_event; + if (stdin_wait == StdinWait::kSocketWrite) { + sock_write_event.reset(WSACreateEvent()); + if (!sock_write_event || WSAEventSelect(sock_, sock_write_event.get(), + FD_WRITE | FD_CLOSE) != 0) { + PrintSocketError("Error waiting for socket write"); + return false; + } + events[1] = sock_write_event.get(); + count++; + } else if (listen_stdin_) { events[1] = stdin_->event.get(); count++; } @@ -651,51 +671,49 @@ bool TransferData(SSL *ssl, int sock) { if (!waiter.Init()) { return false; } + + uint8_t pending_write[512]; + size_t pending_write_len = 0; for (;;) { bool socket_ready = false; bool stdin_ready = false; - if (!waiter.Wait(&socket_ready, &stdin_ready)) { + if (!waiter.Wait(pending_write_len == 0 ? StdinWait::kStdinRead + : StdinWait::kSocketWrite, + &socket_ready, &stdin_ready)) { return false; } if (stdin_ready) { - uint8_t buffer[512]; - size_t n; - if (!waiter.ReadStdin(buffer, &n, sizeof(buffer))) { - return false; - } - if (n == 0) { -#if !defined(OPENSSL_WINDOWS) - shutdown(sock, SHUT_WR); -#else - shutdown(sock, SD_SEND); -#endif - continue; + if (pending_write_len == 0) { + if (!waiter.ReadStdin(pending_write, &pending_write_len, + sizeof(pending_write))) { + return false; + } + if (pending_write_len == 0) { + #if !defined(OPENSSL_WINDOWS) + shutdown(sock, SHUT_WR); + #else + shutdown(sock, SD_SEND); + #endif + continue; + } } - // TODO(davidben): On Windows, |WSAEventSocket| sets |sock| to non- - // blocking and forbids setting it back to blocking. Rather than toggle - // the blocking state, loop waiting for the socket to be writable. -#if !defined(OPENSSL_WINDOWS) - if (!SocketSetNonBlocking(sock, false)) { - return false; - } -#endif - int ssl_ret = SSL_write(ssl, buffer, static_cast(n)); + int ssl_ret = + SSL_write(ssl, pending_write, static_cast(pending_write_len)); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); + if (ssl_err == SSL_ERROR_WANT_WRITE) { + continue; + } PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; - } else if (ssl_ret != static_cast(n)) { + } + if (ssl_ret != static_cast(pending_write_len)) { fprintf(stderr, "Short write from SSL_write.\n"); return false; } - - // Note we handle errors before restoring the non-blocking state. On - // Windows, |SocketSetNonBlocking| internally clears the last error. - if (!SocketSetNonBlocking(sock, true)) { - return false; - } + pending_write_len = 0; } if (socket_ready) {