summaryrefslogtreecommitdiffstats
path: root/src/shader_recompiler/backend
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/backend')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp10
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.h5
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.h5
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp46
4 files changed, 56 insertions, 10 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index e70b78a28..5ef637fe7 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) {
if (info.uses_local_invocation_id) {
local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
}
+ if (info.uses_subgroup_mask) {
+ subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
+ subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
+ subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
+ subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
+ subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
+ }
if (info.uses_subgroup_invocation_id ||
- (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) {
+ (profile.warp_size_potentially_larger_than_guest &&
+ (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
subgroup_local_invocation_id =
DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
}
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 3a686a78c..03c5a6aba 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -97,6 +97,11 @@ public:
Id workgroup_id{};
Id local_invocation_id{};
Id subgroup_local_invocation_id{};
+ Id subgroup_mask_eq{};
+ Id subgroup_mask_lt{};
+ Id subgroup_mask_le{};
+ Id subgroup_mask_gt{};
+ Id subgroup_mask_ge{};
Id instance_id{};
Id instance_index{};
Id base_instance{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 032b0b2f9..712c5e61f 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred);
Id EmitVoteAny(EmitContext& ctx, Id pred);
Id EmitVoteEqual(EmitContext& ctx, Id pred);
Id EmitSubgroupBallot(EmitContext& ctx, Id pred);
+Id EmitSubgroupEqMask(EmitContext& ctx);
+Id EmitSubgroupLtMask(EmitContext& ctx);
+Id EmitSubgroupLeMask(EmitContext& ctx);
+Id EmitSubgroupGtMask(EmitContext& ctx);
+Id EmitSubgroupGeMask(EmitContext& ctx);
Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
Id segmentation_mask);
Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
index cbc5b1c96..c57bd291d 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
@@ -6,10 +6,18 @@
namespace Shader::Backend::SPIRV {
namespace {
-Id LargeWarpBallot(EmitContext& ctx, Id ballot) {
+Id WarpExtract(EmitContext& ctx, Id value) {
const Id shift{ctx.Constant(ctx.U32[1], 5)};
const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)};
- return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index);
+ return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index);
+}
+
+Id LoadMask(EmitContext& ctx, Id mask) {
+ const Id value{ctx.OpLoad(ctx.U32[4], mask)};
+ if (!ctx.profile.warp_size_potentially_larger_than_guest) {
+ return ctx.OpCompositeExtract(ctx.U32[1], value, 0U);
+ }
+ return WarpExtract(ctx, value);
}
void SetInBoundsFlag(IR::Inst* inst, Id result) {
@@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAllKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
- const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
- const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
+ const Id active_mask{WarpExtract(ctx, mask_ballot)};
+ const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
return ctx.OpIEqual(ctx.U1, lhs, active_mask);
}
@@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAnyKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
- const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
- const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
+ const Id active_mask{WarpExtract(ctx, mask_ballot)};
+ const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value);
}
@@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) {
return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred);
}
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
- const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
- const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
+ const Id active_mask{WarpExtract(ctx, mask_ballot)};
+ const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)};
return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value),
ctx.OpIEqual(ctx.U1, lhs, active_mask));
@@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) {
if (!ctx.profile.warp_size_potentially_larger_than_guest) {
return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U);
}
- return LargeWarpBallot(ctx, ballot);
+ return WarpExtract(ctx, ballot);
+}
+
+Id EmitSubgroupEqMask(EmitContext& ctx) {
+ return LoadMask(ctx, ctx.subgroup_mask_eq);
+}
+
+Id EmitSubgroupLtMask(EmitContext& ctx) {
+ return LoadMask(ctx, ctx.subgroup_mask_lt);
+}
+
+Id EmitSubgroupLeMask(EmitContext& ctx) {
+ return LoadMask(ctx, ctx.subgroup_mask_le);
+}
+
+Id EmitSubgroupGtMask(EmitContext& ctx) {
+ return LoadMask(ctx, ctx.subgroup_mask_gt);
+}
+
+Id EmitSubgroupGeMask(EmitContext& ctx) {
+ return LoadMask(ctx, ctx.subgroup_mask_ge);
}
Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,