summaryrefslogtreecommitdiffstats
path: root/src/video_core/host_shaders/queries_prefix_scan_sum.comp
blob: 8f10e248ee8f9f2066da1bdfed5cf8d1803b1773 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: GPL-3.0-or-later

#version 460 core

#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require

#ifdef VULKAN

#define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS };
#define UNIFORM(n)
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 1

#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv

#extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1
#else
#define HAS_EXTENDED_TYPES 0
#endif
#define BEGIN_PUSH_CONSTANTS
#define END_PUSH_CONSTANTS
#define UNIFORM(n) layout(location = n) uniform
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 0

#endif

BEGIN_PUSH_CONSTANTS
UNIFORM(0) uint max_accumulation_base;
UNIFORM(1) uint accumulation_limit;
END_PUSH_CONSTANTS

layout(local_size_x = 32) in;

layout(std430, binding = 0) readonly buffer block1 {
    uvec2 input_data[];
};

layout(std430, binding = 1) coherent buffer block2 {
    uvec2 output_data[];
};

layout(std430, binding = 2) coherent buffer block3 {
    uvec2 accumulated_data;
};

shared uvec2 shared_data[2];

// Simple Uint64 add that uses 2 uint variables for GPUs that don't support uint64
uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
    uint carry = 0;
    uvec2 result;
    result.x = uaddCarry(value_1.x, value_2.x, carry);
    result.y = value_1.y + value_2.y + carry;
    return result;
}

// do subgroup Prefix Sum using Hillis and Steele's algorithm
uvec2 subgroupInclusiveAddUint64(uvec2 value) {
    uvec2 result = value;
    for (uint i = 1; i < gl_SubgroupSize; i *= 2) {
        if (i <= gl_SubgroupInvocationID) {
            uvec2 other = subgroupShuffleUp(result, i); // get value from subgroup_inv_id - i;
            result = AddUint64(result, other);
        }
    }
    return result;
}

// Writes down the results to the output buffer and to the accumulation buffer
void WriteResults(uvec2 result) {
    uint current_global_id = gl_GlobalInvocationID.x;
    uvec2 base_data = current_global_id < max_accumulation_base ? accumulated_data : uvec2(0);
    output_data[current_global_id] = result + base_data;
    if (max_accumulation_base >= accumulation_limit + 1) {
        if (current_global_id == accumulation_limit) {
            accumulated_data = result;
        }
        return;
    }
    // We have that ugly case in which the accumulation data is reset in the middle somewhere.
    barrier();
    groupMemoryBarrier();
    if (current_global_id == accumulation_limit) {
        uvec2 value_1 = output_data[max_accumulation_base];
        accumulated_data = AddUint64(result, -value_1);
    }
}

void main() {
    uint subgroup_inv_id = gl_SubgroupInvocationID;
    uint subgroup_id = gl_SubgroupID;
    uint last_subgroup_id = subgroupMax(subgroup_inv_id);
    uint current_global_id = gl_GlobalInvocationID.x;
    uint total_work = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
    uvec2 data = input_data[current_global_id];
    // make sure all input data has been loaded
    subgroupBarrier();
    subgroupMemoryBarrier();

    uvec2 result = subgroupInclusiveAddUint64(data);

    // if we had less queries than our subgroup, just write down the results.
    if (total_work <= gl_SubgroupSize) { // This condition is constant per dispatch.
        WriteResults(result);
        return;
    }

    // We now have more, so lets write the last result into shared memory.
    // Only pick the last subgroup.
    if (subgroup_inv_id == last_subgroup_id) {
        shared_data[subgroup_id] = result;
    }
    // wait until everyone loaded their stuffs
    barrier();
    memoryBarrierShared();

    // Case 1: the total work for the grouped results can be calculated in a single subgroup
    // operation (about 1024 queries).
    uint total_extra_work = gl_NumSubgroups * gl_NumWorkGroups.x;
    if (total_extra_work <= gl_SubgroupSize) { // This condition is constant per dispatch.
        if (subgroup_id != 0) {
            uvec2 tmp = shared_data[subgroup_inv_id];
            subgroupBarrier();
            subgroupMemoryBarrierShared();
            tmp = subgroupInclusiveAddUint64(tmp);
            result = AddUint64(result, subgroupShuffle(tmp, subgroup_id - 1));
        }

        WriteResults(result);
        return;
    }

    // Case 2: our work amount is huge, so lets do it in O(log n) steps.
    const uint extra = (total_extra_work ^ (total_extra_work - 1)) != 0 ? 1 : 0;
    const uint steps = 1 << (findMSB(total_extra_work) + extra);
    uint step;
    // Hillis and Steele's algorithm
    for (step = 1; step < steps; step *= 2) {
        if (current_global_id < steps && current_global_id >= step) {
            uvec2 current = shared_data[current_global_id];
            uvec2 other = shared_data[current_global_id - step];
            shared_data[current_global_id] = AddUint64(current, other);
        }
        // steps is constant, so this will always execute in ever workgroup's thread.
        barrier();
        memoryBarrierShared();
    }
    // Only add results for groups higher than 0
    if (subgroup_id != 0) {
        result = AddUint64(result, shared_data[subgroup_id - 1]);
    }

    // Just write the final results. We are done
    WriteResults(result);
}