Fix bssl select loop on Windows.

While |WaitForMultipleObjects| works for both sockets and stdin, the
latter is often a line-buffered console. The |HANDLE| is considered
readable if there are any console events available, but reading blocks
until a full line is available. (In POSIX, line buffering is implemented
in the kernel via termios, which is differently concerning, but does
mean |select| works as expected.)

So that |Wait| reflects final stdin read, we spawn a stdin reader thread
that writes to an in-memory buffer and signals a |WSAEVENT| to
coordinate with the socket. This is kind of silly, but it works.

I tried just writing it to a pipe, but it appears
|WaitForMultipleObjects| does not work on pipes!

Change-Id: I2bfa323fa91aad7d2035bb1fe86ee6f54b85d811
Reviewed-on: https://boringssl-review.googlesource.com/28165
Reviewed-by: Steven Valdez <svaldez@google.com>
Commit-Queue: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
This commit is contained in:
David Benjamin 2018-05-06 02:08:48 -04:00 committed by CQ bot account: commit-bot@chromium.org
parent 2a92847c24
commit 28385db6e1

View File

@ -12,6 +12,12 @@
* OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
* CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
// Suppress MSVC's STL warnings. It flags |std::copy| calls with a raw output
// pointer, on grounds that MSVC cannot check them. Unfortunately, there is no
// way to suppress the warning just on one line. The warning is flagged inside
// the STL itself, so suppressing at the |std::copy| call does not work.
#define _SCL_SECURE_NO_WARNINGS
#include <openssl/base.h>
#include <string>
@ -33,6 +39,13 @@
#include <sys/socket.h>
#include <unistd.h>
#else
#include <algorithm>
#include <deque>
#include <memory>
#include <mutex>
#include <thread>
#include <utility>
#include <io.h>
OPENSSL_MSVC_PRAGMA(warning(push, 3))
#include <winsock2.h>
@ -347,64 +360,265 @@ bool SocketSetNonBlocking(int sock, bool is_non_blocking) {
ok = 0 == fcntl(sock, F_SETFL, flags);
#endif
if (!ok) {
fprintf(stderr, "Failed to set socket non-blocking.\n");
PrintSocketError("Failed to set socket non-blocking");
}
return ok;
}
static bool SocketSelect(int sock, bool stdin_open, bool *socket_ready,
bool *stdin_ready) {
#if !defined(OPENSSL_WINDOWS)
fd_set read_fds;
FD_ZERO(&read_fds);
if (stdin_open) {
FD_SET(0, &read_fds);
}
FD_SET(sock, &read_fds);
if (select(sock + 1, &read_fds, NULL, NULL, NULL) <= 0) {
perror("select");
return false;
}
if (FD_ISSET(0, &read_fds)) {
*stdin_ready = true;
}
if (FD_ISSET(sock, &read_fds)) {
// SocketWaiter abstracts waiting for either the socket or stdin to be readable
// between Windows and POSIX.
class SocketWaiter {
public:
explicit SocketWaiter(int sock) : sock_(sock) {}
SocketWaiter(const SocketWaiter &) = delete;
SocketWaiter &operator=(const SocketWaiter &) = delete;
// Init initializes the SocketWaiter. It returns whether it succeeded.
bool Init() { return true; }
// Wait waits for at least on of the socket or stdin or be ready. On success,
// it sets |*socket_ready| and |*stdin_ready| to whether the respective
// objects are readable and returns true. On error, it returns false.
bool Wait(bool *socket_ready, bool *stdin_ready) {
*socket_ready = true;
}
*stdin_ready = false;
return true;
#else
WSAEVENT socket_handle = WSACreateEvent();
if (socket_handle == WSA_INVALID_EVENT ||
WSAEventSelect(sock, socket_handle, FD_READ | FD_CLOSE) != 0) {
WSACloseEvent(socket_handle);
return false;
}
HANDLE read_fds[2];
read_fds[0] = socket_handle;
read_fds[1] = GetStdHandle(STD_INPUT_HANDLE);
switch (
WaitForMultipleObjects(stdin_open ? 2 : 1, read_fds, FALSE, INFINITE)) {
case WAIT_OBJECT_0 + 0:
*socket_ready = true;
break;
case WAIT_OBJECT_0 + 1:
*stdin_ready = true;
break;
case WAIT_TIMEOUT:
break;
default:
WSACloseEvent(socket_handle);
fd_set read_fds;
FD_ZERO(&read_fds);
if (stdin_open_) {
FD_SET(STDIN_FILENO, &read_fds);
}
FD_SET(sock_, &read_fds);
if (select(sock_ + 1, &read_fds, NULL, NULL, NULL) <= 0) {
perror("select");
return false;
}
if (FD_ISSET(STDIN_FILENO, &read_fds)) {
*stdin_ready = true;
}
if (FD_ISSET(sock_, &read_fds)) {
*socket_ready = true;
}
return true;
}
WSACloseEvent(socket_handle);
return true;
#endif
}
// ReadStdin reads at most |max_out| bytes from stdin. On success, it writes
// them to |out| and sets |*out_len| to the number of bytes written. On error,
// it returns false. This method may only be called after |Wait| returned
// stdin was ready.
bool ReadStdin(void *out, size_t *out_len, size_t max_out) {
ssize_t n;
do {
n = read(STDIN_FILENO, out, max_out);
} while (n == -1 && errno == EINTR);
if (n < 0) {
perror("read from stdin");
return false;
}
*out_len = static_cast<size_t>(n);
return true;
}
private:
bool stdin_open_ = true;
int sock_;
};
#else // OPENSSL_WINDOWs
class ScopedWSAEVENT {
public:
ScopedWSAEVENT() = default;
ScopedWSAEVENT(WSAEVENT event) { reset(event); }
ScopedWSAEVENT(const ScopedWSAEVENT &) = delete;
ScopedWSAEVENT(ScopedWSAEVENT &&other) { *this = std::move(other); }
~ScopedWSAEVENT() { reset(); }
ScopedWSAEVENT &operator=(const ScopedWSAEVENT &) = delete;
ScopedWSAEVENT &operator=(ScopedWSAEVENT &&other) { reset(other.release()); }
explicit operator bool() const { return event_ != WSA_INVALID_EVENT; }
WSAEVENT get() const { return event_; }
WSAEVENT release() {
WSAEVENT ret = event_;
event_ = WSA_INVALID_EVENT;
return ret;
}
void reset(WSAEVENT event = WSA_INVALID_EVENT) {
if (event_ != WSA_INVALID_EVENT) {
WSACloseEvent(event_);
}
event_ = event;
}
private:
WSAEVENT event_ = WSA_INVALID_EVENT;
};
// SocketWaiter, on Windows, is more complicated. While |WaitForMultipleObjects|
// works for both sockets and stdin, the latter is often a line-buffered
// console. The |HANDLE| is considered readable if there are any console events
// available, but reading blocks until a full line is available.
//
// So that |Wait| reflects final stdin read, we spawn a stdin reader thread that
// writes to an in-memory buffer and signals a |WSAEVENT| to coordinate with the
// socket.
class SocketWaiter {
public:
explicit SocketWaiter(int sock) : sock_(sock) {}
SocketWaiter(const SocketWaiter &) = delete;
SocketWaiter &operator=(const SocketWaiter &) = delete;
bool Init() {
stdin_ = std::make_shared<StdinState>();
stdin_->event.reset(WSACreateEvent());
if (!stdin_->event) {
PrintSocketError("Error in WSACreateEvent");
return false;
}
// Spawn a thread to block on stdin.
std::shared_ptr<StdinState> state = stdin_;
std::thread thread([state]() {
for (;;) {
uint8_t buf[512];
int ret = _read(0 /* stdin */, buf, sizeof(buf));
if (ret <= 0) {
if (ret < 0) {
perror("read from stdin");
}
// Report the error or EOF to the caller.
std::lock_guard<std::mutex> lock(state->lock);
state->error = ret < 0;
state->open = false;
WSASetEvent(state->event.get());
return;
}
size_t len = static_cast<size_t>(ret);
size_t written = 0;
while (written < len) {
std::unique_lock<std::mutex> lock(state->lock);
// Wait for there to be room in the buffer.
state->cond.wait(lock, [&] { return !state->buffer_full(); });
// Copy what we can and signal to the caller.
size_t todo = std::min(len - written, state->buffer_remaining());
state->buffer.insert(state->buffer.end(), buf + written,
buf + written + todo);
written += todo;
WSASetEvent(state->event.get());
}
}
});
thread.detach();
return true;
}
bool Wait(bool *socket_ready, bool *stdin_ready) {
*socket_ready = true;
*stdin_ready = false;
ScopedWSAEVENT sock_event(WSACreateEvent());
if (!sock_event ||
WSAEventSelect(sock_, sock_event.get(), FD_READ | FD_CLOSE) != 0) {
PrintSocketError("Error waiting for socket");
return false;
}
DWORD count = 1;
WSAEVENT events[2] = {sock_event.get(), WSA_INVALID_EVENT};
if (listen_stdin_) {
events[1] = stdin_->event.get();
count++;
}
switch (WSAWaitForMultipleEvents(count, events, FALSE /* wait all */,
WSA_INFINITE, FALSE /* alertable */)) {
case WSA_WAIT_EVENT_0 + 0:
*socket_ready = true;
return true;
case WSA_WAIT_EVENT_0 + 1:
*stdin_ready = true;
return true;
case WSA_WAIT_TIMEOUT:
return true;
default:
PrintSocketError("Error waiting for events");
return false;
}
}
bool ReadStdin(void *out, size_t *out_len, size_t max_out) {
std::lock_guard<std::mutex> locked(stdin_->lock);
if (stdin_->buffer.empty()) {
// |ReadStdin| may only be called when |Wait| signals it is ready, so
// stdin must have reached EOF or error.
assert(!stdin_->open);
listen_stdin_ = false;
if (stdin_->error) {
return false;
}
*out_len = 0;
return true;
}
bool was_full = stdin_->buffer_full();
// Copy as many bytes as well fit.
*out_len = std::min(max_out, stdin_->buffer.size());
auto begin = stdin_->buffer.begin();
auto end = stdin_->buffer.begin() + *out_len;
std::copy(begin, end, static_cast<uint8_t *>(out));
stdin_->buffer.erase(begin, end);
// Notify the stdin thread if there is more space.
if (was_full && !stdin_->buffer_full()) {
stdin_->cond.notify_one();
}
// If stdin is now waiting for input, clear the event.
if (stdin_->buffer.empty() && stdin_->open) {
WSAResetEvent(stdin_->event.get());
}
return true;
}
private:
struct StdinState {
static constexpr size_t kMaxBuffer = 1024;
StdinState() = default;
StdinState(const StdinState &) = delete;
StdinState &operator=(const StdinState &) = delete;
size_t buffer_remaining() const { return kMaxBuffer - buffer.size(); }
bool buffer_full() const { return buffer_remaining() == 0; }
ScopedWSAEVENT event;
// lock protects the following fields.
std::mutex lock;
// cond notifies the stdin thread that |buffer| is no longer full.
std::condition_variable cond;
std::deque<uint8_t> buffer;
bool open = true;
bool error = false;
};
int sock_;
std::shared_ptr<StdinState> stdin_;
// listen_stdin_ is set to false when we have consumed an EOF or error from
// |stdin_|. This is separate from |stdin_->open| because the signal may not
// have been consumed yet.
bool listen_stdin_ = true;
};
#endif // OPENSSL_WINDOWS
void PrintSSLError(FILE *file, const char *msg, int ssl_err, int ret) {
switch (ssl_err) {
@ -433,47 +647,46 @@ bool TransferData(SSL *ssl, int sock) {
return false;
}
bool stdin_open = true;
SocketWaiter waiter(sock);
if (!waiter.Init()) {
return false;
}
for (;;) {
bool socket_ready = false;
bool stdin_ready = false;
if (!SocketSelect(sock, stdin_open, &socket_ready, &stdin_ready)) {
if (!waiter.Wait(&socket_ready, &stdin_ready)) {
return false;
}
if (stdin_ready) {
uint8_t buffer[512];
ssize_t n;
do {
n = BORINGSSL_READ(0, buffer, sizeof(buffer));
} while (n == -1 && errno == EINTR);
size_t n;
if (!waiter.ReadStdin(buffer, &n, sizeof(buffer))) {
return false;
}
if (n == 0) {
stdin_open = false;
#if !defined(OPENSSL_WINDOWS)
shutdown(sock, SHUT_WR);
#else
shutdown(sock, SD_SEND);
#endif
continue;
} else if (n < 0) {
perror("read from stdin");
return false;
}
// On Windows, SocketSelect ends up setting sock to non-blocking.
// TODO(davidben): On Windows, |WSAEventSocket| sets |sock| to non-
// blocking and forbids setting it back to blocking. Rather than toggle
// the blocking state, loop waiting for the socket to be writable.
#if !defined(OPENSSL_WINDOWS)
if (!SocketSetNonBlocking(sock, false)) {
return false;
}
#endif
int ssl_ret = SSL_write(ssl, buffer, n);
int ssl_ret = SSL_write(ssl, buffer, static_cast<int>(n));
if (ssl_ret <= 0) {
int ssl_err = SSL_get_error(ssl, ssl_ret);
PrintSSLError(stderr, "Error while writing", ssl_err, ssl_ret);
return false;
} else if (ssl_ret != n) {
} else if (ssl_ret != static_cast<int>(n)) {
fprintf(stderr, "Short write from SSL_write.\n");
return false;
}