summaryrefslogtreecommitdiffstats
path: root/src/common/bounded_threadsafe_queue.h
blob: 21217801e626a9643d68b6d70815568897e5faf2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
// SPDX-FileCopyrightText: Copyright (c) 2020 Erik Rigtorp <erik@rigtorp.se>
// SPDX-License-Identifier: MIT

#pragma once

#include <atomic>
#include <bit>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <new>
#include <stop_token>
#include <type_traits>
#include <utility>

namespace Common {

#if defined(__cpp_lib_hardware_interference_size)
constexpr size_t hardware_interference_size = std::hardware_destructive_interference_size;
#else
constexpr size_t hardware_interference_size = 64;
#endif

template <typename T, size_t capacity = 0x400>
class MPSCQueue {
public:
    explicit MPSCQueue() : allocator{std::allocator<Slot<T>>()} {
        // Allocate one extra slot to prevent false sharing on the last slot
        slots = allocator.allocate(capacity + 1);
        // Allocators are not required to honor alignment for over-aligned types
        // (see http://eel.is/c++draft/allocator.requirements#10) so we verify
        // alignment here
        if (reinterpret_cast<uintptr_t>(slots) % alignof(Slot<T>) != 0) {
            allocator.deallocate(slots, capacity + 1);
            throw std::bad_alloc();
        }
        for (size_t i = 0; i < capacity; ++i) {
            std::construct_at(&slots[i]);
        }
        static_assert(std::has_single_bit(capacity), "capacity must be an integer power of 2");
        static_assert(alignof(Slot<T>) == hardware_interference_size,
                      "Slot must be aligned to cache line boundary to prevent false sharing");
        static_assert(sizeof(Slot<T>) % hardware_interference_size == 0,
                      "Slot size must be a multiple of cache line size to prevent "
                      "false sharing between adjacent slots");
        static_assert(sizeof(MPSCQueue) % hardware_interference_size == 0,
                      "Queue size must be a multiple of cache line size to "
                      "prevent false sharing between adjacent queues");
    }

    ~MPSCQueue() noexcept {
        for (size_t i = 0; i < capacity; ++i) {
            std::destroy_at(&slots[i]);
        }
        allocator.deallocate(slots, capacity + 1);
    }

    // The queue must be both non-copyable and non-movable
    MPSCQueue(const MPSCQueue&) = delete;
    MPSCQueue& operator=(const MPSCQueue&) = delete;

    MPSCQueue(MPSCQueue&&) = delete;
    MPSCQueue& operator=(MPSCQueue&&) = delete;

    void Push(const T& v) noexcept {
        static_assert(std::is_nothrow_copy_constructible_v<T>,
                      "T must be nothrow copy constructible");
        emplace(v);
    }

    template <typename P, typename = std::enable_if_t<std::is_nothrow_constructible_v<T, P&&>>>
    void Push(P&& v) noexcept {
        emplace(std::forward<P>(v));
    }

    void Pop(T& v, std::stop_token stop) noexcept {
        auto const tail = tail_.fetch_add(1);
        auto& slot = slots[idx(tail)];
        if (!slot.turn.test()) {
            std::unique_lock lock{cv_mutex};
            cv.wait(lock, stop, [&slot] { return slot.turn.test(); });
        }
        v = slot.move();
        slot.destroy();
        slot.turn.clear();
        slot.turn.notify_one();
    }

private:
    template <typename U = T>
    struct Slot {
        ~Slot() noexcept {
            if (turn.test()) {
                destroy();
            }
        }

        template <typename... Args>
        void construct(Args&&... args) noexcept {
            static_assert(std::is_nothrow_constructible_v<U, Args&&...>,
                          "T must be nothrow constructible with Args&&...");
            std::construct_at(reinterpret_cast<U*>(&storage), std::forward<Args>(args)...);
        }

        void destroy() noexcept {
            static_assert(std::is_nothrow_destructible_v<U>, "T must be nothrow destructible");
            std::destroy_at(reinterpret_cast<U*>(&storage));
        }

        U&& move() noexcept {
            return reinterpret_cast<U&&>(storage);
        }

        // Align to avoid false sharing between adjacent slots
        alignas(hardware_interference_size) std::atomic_flag turn{};
        struct aligned_store {
            struct type {
                alignas(U) unsigned char data[sizeof(U)];
            };
        };
        typename aligned_store::type storage;
    };

    template <typename... Args>
    void emplace(Args&&... args) noexcept {
        static_assert(std::is_nothrow_constructible_v<T, Args&&...>,
                      "T must be nothrow constructible with Args&&...");
        auto const head = head_.fetch_add(1);
        auto& slot = slots[idx(head)];
        slot.turn.wait(true);
        slot.construct(std::forward<Args>(args)...);
        slot.turn.test_and_set();
        cv.notify_one();
    }

    constexpr size_t idx(size_t i) const noexcept {
        return i & mask;
    }

    static constexpr size_t mask = capacity - 1;

    // Align to avoid false sharing between head_ and tail_
    alignas(hardware_interference_size) std::atomic<size_t> head_{0};
    alignas(hardware_interference_size) std::atomic<size_t> tail_{0};

    std::mutex cv_mutex;
    std::condition_variable_any cv;

    Slot<T>* slots;
    [[no_unique_address]] std::allocator<Slot<T>> allocator;

    static_assert(std::is_nothrow_copy_assignable_v<T> || std::is_nothrow_move_assignable_v<T>,
                  "T must be nothrow copy or move assignable");

    static_assert(std::is_nothrow_destructible_v<T>, "T must be nothrow destructible");
};

} // namespace Common