diff options
Diffstat (limited to 'src/common')
-rw-r--r-- | src/common/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/common/thread_worker.cpp | 58 | ||||
-rw-r--r-- | src/common/thread_worker.h | 102 | ||||
-rw-r--r-- | src/common/unique_function.h | 62 |
4 files changed, 157 insertions, 67 deletions
diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index a6fa9a85d..e03fffd8d 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -180,7 +180,6 @@ add_library(common STATIC thread.cpp thread.h thread_queue_list.h - thread_worker.cpp thread_worker.h threadsafe_queue.h time_zone.cpp @@ -188,6 +187,7 @@ add_library(common STATIC tiny_mt.h tree.h uint128.h + unique_function.h uuid.cpp uuid.h vector_math.h diff --git a/src/common/thread_worker.cpp b/src/common/thread_worker.cpp deleted file mode 100644 index 8f9bf447a..000000000 --- a/src/common/thread_worker.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright 2020 yuzu emulator team -// Licensed under GPLv2 or any later version -// Refer to the license.txt file included. - -#include "common/thread.h" -#include "common/thread_worker.h" - -namespace Common { - -ThreadWorker::ThreadWorker(std::size_t num_workers, const std::string& name) { - for (std::size_t i = 0; i < num_workers; ++i) - threads.emplace_back([this, thread_name{std::string{name}}] { - Common::SetCurrentThreadName(thread_name.c_str()); - - // Wait for first request - { - std::unique_lock lock{queue_mutex}; - condition.wait(lock, [this] { return stop || !requests.empty(); }); - } - - while (true) { - std::function<void()> task; - - { - std::unique_lock lock{queue_mutex}; - condition.wait(lock, [this] { return stop || !requests.empty(); }); - if (stop || requests.empty()) { - return; - } - task = std::move(requests.front()); - requests.pop(); - } - - task(); - } - }); -} - -ThreadWorker::~ThreadWorker() { - { - std::unique_lock lock{queue_mutex}; - stop = true; - } - condition.notify_all(); - for (std::thread& thread : threads) { - thread.join(); - } -} - -void ThreadWorker::QueueWork(std::function<void()>&& work) { - { - std::unique_lock lock{queue_mutex}; - requests.emplace(work); - } - condition.notify_one(); -} - -} // namespace Common diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h index f1859971f..8272985ff 100644 --- a/src/common/thread_worker.h +++ b/src/common/thread_worker.h @@ -7,24 +7,110 @@ #include <atomic> #include <functional> #include <mutex> +#include <stop_token> #include <string> +#include <thread> +#include <type_traits> #include <vector> #include <queue> +#include "common/thread.h" +#include "common/unique_function.h" + namespace Common { -class ThreadWorker final { +template <class StateType = void> +class StatefulThreadWorker { + static constexpr bool with_state = !std::is_same_v<StateType, void>; + + struct DummyCallable { + int operator()() const noexcept { + return 0; + } + }; + + using Task = + std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>; + using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>; + public: - explicit ThreadWorker(std::size_t num_workers, const std::string& name); - ~ThreadWorker(); - void QueueWork(std::function<void()>&& work); + explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) + : workers_queued{num_workers}, thread_name{std::move(name)} { + const auto lambda = [this, func](std::stop_token stop_token) { + Common::SetCurrentThreadName(thread_name.c_str()); + { + std::conditional_t<with_state, StateType, int> state{func()}; + while (!stop_token.stop_requested()) { + Task task; + { + std::unique_lock lock{queue_mutex}; + if (requests.empty()) { + wait_condition.notify_all(); + } + condition.wait(lock, stop_token, [this] { return !requests.empty(); }); + if (stop_token.stop_requested()) { + break; + } + task = std::move(requests.front()); + requests.pop(); + } + if constexpr (with_state) { + task(&state); + } else { + task(); + } + ++work_done; + } + } + ++workers_stopped; + wait_condition.notify_all(); + }; + threads.reserve(num_workers); + for (size_t i = 0; i < num_workers; ++i) { + threads.emplace_back(lambda); + } + } + + StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete; + StatefulThreadWorker(const StatefulThreadWorker&) = delete; + + StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete; + StatefulThreadWorker(StatefulThreadWorker&&) = delete; + + void QueueWork(Task work) { + { + std::unique_lock lock{queue_mutex}; + requests.emplace(std::move(work)); + ++work_scheduled; + } + condition.notify_one(); + } + + void WaitForRequests(std::stop_token stop_token = {}) { + std::stop_callback callback(stop_token, [this] { + for (auto& thread : threads) { + thread.request_stop(); + } + }); + std::unique_lock lock{queue_mutex}; + wait_condition.wait(lock, [this] { + return workers_stopped >= workers_queued || work_done >= work_scheduled; + }); + } private: - std::vector<std::thread> threads; - std::queue<std::function<void()>> requests; + std::queue<Task> requests; std::mutex queue_mutex; - std::condition_variable condition; - std::atomic_bool stop{}; + std::condition_variable_any condition; + std::condition_variable wait_condition; + std::atomic<size_t> work_scheduled{}; + std::atomic<size_t> work_done{}; + std::atomic<size_t> workers_stopped{}; + std::atomic<size_t> workers_queued{}; + std::string thread_name; + std::vector<std::jthread> threads; }; +using ThreadWorker = StatefulThreadWorker<>; + } // namespace Common diff --git a/src/common/unique_function.h b/src/common/unique_function.h new file mode 100644 index 000000000..ca0559071 --- /dev/null +++ b/src/common/unique_function.h @@ -0,0 +1,62 @@ +// Copyright 2021 yuzu emulator team +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once + +#include <memory> +#include <utility> + +namespace Common { + +/// General purpose function wrapper similar to std::function. +/// Unlike std::function, the captured values don't have to be copyable. +/// This class can be moved but not copied. +template <typename ResultType, typename... Args> +class UniqueFunction { + class CallableBase { + public: + virtual ~CallableBase() = default; + virtual ResultType operator()(Args&&...) = 0; + }; + + template <typename Functor> + class Callable final : public CallableBase { + public: + Callable(Functor&& functor_) : functor{std::move(functor_)} {} + ~Callable() override = default; + + ResultType operator()(Args&&... args) override { + return functor(std::forward<Args>(args)...); + } + + private: + Functor functor; + }; + +public: + UniqueFunction() = default; + + template <typename Functor> + UniqueFunction(Functor&& functor) + : callable{std::make_unique<Callable<Functor>>(std::move(functor))} {} + + UniqueFunction& operator=(UniqueFunction&& rhs) noexcept = default; + UniqueFunction(UniqueFunction&& rhs) noexcept = default; + + UniqueFunction& operator=(const UniqueFunction&) = delete; + UniqueFunction(const UniqueFunction&) = delete; + + ResultType operator()(Args&&... args) const { + return (*callable)(std::forward<Args>(args)...); + } + + explicit operator bool() const noexcept { + return static_cast<bool>(callable); + } + +private: + std::unique_ptr<CallableBase> callable; +}; + +} // namespace Common |