summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/core/core.cpp2
-rw-r--r--src/core/internal_network/network.cpp89
-rw-r--r--src/core/internal_network/network.h3
3 files changed, 91 insertions, 3 deletions
diff --git a/src/core/core.cpp b/src/core/core.cpp
index f075ae7fa..2d6e61398 100644
--- a/src/core/core.cpp
+++ b/src/core/core.cpp
@@ -406,6 +406,7 @@ struct System::Impl {
gpu_core->NotifyShutdown();
}
+ Network::CancelPendingSocketOperations();
kernel.SuspendApplication(true);
if (services) {
services->KillNVNFlinger();
@@ -427,6 +428,7 @@ struct System::Impl {
debugger.reset();
kernel.Shutdown();
memory.Reset();
+ Network::RestartSocketOperations();
if (auto room_member = room_network.GetRoomMember().lock()) {
Network::GameInfo game_info{};
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 5d28300e6..a983f23ea 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -48,15 +48,32 @@ enum class CallType {
using socklen_t = int;
+SOCKET interrupt_socket = static_cast<SOCKET>(-1);
+
+void InterruptSocketOperations() {
+ closesocket(interrupt_socket);
+}
+
+void AcknowledgeInterrupt() {
+ interrupt_socket = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+}
+
void Initialize() {
WSADATA wsa_data;
(void)WSAStartup(MAKEWORD(2, 2), &wsa_data);
+
+ AcknowledgeInterrupt();
}
void Finalize() {
+ InterruptSocketOperations();
WSACleanup();
}
+SOCKET GetInterruptSocket() {
+ return interrupt_socket;
+}
+
sockaddr TranslateFromSockAddrIn(SockAddrIn input) {
sockaddr_in result;
@@ -157,9 +174,42 @@ constexpr int SD_RECEIVE = SHUT_RD;
constexpr int SD_SEND = SHUT_WR;
constexpr int SD_BOTH = SHUT_RDWR;
-void Initialize() {}
+int interrupt_pipe_fd[2] = {-1, -1};
-void Finalize() {}
+void Initialize() {
+ if (pipe(interrupt_pipe_fd) != 0) {
+ LOG_ERROR(Network, "Failed to create interrupt pipe!");
+ }
+ int flags = fcntl(interrupt_pipe_fd[0], F_GETFL);
+ ASSERT_MSG(fcntl(interrupt_pipe_fd[0], F_SETFL, flags | O_NONBLOCK) == 0,
+ "Failed to set nonblocking state for interrupt pipe");
+}
+
+void Finalize() {
+ if (interrupt_pipe_fd[0] >= 0) {
+ close(interrupt_pipe_fd[0]);
+ }
+ if (interrupt_pipe_fd[1] >= 0) {
+ close(interrupt_pipe_fd[1]);
+ }
+}
+
+void InterruptSocketOperations() {
+ u8 value = 0;
+ ASSERT(write(interrupt_pipe_fd[1], &value, sizeof(value)) == 1);
+}
+
+void AcknowledgeInterrupt() {
+ u8 value = 0;
+ ssize_t ret = read(interrupt_pipe_fd[0], &value, sizeof(value));
+ if (ret != 1 && errno != EAGAIN && errno != EWOULDBLOCK) {
+ LOG_ERROR(Network, "Failed to acknowledge interrupt on shutdown");
+ }
+}
+
+SOCKET GetInterruptSocket() {
+ return interrupt_pipe_fd[0];
+}
sockaddr TranslateFromSockAddrIn(SockAddrIn input) {
sockaddr_in result;
@@ -490,6 +540,14 @@ NetworkInstance::~NetworkInstance() {
Finalize();
}
+void CancelPendingSocketOperations() {
+ InterruptSocketOperations();
+}
+
+void RestartSocketOperations() {
+ AcknowledgeInterrupt();
+}
+
std::optional<IPv4Address> GetHostIPv4Address() {
const auto network_interface = Network::GetSelectedNetworkInterface();
if (!network_interface.has_value()) {
@@ -560,7 +618,14 @@ std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
return result;
});
- const int result = WSAPoll(host_pollfds.data(), static_cast<ULONG>(num), timeout);
+ host_pollfds.push_back(WSAPOLLFD{
+ .fd = GetInterruptSocket(),
+ .events = POLLIN,
+ .revents = 0,
+ });
+
+ const int result =
+ WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), timeout);
if (result == 0) {
ASSERT(std::all_of(host_pollfds.begin(), host_pollfds.end(),
[](WSAPOLLFD fd) { return fd.revents == 0; }));
@@ -627,6 +692,24 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
sockaddr_in addr;
socklen_t addrlen = sizeof(addr);
+
+ std::vector<WSAPOLLFD> host_pollfds{
+ WSAPOLLFD{fd, POLLIN, 0},
+ WSAPOLLFD{GetInterruptSocket(), POLLIN, 0},
+ };
+
+ while (true) {
+ const int pollres =
+ WSAPoll(host_pollfds.data(), static_cast<ULONG>(host_pollfds.size()), -1);
+ if (host_pollfds[1].revents != 0) {
+ // Interrupt signaled before a client could be accepted, break
+ return {AcceptResult{}, Errno::AGAIN};
+ }
+ if (pollres > 0) {
+ break;
+ }
+ }
+
const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
if (new_socket == INVALID_SOCKET) {
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index c7e20ae34..b7b7d773a 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -96,6 +96,9 @@ public:
~NetworkInstance();
};
+void CancelPendingSocketOperations();
+void RestartSocketOperations();
+
#ifdef _WIN32
constexpr IPv4Address TranslateIPv4(in_addr addr) {
auto& bytes = addr.S_un.S_un_b;