summaryrefslogtreecommitdiffstats
path: root/src/video_core/host_shaders/queries_prefix_scan_sum.comp
blob: 6faa8981f252bb9c4b6c2461bd0c8864fdf62269 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project
// SPDX-License-Identifier: GPL-3.0-or-later

#version 460 core

#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_shuffle : require
#extension GL_KHR_shader_subgroup_shuffle_relative : require
#extension GL_KHR_shader_subgroup_arithmetic : require

#ifdef VULKAN

#define HAS_EXTENDED_TYPES 1
#define BEGIN_PUSH_CONSTANTS layout(push_constant) uniform PushConstants {
#define END_PUSH_CONSTANTS };
#define UNIFORM(n)
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 1

#else // ^^^ Vulkan ^^^ // vvv OpenGL vvv

#extension GL_NV_gpu_shader5 : enable
#ifdef GL_NV_gpu_shader5
#define HAS_EXTENDED_TYPES 1
#else
#define HAS_EXTENDED_TYPES 0
#endif
#define BEGIN_PUSH_CONSTANTS
#define END_PUSH_CONSTANTS
#define UNIFORM(n) layout(location = n) uniform
#define BINDING_INPUT_BUFFER 0
#define BINDING_OUTPUT_IMAGE 0

#endif

BEGIN_PUSH_CONSTANTS
UNIFORM(0) uint min_accumulation_base;
UNIFORM(1) uint max_accumulation_base;
UNIFORM(2) uint accumulation_limit;
UNIFORM(3) uint buffer_offset;
END_PUSH_CONSTANTS

#define LOCAL_RESULTS 8
#define QUERIES_PER_INVOC 2048

layout(local_size_x = QUERIES_PER_INVOC / LOCAL_RESULTS) in;

layout(std430, binding = 0) readonly buffer block1 {
    uvec2 input_data[];
};

layout(std430, binding = 1) coherent buffer block2 {
    uvec2 output_data[];
};

layout(std430, binding = 2) coherent buffer block3 {
    uvec2 accumulated_data;
};

shared uvec2 shared_data[128];

// Simple Uint64 add that uses 2 uint variables for GPUs that don't support uint64
uvec2 AddUint64(uvec2 value_1, uvec2 value_2) {
    uint carry = 0;
    uvec2 result;
    result.x = uaddCarry(value_1.x, value_2.x, carry);
    result.y = value_1.y + value_2.y + carry;
    return result;
}

// do subgroup Prefix Sum using Hillis and Steele's algorithm
uvec2 subgroupInclusiveAddUint64(uvec2 value) {
    uvec2 result = value;
    for (uint i = 1; i < gl_SubgroupSize; i *= 2) {
        uvec2 other = subgroupShuffleUp(result, i); // get value from subgroup_inv_id - i;
        if (i <= gl_SubgroupInvocationID) {
            result = AddUint64(result, other);
        }
    }
    return result;
}

// Writes down the results to the output buffer and to the accumulation buffer
void WriteResults(uvec2 results[LOCAL_RESULTS]) {
    const uint current_id = gl_LocalInvocationID.x;
    const uvec2 accum = accumulated_data;
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        uvec2 base_data = current_id * LOCAL_RESULTS + i < min_accumulation_base ? accum : uvec2(0, 0);
        AddUint64(results[i], base_data);
    }
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        output_data[buffer_offset + current_id * LOCAL_RESULTS + i] = results[i];
    }
    uint index = accumulation_limit % LOCAL_RESULTS;
    uint base_id = accumulation_limit / LOCAL_RESULTS;
    if (min_accumulation_base >= accumulation_limit + 1) {
        if (current_id == base_id) {
            accumulated_data = results[index];
        }
        return;
    }
    // We have that ugly case in which the accumulation data is reset in the middle somewhere.
    barrier();
    groupMemoryBarrier();

    if (current_id == base_id) {
        uvec2 reset_value = output_data[max_accumulation_base - 1];
        // Calculate two complement / negate manually
        reset_value = AddUint64(uvec2(1,0), ~reset_value);
        accumulated_data = AddUint64(results[index], reset_value);
    }
}

void main() {
    const uint subgroup_inv_id = gl_SubgroupInvocationID;
    const uint subgroup_id = gl_SubgroupID + gl_WorkGroupID.x * gl_NumSubgroups;
    const uint last_subgroup_id = subgroupMax(subgroup_inv_id);
    const uint current_id = gl_LocalInvocationID.x;
    const uint total_work = accumulation_limit;
    const uint last_result_id = LOCAL_RESULTS - 1;
    uvec2 data[LOCAL_RESULTS];
    for (uint i = 0; i < LOCAL_RESULTS; i++) {
        data[i] = input_data[buffer_offset + current_id * LOCAL_RESULTS + i];
    }
    uvec2 results[LOCAL_RESULTS];
    results[0] = data[0];
    for (uint i = 1; i < LOCAL_RESULTS; i++) {
        results[i] = AddUint64(data[i], results[i - 1]);
    }
    // make sure all input data has been loaded
    subgroupBarrier();
    subgroupMemoryBarrier();

    // on the last local result, do a subgroup inclusive scan sum
    results[last_result_id] = subgroupInclusiveAddUint64(results[last_result_id]);
    // get the last local result from the subgroup behind the current
    uvec2 result_behind = subgroupShuffleUp(results[last_result_id], 1);
    if (subgroup_inv_id != 0) {
        for (uint i = 1; i < LOCAL_RESULTS; i++) {
            results[i - 1] = AddUint64(results[i - 1], result_behind);
        }
    }

    // if we had less queries than our subgroup, just write down the results.
    if (total_work <= gl_SubgroupSize * LOCAL_RESULTS) { // This condition is constant per dispatch.
        WriteResults(results);
        return;
    }

    // We now have more, so lets write the last result into shared memory.
    // Only pick the last subgroup.
    if (subgroup_inv_id == last_subgroup_id) {
        shared_data[subgroup_id] = results[last_result_id];
    }
    // wait until everyone loaded their stuffs
    barrier();
    memoryBarrierShared();

    // only if it's not the first subgroup
    if (subgroup_id != 0) {
        // get the results from some previous invocation
        uvec2 tmp = shared_data[subgroup_inv_id];
        subgroupBarrier();
        subgroupMemoryBarrierShared();
        tmp = subgroupInclusiveAddUint64(tmp);
        // obtain the result that would be equivalent to the previous result
        uvec2 shuffled_result = subgroupShuffle(tmp, subgroup_id - 1);
        for (uint i = 0; i < LOCAL_RESULTS; i++) {
            results[i] = AddUint64(results[i], shuffled_result);
        }
    }
    WriteResults(results);
}