summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/common/CMakeLists.txt2
-rw-r--r--src/common/thread_worker.cpp58
-rw-r--r--src/common/thread_worker.h102
-rw-r--r--src/common/unique_function.h62
-rw-r--r--src/tests/CMakeLists.txt1
-rw-r--r--src/tests/common/unique_function.cpp108
6 files changed, 266 insertions, 67 deletions
diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt
index a6fa9a85d..e03fffd8d 100644
--- a/src/common/CMakeLists.txt
+++ b/src/common/CMakeLists.txt
@@ -180,7 +180,6 @@ add_library(common STATIC
thread.cpp
thread.h
thread_queue_list.h
- thread_worker.cpp
thread_worker.h
threadsafe_queue.h
time_zone.cpp
@@ -188,6 +187,7 @@ add_library(common STATIC
tiny_mt.h
tree.h
uint128.h
+ unique_function.h
uuid.cpp
uuid.h
vector_math.h
diff --git a/src/common/thread_worker.cpp b/src/common/thread_worker.cpp
deleted file mode 100644
index 8f9bf447a..000000000
--- a/src/common/thread_worker.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-// Copyright 2020 yuzu emulator team
-// Licensed under GPLv2 or any later version
-// Refer to the license.txt file included.
-
-#include "common/thread.h"
-#include "common/thread_worker.h"
-
-namespace Common {
-
-ThreadWorker::ThreadWorker(std::size_t num_workers, const std::string& name) {
- for (std::size_t i = 0; i < num_workers; ++i)
- threads.emplace_back([this, thread_name{std::string{name}}] {
- Common::SetCurrentThreadName(thread_name.c_str());
-
- // Wait for first request
- {
- std::unique_lock lock{queue_mutex};
- condition.wait(lock, [this] { return stop || !requests.empty(); });
- }
-
- while (true) {
- std::function<void()> task;
-
- {
- std::unique_lock lock{queue_mutex};
- condition.wait(lock, [this] { return stop || !requests.empty(); });
- if (stop || requests.empty()) {
- return;
- }
- task = std::move(requests.front());
- requests.pop();
- }
-
- task();
- }
- });
-}
-
-ThreadWorker::~ThreadWorker() {
- {
- std::unique_lock lock{queue_mutex};
- stop = true;
- }
- condition.notify_all();
- for (std::thread& thread : threads) {
- thread.join();
- }
-}
-
-void ThreadWorker::QueueWork(std::function<void()>&& work) {
- {
- std::unique_lock lock{queue_mutex};
- requests.emplace(work);
- }
- condition.notify_one();
-}
-
-} // namespace Common
diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h
index f1859971f..8272985ff 100644
--- a/src/common/thread_worker.h
+++ b/src/common/thread_worker.h
@@ -7,24 +7,110 @@
#include <atomic>
#include <functional>
#include <mutex>
+#include <stop_token>
#include <string>
+#include <thread>
+#include <type_traits>
#include <vector>
#include <queue>
+#include "common/thread.h"
+#include "common/unique_function.h"
+
namespace Common {
-class ThreadWorker final {
+template <class StateType = void>
+class StatefulThreadWorker {
+ static constexpr bool with_state = !std::is_same_v<StateType, void>;
+
+ struct DummyCallable {
+ int operator()() const noexcept {
+ return 0;
+ }
+ };
+
+ using Task =
+ std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>;
+ using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>;
+
public:
- explicit ThreadWorker(std::size_t num_workers, const std::string& name);
- ~ThreadWorker();
- void QueueWork(std::function<void()>&& work);
+ explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {})
+ : workers_queued{num_workers}, thread_name{std::move(name)} {
+ const auto lambda = [this, func](std::stop_token stop_token) {
+ Common::SetCurrentThreadName(thread_name.c_str());
+ {
+ std::conditional_t<with_state, StateType, int> state{func()};
+ while (!stop_token.stop_requested()) {
+ Task task;
+ {
+ std::unique_lock lock{queue_mutex};
+ if (requests.empty()) {
+ wait_condition.notify_all();
+ }
+ condition.wait(lock, stop_token, [this] { return !requests.empty(); });
+ if (stop_token.stop_requested()) {
+ break;
+ }
+ task = std::move(requests.front());
+ requests.pop();
+ }
+ if constexpr (with_state) {
+ task(&state);
+ } else {
+ task();
+ }
+ ++work_done;
+ }
+ }
+ ++workers_stopped;
+ wait_condition.notify_all();
+ };
+ threads.reserve(num_workers);
+ for (size_t i = 0; i < num_workers; ++i) {
+ threads.emplace_back(lambda);
+ }
+ }
+
+ StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete;
+ StatefulThreadWorker(const StatefulThreadWorker&) = delete;
+
+ StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete;
+ StatefulThreadWorker(StatefulThreadWorker&&) = delete;
+
+ void QueueWork(Task work) {
+ {
+ std::unique_lock lock{queue_mutex};
+ requests.emplace(std::move(work));
+ ++work_scheduled;
+ }
+ condition.notify_one();
+ }
+
+ void WaitForRequests(std::stop_token stop_token = {}) {
+ std::stop_callback callback(stop_token, [this] {
+ for (auto& thread : threads) {
+ thread.request_stop();
+ }
+ });
+ std::unique_lock lock{queue_mutex};
+ wait_condition.wait(lock, [this] {
+ return workers_stopped >= workers_queued || work_done >= work_scheduled;
+ });
+ }
private:
- std::vector<std::thread> threads;
- std::queue<std::function<void()>> requests;
+ std::queue<Task> requests;
std::mutex queue_mutex;
- std::condition_variable condition;
- std::atomic_bool stop{};
+ std::condition_variable_any condition;
+ std::condition_variable wait_condition;
+ std::atomic<size_t> work_scheduled{};
+ std::atomic<size_t> work_done{};
+ std::atomic<size_t> workers_stopped{};
+ std::atomic<size_t> workers_queued{};
+ std::string thread_name;
+ std::vector<std::jthread> threads;
};
+using ThreadWorker = StatefulThreadWorker<>;
+
} // namespace Common
diff --git a/src/common/unique_function.h b/src/common/unique_function.h
new file mode 100644
index 000000000..ca0559071
--- /dev/null
+++ b/src/common/unique_function.h
@@ -0,0 +1,62 @@
+// Copyright 2021 yuzu emulator team
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#pragma once
+
+#include <memory>
+#include <utility>
+
+namespace Common {
+
+/// General purpose function wrapper similar to std::function.
+/// Unlike std::function, the captured values don't have to be copyable.
+/// This class can be moved but not copied.
+template <typename ResultType, typename... Args>
+class UniqueFunction {
+ class CallableBase {
+ public:
+ virtual ~CallableBase() = default;
+ virtual ResultType operator()(Args&&...) = 0;
+ };
+
+ template <typename Functor>
+ class Callable final : public CallableBase {
+ public:
+ Callable(Functor&& functor_) : functor{std::move(functor_)} {}
+ ~Callable() override = default;
+
+ ResultType operator()(Args&&... args) override {
+ return functor(std::forward<Args>(args)...);
+ }
+
+ private:
+ Functor functor;
+ };
+
+public:
+ UniqueFunction() = default;
+
+ template <typename Functor>
+ UniqueFunction(Functor&& functor)
+ : callable{std::make_unique<Callable<Functor>>(std::move(functor))} {}
+
+ UniqueFunction& operator=(UniqueFunction&& rhs) noexcept = default;
+ UniqueFunction(UniqueFunction&& rhs) noexcept = default;
+
+ UniqueFunction& operator=(const UniqueFunction&) = delete;
+ UniqueFunction(const UniqueFunction&) = delete;
+
+ ResultType operator()(Args&&... args) const {
+ return (*callable)(std::forward<Args>(args)...);
+ }
+
+ explicit operator bool() const noexcept {
+ return static_cast<bool>(callable);
+ }
+
+private:
+ std::unique_ptr<CallableBase> callable;
+};
+
+} // namespace Common
diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt
index 96bc30cac..c4c012f3d 100644
--- a/src/tests/CMakeLists.txt
+++ b/src/tests/CMakeLists.txt
@@ -5,6 +5,7 @@ add_executable(tests
common/host_memory.cpp
common/param_package.cpp
common/ring_buffer.cpp
+ common/unique_function.cpp
core/core_timing.cpp
core/network/network.cpp
tests.cpp
diff --git a/src/tests/common/unique_function.cpp b/src/tests/common/unique_function.cpp
new file mode 100644
index 000000000..ac9912738
--- /dev/null
+++ b/src/tests/common/unique_function.cpp
@@ -0,0 +1,108 @@
+// Copyright 2021 yuzu Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
+#include <string>
+
+#include <catch2/catch.hpp>
+
+#include "common/unique_function.h"
+
+namespace {
+struct Noisy {
+ Noisy() : state{"Default constructed"} {}
+ Noisy(Noisy&& rhs) noexcept : state{"Move constructed"} {
+ rhs.state = "Moved away";
+ }
+ Noisy& operator=(Noisy&& rhs) noexcept {
+ state = "Move assigned";
+ rhs.state = "Moved away";
+ }
+ Noisy(const Noisy&) : state{"Copied constructed"} {}
+ Noisy& operator=(const Noisy&) {
+ state = "Copied assigned";
+ }
+
+ std::string state;
+};
+} // Anonymous namespace
+
+TEST_CASE("UniqueFunction", "[common]") {
+ SECTION("Capture reference") {
+ int value = 0;
+ Common::UniqueFunction<void> func = [&value] { value = 5; };
+ func();
+ REQUIRE(value == 5);
+ }
+ SECTION("Capture pointer") {
+ int value = 0;
+ int* pointer = &value;
+ Common::UniqueFunction<void> func = [pointer] { *pointer = 5; };
+ func();
+ REQUIRE(value == 5);
+ }
+ SECTION("Move object") {
+ Noisy noisy;
+ REQUIRE(noisy.state == "Default constructed");
+
+ Common::UniqueFunction<void> func = [noisy = std::move(noisy)] {
+ REQUIRE(noisy.state == "Move constructed");
+ };
+ REQUIRE(noisy.state == "Moved away");
+ func();
+ }
+ SECTION("Move construct function") {
+ int value = 0;
+ Common::UniqueFunction<void> func = [&value] { value = 5; };
+ Common::UniqueFunction<void> new_func = std::move(func);
+ new_func();
+ REQUIRE(value == 5);
+ }
+ SECTION("Move assign function") {
+ int value = 0;
+ Common::UniqueFunction<void> func = [&value] { value = 5; };
+ Common::UniqueFunction<void> new_func;
+ new_func = std::move(func);
+ new_func();
+ REQUIRE(value == 5);
+ }
+ SECTION("Default construct then assign function") {
+ int value = 0;
+ Common::UniqueFunction<void> func;
+ func = [&value] { value = 5; };
+ func();
+ REQUIRE(value == 5);
+ }
+ SECTION("Pass arguments") {
+ int result = 0;
+ Common::UniqueFunction<void, int, int> func = [&result](int a, int b) { result = a + b; };
+ func(5, 4);
+ REQUIRE(result == 9);
+ }
+ SECTION("Pass arguments and return value") {
+ Common::UniqueFunction<int, int, int> func = [](int a, int b) { return a + b; };
+ REQUIRE(func(5, 4) == 9);
+ }
+ SECTION("Destructor") {
+ int num_destroyed = 0;
+ struct Foo {
+ Foo(int* num_) : num{num_} {}
+ Foo(Foo&& rhs) : num{std::exchange(rhs.num, nullptr)} {}
+ Foo(const Foo&) = delete;
+
+ ~Foo() {
+ if (num) {
+ ++*num;
+ }
+ }
+
+ int* num = nullptr;
+ };
+ Foo object{&num_destroyed};
+ {
+ Common::UniqueFunction<void> func = [object = std::move(object)] {};
+ REQUIRE(num_destroyed == 0);
+ }
+ REQUIRE(num_destroyed == 1);
+ }
+}