diff options
Diffstat (limited to 'src/common')
-rw-r--r-- | src/common/thread_worker.h | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h index 16aa673bd..8272985ff 100644 --- a/src/common/thread_worker.h +++ b/src/common/thread_worker.h @@ -7,7 +7,9 @@ #include <atomic> #include <functional> #include <mutex> +#include <stop_token> #include <string> +#include <thread> #include <type_traits> #include <vector> #include <queue> @@ -34,19 +36,19 @@ class StatefulThreadWorker { public: explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) : workers_queued{num_workers}, thread_name{std::move(name)} { - const auto lambda = [this, func] { + const auto lambda = [this, func](std::stop_token stop_token) { Common::SetCurrentThreadName(thread_name.c_str()); { std::conditional_t<with_state, StateType, int> state{func()}; - while (!stop) { + while (!stop_token.stop_requested()) { Task task; { std::unique_lock lock{queue_mutex}; if (requests.empty()) { wait_condition.notify_all(); } - condition.wait(lock, [this] { return stop || !requests.empty(); }); - if (stop) { + condition.wait(lock, stop_token, [this] { return !requests.empty(); }); + if (stop_token.stop_requested()) { break; } task = std::move(requests.front()); @@ -63,21 +65,17 @@ public: ++workers_stopped; wait_condition.notify_all(); }; + threads.reserve(num_workers); for (size_t i = 0; i < num_workers; ++i) { threads.emplace_back(lambda); } } - ~StatefulThreadWorker() { - { - std::unique_lock lock{queue_mutex}; - stop = true; - } - condition.notify_all(); - for (std::thread& thread : threads) { - thread.join(); - } - } + StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete; + StatefulThreadWorker(const StatefulThreadWorker&) = delete; + + StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete; + StatefulThreadWorker(StatefulThreadWorker&&) = delete; void QueueWork(Task work) { { @@ -88,7 +86,12 @@ public: condition.notify_one(); } - void WaitForRequests() { + void WaitForRequests(std::stop_token stop_token = {}) { + std::stop_callback callback(stop_token, [this] { + for (auto& thread : threads) { + thread.request_stop(); + } + }); std::unique_lock lock{queue_mutex}; wait_condition.wait(lock, [this] { return workers_stopped >= workers_queued || work_done >= work_scheduled; @@ -96,17 +99,16 @@ public: } private: - std::vector<std::thread> threads; std::queue<Task> requests; std::mutex queue_mutex; - std::condition_variable condition; + std::condition_variable_any condition; std::condition_variable wait_condition; - std::atomic_bool stop{}; std::atomic<size_t> work_scheduled{}; std::atomic<size_t> work_done{}; std::atomic<size_t> workers_stopped{}; std::atomic<size_t> workers_queued{}; std::string thread_name; + std::vector<std::jthread> threads; }; using ThreadWorker = StatefulThreadWorker<>; |