diff --git a/tool/transport_common.cc b/tool/transport_common.cc index dcb8e0d3..f1b03abb 100644 --- a/tool/transport_common.cc +++ b/tool/transport_common.cc @@ -12,6 +12,12 @@ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ +// Suppress MSVC's STL warnings. It flags |std::copy| calls with a raw output +// pointer, on grounds that MSVC cannot check them. Unfortunately, there is no +// way to suppress the warning just on one line. The warning is flagged inside +// the STL itself, so suppressing at the |std::copy| call does not work. +#define _SCL_SECURE_NO_WARNINGS + #include #include @@ -33,6 +39,13 @@ #include #include #else +#include +#include +#include +#include +#include +#include + #include OPENSSL_MSVC_PRAGMA(warning(push, 3)) #include @@ -347,64 +360,265 @@ bool SocketSetNonBlocking(int sock, bool is_non_blocking) { ok = 0 == fcntl(sock, F_SETFL, flags); #endif if (!ok) { - fprintf(stderr, "Failed to set socket non-blocking.\n"); + PrintSocketError("Failed to set socket non-blocking"); } return ok; } -static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready, - bool *stdin_ready) { #if !defined(OPENSSL_WINDOWS) - fd_set read_fds; - FD_ZERO(&read_fds); - if (stdin_open) { - FD_SET(0, &read_fds); - } - FD_SET(sock, &read_fds); - if (select(sock + 1, &read_fds, NULL, NULL, NULL) <= 0) { - perror("select"); - return false; - } - if (FD_ISSET(0, &read_fds)) { - *stdin_ready = true; - } - if (FD_ISSET(sock, &read_fds)) { +// SocketWaiter abstracts waiting for either the socket or stdin to be readable +// between Windows and POSIX. +class SocketWaiter { + public: + explicit SocketWaiter(int sock) : sock_(sock) {} + SocketWaiter(const SocketWaiter &) = delete; + SocketWaiter &operator=(const SocketWaiter &) = delete; + + // Init initializes the SocketWaiter. It returns whether it succeeded. + bool Init() { return true; } + + // Wait waits for at least on of the socket or stdin or be ready. On success, + // it sets |*socket_ready| and |*stdin_ready| to whether the respective + // objects are readable and returns true. On error, it returns false. + bool Wait(bool *socket_ready, bool *stdin_ready) { *socket_ready = true; - } + *stdin_ready = false; - return true; -#else - WSAEVENT socket_handle = WSACreateEvent(); - if (socket_handle == WSA_INVALID_EVENT || - WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) { - WSACloseEvent(socket_handle); - return false; - } - - HANDLE read_fds[2]; - read_fds[0] = socket_handle; - read_fds[1] = GetStdHandle(STD_INPUT_HANDLE); - - switch ( - WaitForMultipleObjects(stdin_open ? 2 : 1, read_fds, FALSE, INFINITE)) { - case WAIT_OBJECT_0 + 0: - *socket_ready = true; - break; - case WAIT_OBJECT_0 + 1: - *stdin_ready = true; - break; - case WAIT_TIMEOUT: - break; - default: - WSACloseEvent(socket_handle); + fd_set read_fds; + FD_ZERO(&read_fds); + if (stdin_open_) { + FD_SET(STDIN_FILENO, &read_fds); + } + FD_SET(sock_, &read_fds); + if (select(sock_ + 1, &read_fds, NULL, NULL, NULL) <= 0) { + perror("select"); return false; + } + + if (FD_ISSET(STDIN_FILENO, &read_fds)) { + *stdin_ready = true; + } + if (FD_ISSET(sock_, &read_fds)) { + *socket_ready = true; + } + + return true; } - WSACloseEvent(socket_handle); - return true; -#endif -} + // ReadStdin reads at most |max_out| bytes from stdin. On success, it writes + // them to |out| and sets |*out_len| to the number of bytes written. On error, + // it returns false. This method may only be called after |Wait| returned + // stdin was ready. + bool ReadStdin(void *out, size_t *out_len, size_t max_out) { + ssize_t n; + do { + n = read(STDIN_FILENO, out, max_out); + } while (n == -1 && errno == EINTR); + if (n < 0) { + perror("read from stdin"); + return false; + } + *out_len = static_cast(n); + return true; + } + + private: + bool stdin_open_ = true; + int sock_; +}; + +#else // OPENSSL_WINDOWs + +class ScopedWSAEVENT { + public: + ScopedWSAEVENT() = default; + ScopedWSAEVENT(WSAEVENT event) { reset(event); } + ScopedWSAEVENT(const ScopedWSAEVENT &) = delete; + ScopedWSAEVENT(ScopedWSAEVENT &&other) { *this = std::move(other); } + + ~ScopedWSAEVENT() { reset(); } + + ScopedWSAEVENT &operator=(const ScopedWSAEVENT &) = delete; + ScopedWSAEVENT &operator=(ScopedWSAEVENT &&other) { reset(other.release()); } + + explicit operator bool() const { return event_ != WSA_INVALID_EVENT; } + WSAEVENT get() const { return event_; } + + WSAEVENT release() { + WSAEVENT ret = event_; + event_ = WSA_INVALID_EVENT; + return ret; + } + + void reset(WSAEVENT event = WSA_INVALID_EVENT) { + if (event_ != WSA_INVALID_EVENT) { + WSACloseEvent(event_); + } + event_ = event; + } + + private: + WSAEVENT event_ = WSA_INVALID_EVENT; +}; + +// SocketWaiter, on Windows, is more complicated. While |WaitForMultipleObjects| +// works for both sockets and stdin, the latter is often a line-buffered +// console. The |HANDLE| is considered readable if there are any console events +// available, but reading blocks until a full line is available. +// +// So that |Wait| reflects final stdin read, we spawn a stdin reader thread that +// writes to an in-memory buffer and signals a |WSAEVENT| to coordinate with the +// socket. +class SocketWaiter { + public: + explicit SocketWaiter(int sock) : sock_(sock) {} + SocketWaiter(const SocketWaiter &) = delete; + SocketWaiter &operator=(const SocketWaiter &) = delete; + + bool Init() { + stdin_ = std::make_shared(); + stdin_->event.reset(WSACreateEvent()); + if (!stdin_->event) { + PrintSocketError("Error in WSACreateEvent"); + return false; + } + + // Spawn a thread to block on stdin. + std::shared_ptr state = stdin_; + std::thread thread([state]() { + for (;;) { + uint8_t buf[512]; + int ret = _read(0 /* stdin */, buf, sizeof(buf)); + if (ret <= 0) { + if (ret < 0) { + perror("read from stdin"); + } + // Report the error or EOF to the caller. + std::lock_guard lock(state->lock); + state->error = ret < 0; + state->open = false; + WSASetEvent(state->event.get()); + return; + } + + size_t len = static_cast(ret); + size_t written = 0; + while (written < len) { + std::unique_lock lock(state->lock); + // Wait for there to be room in the buffer. + state->cond.wait(lock, [&] { return !state->buffer_full(); }); + + // Copy what we can and signal to the caller. + size_t todo = std::min(len - written, state->buffer_remaining()); + state->buffer.insert(state->buffer.end(), buf + written, + buf + written + todo); + written += todo; + WSASetEvent(state->event.get()); + } + } + }); + thread.detach(); + return true; + } + + bool Wait(bool *socket_ready, bool *stdin_ready) { + *socket_ready = true; + *stdin_ready = false; + + ScopedWSAEVENT sock_event(WSACreateEvent()); + if (!sock_event || + WSAEventSelect(sock_, sock_event.get(), FD_READ | FD_CLOSE) != 0) { + PrintSocketError("Error waiting for socket"); + return false; + } + + DWORD count = 1; + WSAEVENT events[2] = {sock_event.get(), WSA_INVALID_EVENT}; + if (listen_stdin_) { + events[1] = stdin_->event.get(); + count++; + } + + switch (WSAWaitForMultipleEvents(count, events, FALSE /* wait all */, + WSA_INFINITE, FALSE /* alertable */)) { + case WSA_WAIT_EVENT_0 + 0: + *socket_ready = true; + return true; + case WSA_WAIT_EVENT_0 + 1: + *stdin_ready = true; + return true; + case WSA_WAIT_TIMEOUT: + return true; + default: + PrintSocketError("Error waiting for events"); + return false; + } + } + + bool ReadStdin(void *out, size_t *out_len, size_t max_out) { + std::lock_guard locked(stdin_->lock); + + if (stdin_->buffer.empty()) { + // |ReadStdin| may only be called when |Wait| signals it is ready, so + // stdin must have reached EOF or error. + assert(!stdin_->open); + listen_stdin_ = false; + if (stdin_->error) { + return false; + } + *out_len = 0; + return true; + } + + bool was_full = stdin_->buffer_full(); + // Copy as many bytes as well fit. + *out_len = std::min(max_out, stdin_->buffer.size()); + auto begin = stdin_->buffer.begin(); + auto end = stdin_->buffer.begin() + *out_len; + std::copy(begin, end, static_cast(out)); + stdin_->buffer.erase(begin, end); + // Notify the stdin thread if there is more space. + if (was_full && !stdin_->buffer_full()) { + stdin_->cond.notify_one(); + } + // If stdin is now waiting for input, clear the event. + if (stdin_->buffer.empty() && stdin_->open) { + WSAResetEvent(stdin_->event.get()); + } + return true; + } + + private: + struct StdinState { + static constexpr size_t kMaxBuffer = 1024; + + StdinState() = default; + StdinState(const StdinState &) = delete; + StdinState &operator=(const StdinState &) = delete; + + size_t buffer_remaining() const { return kMaxBuffer - buffer.size(); } + bool buffer_full() const { return buffer_remaining() == 0; } + + ScopedWSAEVENT event; + // lock protects the following fields. + std::mutex lock; + // cond notifies the stdin thread that |buffer| is no longer full. + std::condition_variable cond; + std::deque buffer; + bool open = true; + bool error = false; + }; + + int sock_; + std::shared_ptr stdin_; + // listen_stdin_ is set to false when we have consumed an EOF or error from + // |stdin_|. This is separate from |stdin_->open| because the signal may not + // have been consumed yet. + bool listen_stdin_ = true; +}; + +#endif // OPENSSL_WINDOWS void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) { switch (ssl_err) { @@ -433,47 +647,46 @@ bool TransferData(SSL *ssl, int sock) { return false; } - bool stdin_open = true; + SocketWaiter waiter(sock); + if (!waiter.Init()) { + return false; + } for (;;) { bool socket_ready = false; bool stdin_ready = false; - if (!SocketSelect(sock, stdin_open, &socket_ready, &stdin_ready)) { + if (!waiter.Wait(&socket_ready, &stdin_ready)) { return false; } if (stdin_ready) { uint8_t buffer[512]; - ssize_t n; - - do { - n = BORINGSSL_READ(0, buffer, sizeof(buffer)); - } while (n == -1 && errno == EINTR); - + size_t n; + if (!waiter.ReadStdin(buffer, &n, sizeof(buffer))) { + return false; + } if (n == 0) { - stdin_open = false; #if !defined(OPENSSL_WINDOWS) shutdown(sock, SHUT_WR); #else shutdown(sock, SD_SEND); #endif continue; - } else if (n < 0) { - perror("read from stdin"); - return false; } - // On Windows, SocketSelect ends up setting sock to non-blocking. + // TODO(davidben): On Windows, |WSAEventSocket| sets |sock| to non- + // blocking and forbids setting it back to blocking. Rather than toggle + // the blocking state, loop waiting for the socket to be writable. #if !defined(OPENSSL_WINDOWS) if (!SocketSetNonBlocking(sock, false)) { return false; } #endif - int ssl_ret = SSL_write(ssl, buffer, n); + int ssl_ret = SSL_write(ssl, buffer, static_cast(n)); if (ssl_ret <= 0) { int ssl_err = SSL_get_error(ssl, ssl_ret); PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret); return false; - } else if (ssl_ret != n) { + } else if (ssl_ret != static_cast(n)) { fprintf(stderr, "Short write from SSL_write.\n"); return false; }