From a0d48382cb40d48104b1d5a4279ef9ef17a8d25f Mon Sep 17 00:00:00 2001 From: James Price Date: Thu, 11 Mar 2021 15:57:21 +0000 Subject: [PATCH] [spirv-writer] Handle entry point parameters Add a sanitizing transform to hoist entry point parameters out as global variables. Bug: tint:509 Change-Id: Ic18f69386a58d82ee11571fa9ec0c54cb5bdf2cf Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44083 Commit-Queue: James Price Auto-Submit: James Price Reviewed-by: Ben Clayton --- BUILD.gn | 1 + src/CMakeLists.txt | 1 + src/transform/spirv.cc | 113 ++++++++++++++++++- src/transform/spirv.h | 3 + src/transform/spirv_test.cc | 71 ++++++++++++ src/writer/spirv/builder_entry_point_test.cc | 102 +++++++++++++++++ 6 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 src/writer/spirv/builder_entry_point_test.cc diff --git a/BUILD.gn b/BUILD.gn index 1ac5dacd969..07c6aa40b64 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -987,6 +987,7 @@ source_set("tint_unittests_spv_writer_src") { "src/writer/spirv/builder_call_test.cc", "src/writer/spirv/builder_constructor_expression_test.cc", "src/writer/spirv/builder_discard_test.cc", + "src/writer/spirv/builder_entry_point_test.cc", "src/writer/spirv/builder_format_conversion_test.cc", "src/writer/spirv/builder_function_decoration_test.cc", "src/writer/spirv/builder_function_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d3e94d4c6d7..d5903bf2ed4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -651,6 +651,7 @@ if(${TINT_BUILD_TESTS}) writer/spirv/builder_call_test.cc writer/spirv/builder_constructor_expression_test.cc writer/spirv/builder_discard_test.cc + writer/spirv/builder_entry_point_test.cc writer/spirv/builder_format_conversion_test.cc writer/spirv/builder_function_decoration_test.cc writer/spirv/builder_function_test.cc diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 37c002b913f..d990f55729c 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -29,9 +29,118 @@ Spirv::~Spirv() = default; Transform::Output Spirv::Run(const Program* in) { ProgramBuilder out; CloneContext ctx(&out, in); - HandleSampleMaskBuiltins(ctx); + HandleEntryPointIOTypes(ctx); ctx.Clone(); - return Output{Program(std::move(out))}; + + // TODO(jrprice): Look into combining these transforms into a single clone. + Program tmp(std::move(out)); + + ProgramBuilder out2; + CloneContext ctx2(&out2, &tmp); + HandleSampleMaskBuiltins(ctx2); + ctx2.Clone(); + + return Output{Program(std::move(out2))}; +} + +void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const { + // Hoist entry point parameters, return values, and struct members out to + // global variables. Declare and construct struct parameters in the function + // body. Replace entry point return statements with variable assignments. + // + // Before: + // ``` + // struct FragmentInput { + // [[builtin(sample_index)]] sample_index : u32; + // [[builtin(sample_mask_in)]] sample_mask_in : u32; + // }; + // struct FragmentOutput { + // [[builtin(frag_depth)]] depth: f32; + // [[builtin(sample_mask_out)]] mask_out : u32; + // }; + // + // [[stage(fragment)]] + // fn fs_main( + // [[builtin(frag_coord)]] coord : vec4, + // samples : FragmentInput + // ) -> FragmentOutput { + // return FragmentOutput(1.0, samples.sample_mask_in); + // } + // ``` + // + // After: + // ``` + // struct FragmentInput { + // sample_index : u32; + // sample_mask_in : u32; + // }; + // struct FragmentOutput { + // depth: f32; + // mask_out : u32; + // }; + // + // [[builtin(frag_coord)]] var coord : vec4, + // [[builtin(sample_index)]] var sample_index : u32, + // [[builtin(sample_mask_in)]] var sample_mask_in : u32, + // [[builtin(frag_depth)]] var depth: f32; + // [[builtin(sample_mask_out)]] var mask_out : u32; + // + // [[stage(fragment)]] + // fn fs_main() -> void { + // const samples : FragmentInput(sample_index, sample_mask_in); + // depth = 1.0; + // mask_out = samples.sample_mask_in; + // return; + // } + // ``` + + // TODO(jrprice): Hoist struct members decorated as entry point IO types out + // of struct declarations, and redeclare the structs without the decorations. + + for (auto* func : ctx.src->AST().Functions()) { + if (!func->IsEntryPoint()) { + continue; + } + + for (auto* param : func->params()) { + // TODO(jrprice): Handle structures by moving the declaration and + // construction to the function body. + if (param->type()->Is()) { + TINT_ICE(ctx.dst->Diagnostics()) + << "structures as entry point parameters are not yet supported"; + continue; + } + + // Create a new symbol for the global variable. + auto var_symbol = ctx.dst->Symbols().New(); + // Create the global variable. + ctx.dst->Global(var_symbol, ctx.Clone(param->type()), + ast::StorageClass::kInput, nullptr, + ctx.Clone(param->decorations())); + + // Replace all uses of the function parameter with the global variable. + for (auto* user : ctx.src->Sem().Get(param)->Users()) { + ctx.Replace(user->Declaration(), + ctx.dst->Expr(var_symbol)); + } + } + + // TODO(jrprice): Hoist the return type out to a global variable, and + // replace return statements with variable assignments. + if (!func->return_type()->Is()) { + TINT_ICE(ctx.dst->Diagnostics()) + << "entry point return values are not yet supported"; + continue; + } + + // Rewrite the function header to remove the parameters. + // TODO(jrprice): Change return type to void when return values are handled. + auto* new_func = ctx.dst->create( + func->source(), ctx.Clone(func->symbol()), ast::VariableList{}, + ctx.Clone(func->return_type()), ctx.Clone(func->body()), + ctx.Clone(func->decorations())); + ctx.Replace(func, new_func); + } } void Spirv::HandleSampleMaskBuiltins(CloneContext& ctx) const { diff --git a/src/transform/spirv.h b/src/transform/spirv.h index cb67e738ec8..5231f0e9f87 100644 --- a/src/transform/spirv.h +++ b/src/transform/spirv.h @@ -39,6 +39,9 @@ class Spirv : public Transform { Output Run(const Program* program) override; private: + /// Hoist entry point parameters, return values, and struct members out to + /// global variables. + void HandleEntryPointIOTypes(CloneContext& ctx) const; /// Change type of sample mask builtin variables to single element arrays. void HandleSampleMaskBuiltins(CloneContext& ctx) const; }; diff --git a/src/transform/spirv_test.cc b/src/transform/spirv_test.cc index d1876d883de..2e66ce2dc67 100644 --- a/src/transform/spirv_test.cc +++ b/src/transform/spirv_test.cc @@ -22,6 +22,46 @@ namespace { using SpirvTest = TransformTest; +TEST_F(SpirvTest, HandleEntryPointIOTypes_Parameters) { + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, + [[location(1)]] loc1 : f32) -> void { + var col : f32 = (coord.x * loc1); +} + +[[stage(compute)]] +fn compute_main([[builtin(local_invocation_id)]] local_id : vec3, + [[builtin(local_invocation_index)]] local_index : u32) -> void { + var id_x : u32 = local_id.x; +} +)"; + + auto* expect = R"( +[[builtin(frag_coord)]] var tint_symbol_1 : vec4; + +[[location(1)]] var tint_symbol_2 : f32; + +[[builtin(local_invocation_id)]] var tint_symbol_6 : vec3; + +[[builtin(local_invocation_index)]] var tint_symbol_7 : u32; + +[[stage(fragment)]] +fn frag_main() -> void { + var col : f32 = (tint_symbol_1.x * tint_symbol_2); +} + +[[stage(compute)]] +fn compute_main() -> void { + var id_x : u32 = tint_symbol_6.x; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_F(SpirvTest, HandleSampleMaskBuiltins_Basic) { auto* src = R"( [[builtin(sample_index)]] var sample_index : u32; @@ -98,6 +138,37 @@ fn main() -> void { EXPECT_EQ(expect, str(got)); } +// Test that different transforms within the sanitizer interact correctly. +TEST_F(SpirvTest, MultipleTransforms) { + // TODO(jrprice): Make `mask_out` a return value when supported. + auto* src = R"( +[[builtin(sample_mask_out)]] var mask_out : u32; + +[[stage(fragment)]] +fn main([[builtin(sample_index)]] sample_index : u32, + [[builtin(sample_mask_in)]] mask_in : u32) -> void { + mask_out = mask_in; +} +)"; + + auto* expect = R"( +[[builtin(sample_index)]] var tint_symbol_1 : u32; + +[[builtin(sample_mask_in)]] var tint_symbol_2 : array; + +[[builtin(sample_mask_out)]] var mask_out : array; + +[[stage(fragment)]] +fn main() -> void { + mask_out[0] = tint_symbol_2[0]; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/writer/spirv/builder_entry_point_test.cc b/src/writer/spirv/builder_entry_point_test.cc new file mode 100644 index 00000000000..c9f782d2b93 --- /dev/null +++ b/src/writer/spirv/builder_entry_point_test.cc @@ -0,0 +1,102 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" +#include "src/ast/builtin.h" +#include "src/ast/builtin_decoration.h" +#include "src/ast/location_decoration.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/storage_class.h" +#include "src/ast/variable.h" +#include "src/program.h" +#include "src/type/f32_type.h" +#include "src/type/vector_type.h" +#include "src/writer/spirv/builder.h" +#include "src/writer/spirv/spv_dump.h" +#include "src/writer/spirv/test_helper.h" + +namespace tint { +namespace writer { +namespace spirv { +namespace { + +using BuilderTest = TestHelper; + +TEST_F(BuilderTest, EntryPoint_Parameters) { + // [[stage(fragment)]] + // fn frag_main([[builtin(frag_coord)]] coord : vec4, + // [[location(1)]] loc1 : f32) -> void { + // var col : f32 = (coord.x * loc1); + // } + auto* f32 = ty.f32(); + auto* vec4 = ty.vec4(); + auto* coord = Var("coord", vec4, ast::StorageClass::kInput, nullptr, + {create(ast::Builtin::kFragCoord)}); + auto* loc1 = Var("loc1", f32, ast::StorageClass::kInput, nullptr, + {create(1u)}); + auto* mul = Mul(Expr(MemberAccessor("coord", "x")), Expr("loc1")); + auto* col = Var("col", f32, ast::StorageClass::kFunction, mul, {}); + Func("frag_main", ast::VariableList{coord, loc1}, ty.void_(), + ast::StatementList{WrapInStatement(col)}, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kFragment), + }); + + spirv::Builder& b = SanitizeAndBuild(); + + ASSERT_TRUE(b.Build()); + + // Test that "coord" and "loc1" get hoisted out to global variables with the + // Input storage class, retaining their decorations. + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %9 "frag_main" %1 %5 +OpExecutionMode %9 OriginUpperLeft +OpName %1 "tint_symbol_1" +OpName %5 "tint_symbol_2" +OpName %9 "frag_main" +OpName %17 "col" +OpDecorate %1 BuiltIn FragCoord +OpDecorate %5 Location 1 +%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 4 +%2 = OpTypePointer Input %3 +%1 = OpVariable %2 Input +%6 = OpTypePointer Input %4 +%5 = OpVariable %6 Input +%8 = OpTypeVoid +%7 = OpTypeFunction %8 +%11 = OpTypeInt 32 0 +%12 = OpConstant %11 0 +%18 = OpTypePointer Function %4 +%19 = OpConstantNull %4 +%9 = OpFunction %8 None %7 +%10 = OpLabel +%17 = OpVariable %18 Function %19 +%13 = OpAccessChain %6 %1 %12 +%14 = OpLoad %4 %13 +%15 = OpLoad %4 %5 +%16 = OpFMul %4 %14 %15 +OpStore %17 %16 +OpReturn +OpFunctionEnd +)"); +} + +} // namespace +} // namespace spirv +} // namespace writer +} // namespace tint