From b3fb1e20e41a7d68257a7f153ddab26e0ff49156 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Mon, 13 Jan 2025 19:32:51 -0800 Subject: [PATCH] [spirv-reader][ir] Handle GLSL 450 MatrixInverse The SPIR-V `MatrixInverse` method does not have an equivalent in WGSL. Add a polyfill for the allowed `2x2`, `3x3` and `4x4` matrix variants. Bug: 42250952 Change-Id: I1f73bdb974ad4bf5f51af886f7776c6174e3a6cb Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/220874 Commit-Queue: dan sinclair Reviewed-by: James Price --- src/tint/lang/spirv/builtin_fn.cc | 3 + src/tint/lang/spirv/builtin_fn.cc.tmpl | 1 + src/tint/lang/spirv/builtin_fn.h | 1 + src/tint/lang/spirv/intrinsic/data.cc | 222 +++--- .../spirv/reader/import_glsl_std450_test.cc | 141 ---- src/tint/lang/spirv/reader/lower/builtins.cc | 227 ++++++- .../lang/spirv/reader/lower/builtins_test.cc | 640 ++++++++++++++++++ src/tint/lang/spirv/reader/parser/BUILD.bazel | 1 + src/tint/lang/spirv/reader/parser/BUILD.cmake | 1 + src/tint/lang/spirv/reader/parser/BUILD.gn | 1 + .../reader/parser/import_glsl_std450_test.cc | 147 ++++ src/tint/lang/spirv/reader/parser/parser.cc | 2 + src/tint/lang/spirv/spirv.def | 4 + src/tint/lang/spirv/writer/printer/printer.cc | 5 + 14 files changed, 1173 insertions(+), 223 deletions(-) create mode 100644 src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc diff --git a/src/tint/lang/spirv/builtin_fn.cc b/src/tint/lang/spirv/builtin_fn.cc index 4ea7ab7e2f..f2e5cc95c8 100644 --- a/src/tint/lang/spirv/builtin_fn.cc +++ b/src/tint/lang/spirv/builtin_fn.cc @@ -110,6 +110,8 @@ const char* str(BuiltinFn i) { return "vector_times_scalar"; case BuiltinFn::kNormalize: return "normalize"; + case BuiltinFn::kInverse: + return "inverse"; case BuiltinFn::kSdot: return "sdot"; case BuiltinFn::kUdot: @@ -164,6 +166,7 @@ tint::core::ir::Instruction::Accesses GetSideEffects(BuiltinFn fn) { case BuiltinFn::kUdot: case BuiltinFn::kNone: case BuiltinFn::kNormalize: + case BuiltinFn::kInverse: break; } return core::ir::Instruction::Accesses{}; diff --git a/src/tint/lang/spirv/builtin_fn.cc.tmpl b/src/tint/lang/spirv/builtin_fn.cc.tmpl index cba1b82c84..4285e63d6f 100644 --- a/src/tint/lang/spirv/builtin_fn.cc.tmpl +++ b/src/tint/lang/spirv/builtin_fn.cc.tmpl @@ -73,6 +73,7 @@ tint::core::ir::Instruction::Accesses GetSideEffects(BuiltinFn fn) { case BuiltinFn::kUdot: case BuiltinFn::kNone: case BuiltinFn::kNormalize: + case BuiltinFn::kInverse: break; } return core::ir::Instruction::Accesses{}; diff --git a/src/tint/lang/spirv/builtin_fn.h b/src/tint/lang/spirv/builtin_fn.h index 0e49a9b8bc..24f65d24ba 100644 --- a/src/tint/lang/spirv/builtin_fn.h +++ b/src/tint/lang/spirv/builtin_fn.h @@ -82,6 +82,7 @@ enum class BuiltinFn : uint8_t { kVectorTimesMatrix, kVectorTimesScalar, kNormalize, + kInverse, kSdot, kUdot, kNone, diff --git a/src/tint/lang/spirv/intrinsic/data.cc b/src/tint/lang/spirv/intrinsic/data.cc index 8e9b6d0adf..6c08d99b6d 100644 --- a/src/tint/lang/spirv/intrinsic/data.cc +++ b/src/tint/lang/spirv/intrinsic/data.cc @@ -1339,13 +1339,19 @@ constexpr MatcherIndex kMatcherIndices[] = { /* [140] */ MatcherIndex(0), /* [141] */ MatcherIndex(10), /* [142] */ MatcherIndex(0), - /* [143] */ MatcherIndex(47), - /* [144] */ MatcherIndex(6), - /* [145] */ MatcherIndex(46), - /* [146] */ MatcherIndex(48), - /* [147] */ MatcherIndex(37), - /* [148] */ MatcherIndex(50), - /* [149] */ MatcherIndex(49), + /* [143] */ MatcherIndex(12), + /* [144] */ MatcherIndex(0), + /* [145] */ MatcherIndex(16), + /* [146] */ MatcherIndex(0), + /* [147] */ MatcherIndex(20), + /* [148] */ MatcherIndex(0), + /* [149] */ MatcherIndex(47), + /* [150] */ MatcherIndex(6), + /* [151] */ MatcherIndex(46), + /* [152] */ MatcherIndex(48), + /* [153] */ MatcherIndex(37), + /* [154] */ MatcherIndex(50), + /* [155] */ MatcherIndex(49), }; static_assert(MatcherIndicesIndex::CanIndex(kMatcherIndices), @@ -2135,7 +2141,7 @@ constexpr ParameterInfo kParameters[] = { { /* [156] */ /* usage */ core::ParameterUsage::kNone, - /* matcher_indices */ MatcherIndicesIndex(147), + /* matcher_indices */ MatcherIndicesIndex(153), }, { /* [157] */ @@ -2927,6 +2933,21 @@ constexpr ParameterInfo kParameters[] = { /* usage */ core::ParameterUsage::kNone, /* matcher_indices */ MatcherIndicesIndex(3), }, + { + /* [315] */ + /* usage */ core::ParameterUsage::kNone, + /* matcher_indices */ MatcherIndicesIndex(143), + }, + { + /* [316] */ + /* usage */ core::ParameterUsage::kNone, + /* matcher_indices */ MatcherIndicesIndex(145), + }, + { + /* [317] */ + /* usage */ core::ParameterUsage::kNone, + /* matcher_indices */ MatcherIndicesIndex(147), + }, }; static_assert(ParameterIndex::CanIndex(kParameters), @@ -2936,55 +2957,55 @@ constexpr TemplateInfo kTemplates[] = { { /* [0] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(146), + /* matcher_indices */ MatcherIndicesIndex(152), /* kind */ TemplateInfo::Kind::kType, }, { /* [1] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [2] */ /* name */ "I", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [3] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [4] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(146), + /* matcher_indices */ MatcherIndicesIndex(152), /* kind */ TemplateInfo::Kind::kType, }, { /* [5] */ /* name */ "A", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [6] */ /* name */ "B", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [7] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [8] */ /* name */ "D", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3002,13 +3023,13 @@ constexpr TemplateInfo kTemplates[] = { { /* [11] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [12] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3026,13 +3047,13 @@ constexpr TemplateInfo kTemplates[] = { { /* [15] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [16] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3050,19 +3071,19 @@ constexpr TemplateInfo kTemplates[] = { { /* [19] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [20] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [21] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(145), + /* matcher_indices */ MatcherIndicesIndex(151), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3086,7 +3107,7 @@ constexpr TemplateInfo kTemplates[] = { { /* [25] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3098,7 +3119,7 @@ constexpr TemplateInfo kTemplates[] = { { /* [27] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(144), + /* matcher_indices */ MatcherIndicesIndex(150), /* kind */ TemplateInfo::Kind::kNumber, }, { @@ -3110,55 +3131,55 @@ constexpr TemplateInfo kTemplates[] = { { /* [29] */ /* name */ "B", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [30] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [31] */ /* name */ "I", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [32] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [33] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [34] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(146), + /* matcher_indices */ MatcherIndicesIndex(152), /* kind */ TemplateInfo::Kind::kType, }, { /* [35] */ /* name */ "C", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [36] */ /* name */ "D", - /* matcher_indices */ MatcherIndicesIndex(143), + /* matcher_indices */ MatcherIndicesIndex(149), /* kind */ TemplateInfo::Kind::kType, }, { /* [37] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(145), + /* matcher_indices */ MatcherIndicesIndex(151), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3194,7 +3215,7 @@ constexpr TemplateInfo kTemplates[] = { { /* [43] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(145), + /* matcher_indices */ MatcherIndicesIndex(151), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3212,13 +3233,13 @@ constexpr TemplateInfo kTemplates[] = { { /* [46] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(146), + /* matcher_indices */ MatcherIndicesIndex(152), /* kind */ TemplateInfo::Kind::kType, }, { /* [47] */ /* name */ "S", - /* matcher_indices */ MatcherIndicesIndex(148), + /* matcher_indices */ MatcherIndicesIndex(154), /* kind */ TemplateInfo::Kind::kType, }, { @@ -3230,7 +3251,7 @@ constexpr TemplateInfo kTemplates[] = { { /* [49] */ /* name */ "T", - /* matcher_indices */ MatcherIndicesIndex(149), + /* matcher_indices */ MatcherIndicesIndex(155), /* kind */ TemplateInfo::Kind::kType, }, }; @@ -4859,6 +4880,39 @@ constexpr OverloadInfo kOverloads[] = { { /* [147] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), + /* num_parameters */ 1, + /* num_explicit_templates */ 0, + /* num_templates */ 1, + /* templates */ TemplateIndex(21), + /* parameters */ ParameterIndex(315), + /* return_matcher_indices */ MatcherIndicesIndex(143), + /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), + }, + { + /* [148] */ + /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), + /* num_parameters */ 1, + /* num_explicit_templates */ 0, + /* num_templates */ 1, + /* templates */ TemplateIndex(21), + /* parameters */ ParameterIndex(316), + /* return_matcher_indices */ MatcherIndicesIndex(145), + /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), + }, + { + /* [149] */ + /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), + /* num_parameters */ 1, + /* num_explicit_templates */ 0, + /* num_templates */ 1, + /* templates */ TemplateIndex(21), + /* parameters */ ParameterIndex(317), + /* return_matcher_indices */ MatcherIndicesIndex(147), + /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), + }, + { + /* [150] */ + /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 3, /* num_explicit_templates */ 0, /* num_templates */ 1, @@ -4868,7 +4922,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [148] */ + /* [151] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 3, /* num_explicit_templates */ 0, @@ -4879,7 +4933,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [149] */ + /* [152] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 1, /* num_explicit_templates */ 0, @@ -4890,7 +4944,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [150] */ + /* [153] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 1, /* num_explicit_templates */ 0, @@ -4901,7 +4955,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [151] */ + /* [154] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -4912,7 +4966,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [152] */ + /* [155] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 4, /* num_explicit_templates */ 0, @@ -4923,7 +4977,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [153] */ + /* [156] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 6, /* num_explicit_templates */ 0, @@ -4934,7 +4988,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [154] */ + /* [157] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 3, /* num_explicit_templates */ 0, @@ -4945,7 +4999,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [155] */ + /* [158] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 4, /* num_explicit_templates */ 0, @@ -4956,7 +5010,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [156] */ + /* [159] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -4967,7 +5021,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [157] */ + /* [160] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -4978,7 +5032,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [158] */ + /* [161] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -4989,7 +5043,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [159] */ + /* [162] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -5000,7 +5054,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [160] */ + /* [163] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -5011,7 +5065,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [161] */ + /* [164] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 2, /* num_explicit_templates */ 0, @@ -5022,7 +5076,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [162] */ + /* [165] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 3, /* num_explicit_templates */ 0, @@ -5033,7 +5087,7 @@ constexpr OverloadInfo kOverloads[] = { /* const_eval_fn */ ConstEvalFunctionIndex(/* invalid */), }, { - /* [163] */ + /* [166] */ /* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), /* num_parameters */ 3, /* num_explicit_templates */ 0, @@ -5053,91 +5107,91 @@ constexpr IntrinsicInfo kBuiltins[] = { /* [0] */ /* fn array_length[I : u32, A : access](ptr, I) -> u32 */ /* num overloads */ 1, - /* overloads */ OverloadIndex(151), + /* overloads */ OverloadIndex(154), }, { /* [1] */ /* fn atomic_and[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [2] */ /* fn atomic_compare_exchange[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, U, T, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(153), + /* overloads */ OverloadIndex(156), }, { /* [3] */ /* fn atomic_exchange[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [4] */ /* fn atomic_iadd[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [5] */ /* fn atomic_isub[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [6] */ /* fn atomic_load[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(154), + /* overloads */ OverloadIndex(157), }, { /* [7] */ /* fn atomic_or[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [8] */ /* fn atomic_smax[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [9] */ /* fn atomic_smin[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [10] */ /* fn atomic_store[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) */ /* num overloads */ 1, - /* overloads */ OverloadIndex(155), + /* overloads */ OverloadIndex(158), }, { /* [11] */ /* fn atomic_umax[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [12] */ /* fn atomic_umin[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [13] */ /* fn atomic_xor[T : iu32, U : u32, S : workgroup_or_storage](ptr, read_write>, U, U, T) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(152), + /* overloads */ OverloadIndex(155), }, { /* [14] */ /* fn dot[N : num, T : f32_f16](vec, vec) -> T */ /* num overloads */ 1, - /* overloads */ OverloadIndex(156), + /* overloads */ OverloadIndex(159), }, { /* [15] */ @@ -5335,19 +5389,19 @@ constexpr IntrinsicInfo kBuiltins[] = { /* [26] */ /* fn matrix_times_matrix[T : f32_f16, K : num, C : num, R : num](mat, mat) -> mat */ /* num overloads */ 1, - /* overloads */ OverloadIndex(157), + /* overloads */ OverloadIndex(160), }, { /* [27] */ /* fn matrix_times_scalar[T : f32_f16, N : num, M : num](mat, T) -> mat */ /* num overloads */ 1, - /* overloads */ OverloadIndex(158), + /* overloads */ OverloadIndex(161), }, { /* [28] */ /* fn matrix_times_vector[T : f32_f16, N : num, M : num](mat, vec) -> vec */ /* num overloads */ 1, - /* overloads */ OverloadIndex(159), + /* overloads */ OverloadIndex(162), }, { /* [29] */ @@ -5369,38 +5423,46 @@ constexpr IntrinsicInfo kBuiltins[] = { /* fn select[T : scalar](bool, T, T) -> T */ /* fn select[N : num, T : scalar](vec, vec, vec) -> vec */ /* num overloads */ 2, - /* overloads */ OverloadIndex(147), + /* overloads */ OverloadIndex(150), }, { /* [31] */ /* fn vector_times_matrix[T : f32_f16, N : num, M : num](vec, mat) -> vec */ /* num overloads */ 1, - /* overloads */ OverloadIndex(160), + /* overloads */ OverloadIndex(163), }, { /* [32] */ /* fn vector_times_scalar[T : f32_f16, N : num](vec, T) -> vec */ /* num overloads */ 1, - /* overloads */ OverloadIndex(161), + /* overloads */ OverloadIndex(164), }, { /* [33] */ /* fn normalize[T : f32_f16](T) -> T */ /* fn normalize[N : num, T : f32_f16](vec) -> vec */ /* num overloads */ 2, - /* overloads */ OverloadIndex(149), + /* overloads */ OverloadIndex(152), }, { /* [34] */ + /* fn inverse[T : f32_f16](mat2x2) -> mat2x2 */ + /* fn inverse[T : f32_f16](mat3x3) -> mat3x3 */ + /* fn inverse[T : f32_f16](mat4x4) -> mat4x4 */ + /* num overloads */ 3, + /* overloads */ OverloadIndex(147), + }, + { + /* [35] */ /* fn sdot(u32, u32, u32) -> i32 */ /* num overloads */ 1, - /* overloads */ OverloadIndex(162), + /* overloads */ OverloadIndex(165), }, { - /* [35] */ + /* [36] */ /* fn udot(u32, u32, u32) -> u32 */ /* num overloads */ 1, - /* overloads */ OverloadIndex(163), + /* overloads */ OverloadIndex(166), }, }; diff --git a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc index 47b240ad77..8fad3a0ef3 100644 --- a/src/tint/lang/spirv/reader/import_glsl_std450_test.cc +++ b/src/tint/lang/spirv/reader/import_glsl_std450_test.cc @@ -1412,146 +1412,5 @@ INSTANTIATE_TEST_SUITE_P( DeterminantData{"m3x3f1", "mat3x3(vec3(50.0f, 60.0f, 70.0f))"}, DeterminantData{"m4x4f1", "mat4x4(vec4(50.0f))"})); -TEST_F(SpirvReaderTest, DISABLED_GlslStd450_MatrixInverse_mat2x2) { - EXPECT_IR(Preamble() + R"( - %1 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 - OpReturn - OpFunctionEnd - )", - R"( -%main = @compute @workgroup_size(1u, 1u, 1u) func():void { - $B1: { - let s = (1.0f / determinant(m2x2f1)); - let x_1 = mat2x2f(vec2f((s * m2x2f1[1u][1u]), - (-(s) * m2x2f1[0u][1u])), - vec2f((-(s) * m2x2f1[1u][0u]), - (s * m2x2f1[0u][0u]))); - } -} -)"); -} - -TEST_F(SpirvReaderTest, DISABLED_GlslStd450_MatrixInverse_mat3x3) { - EXPECT_IR(Preamble() + R"( - %1 = OpExtInst %mat3v3float %glsl MatrixInverse %mat3v3float_50_60_70 - OpReturn - OpFunctionEnd - )", - R"( -%main = @compute @workgroup_size(1u, 1u, 1u) func():void { - $B1: { - let s = (1.0f / determinant(m3x3f1)); - let x_1 = (s * mat3x3f(vec3f(((m3x3f1[1u][1u] * m3x3f1[2u][2u]) - - (m3x3f1[1u][2u] * m3x3f1[2u][1u])), ((m3x3f1[0u][2u] * m3x3f1[2u][1u]) - (m3x3f1[0u][1u] - * m3x3f1[2u][2u])), ((m3x3f1[0u][1u] * m3x3f1[1u][2u]) - (m3x3f1[0u][2u] * - m3x3f1[1u][1u]))), vec3f(((m3x3f1[1u][2u] * m3x3f1[2u][0u]) - (m3x3f1[1u][0u] * - m3x3f1[2u][2u])), ((m3x3f1[0u][0u] * m3x3f1[2u][2u]) - (m3x3f1[0u][2u] * - m3x3f1[2u][0u])), ((m3x3f1[0u][2u] * m3x3f1[1u][0u]) - (m3x3f1[0u][0u] * - m3x3f1[1u][2u]))), vec3f(((m3x3f1[1u][0u] * m3x3f1[2u][1u]) - (m3x3f1[1u][1u] * - m3x3f1[2u][0u])), ((m3x3f1[0u][1u] * m3x3f1[2u][0u]) - (m3x3f1[0u][0u] * - m3x3f1[2u][1u])), ((m3x3f1[0u][0u] * m3x3f1[1u][1u]) - (m3x3f1[0u][1u] * - m3x3f1[1u][0u]))))); - } -} -)"); -} - -TEST_F(SpirvReaderTest, DISABLED_GlslStd450_MatrixInverse_mat4x4) { - EXPECT_IR(Preamble() + R"( - %1 = OpExtInst %mat4v4float %glsl MatrixInverse %mat4v4float_50_50_50_50 - OpReturn - OpFunctionEnd - )", - R"( -%main = @compute @workgroup_size(1u, 1u, 1u) func():void { - $B1: { - let s = (1.0f / determinant(m4x4f1)); - let x_1 = (s * mat4x4f(vec4f((((m4x4f1[1u][1u] * ((m4x4f1[2u][2u] * - m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][2u]))) - (m4x4f1[1u][2u] * - ((m4x4f1[2u][1u] * m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][1u])))) + - (m4x4f1[1u][3u] * ((m4x4f1[2u][1u] * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * - m4x4f1[3u][1u])))), (((-(m4x4f1[0u][1u]) * ((m4x4f1[2u][2u] * m4x4f1[3u][3u]) - - (m4x4f1[2u][3u] * m4x4f1[3u][2u]))) + (m4x4f1[0u][2u] * ((m4x4f1[2u][1u] * - m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][1u])))) - (m4x4f1[0u][3u] * - ((m4x4f1[2u][1u] * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * m4x4f1[3u][1u])))), - (((m4x4f1[0u][1u] * ((m4x4f1[1u][2u] * m4x4f1[3u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[3u][2u]))) - (m4x4f1[0u][2u] * ((m4x4f1[1u][1u] * m4x4f1[3u][3u]) - - (m4x4f1[1u][3u] * m4x4f1[3u][1u])))) + (m4x4f1[0u][3u] * ((m4x4f1[1u][1u] * - m4x4f1[3u][2u]) - (m4x4f1[1u][2u] * m4x4f1[3u][1u])))), (((-(m4x4f1[0u][1u]) * - ((m4x4f1[1u][2u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * m4x4f1[2u][2u]))) + - (m4x4f1[0u][2u] * ((m4x4f1[1u][1u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[2u][1u])))) - (m4x4f1[0u][3u] * ((m4x4f1[1u][1u] * m4x4f1[2u][2u]) - - (m4x4f1[1u][2u] * m4x4f1[2u][1u]))))), vec4f((((-(m4x4f1[1u][0u]) * ((m4x4f1[2u][2u] - * m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][2u]))) + (m4x4f1[1u][2u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][0u])))) - - (m4x4f1[1u][3u] * ((m4x4f1[2u][0u] * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * - m4x4f1[3u][0u])))), (((m4x4f1[0u][0u] * ((m4x4f1[2u][2u] * m4x4f1[3u][3u]) - - (m4x4f1[2u][3u] * m4x4f1[3u][2u]))) - (m4x4f1[0u][2u] * ((m4x4f1[2u][0u] * - m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][0u])))) + (m4x4f1[0u][3u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * m4x4f1[3u][0u])))), - (((-(m4x4f1[0u][0u]) * ((m4x4f1[1u][2u] * m4x4f1[3u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[3u][2u]))) + (m4x4f1[0u][2u] * ((m4x4f1[1u][0u] * m4x4f1[3u][3u]) - - (m4x4f1[1u][3u] * m4x4f1[3u][0u])))) - (m4x4f1[0u][3u] * ((m4x4f1[1u][0u] * - m4x4f1[3u][2u]) - (m4x4f1[1u][2u] * m4x4f1[3u][0u])))), (((m4x4f1[0u][0u] * - ((m4x4f1[1u][2u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * m4x4f1[2u][2u]))) - - (m4x4f1[0u][2u] * ((m4x4f1[1u][0u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[2u][0u])))) + (m4x4f1[0u][3u] * ((m4x4f1[1u][0u] * m4x4f1[2u][2u]) - - (m4x4f1[1u][2u] * m4x4f1[2u][0u]))))), vec4f((((m4x4f1[1u][0u] * ((m4x4f1[2u][1u] * - m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][1u]))) - (m4x4f1[1u][1u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][0u])))) + - (m4x4f1[1u][3u] * ((m4x4f1[2u][0u] * m4x4f1[3u][1u]) - (m4x4f1[2u][1u] * - m4x4f1[3u][0u])))), (((-(m4x4f1[0u][0u]) * ((m4x4f1[2u][1u] * m4x4f1[3u][3u]) - - (m4x4f1[2u][3u] * m4x4f1[3u][1u]))) + (m4x4f1[0u][1u] * ((m4x4f1[2u][0u] * - m4x4f1[3u][3u]) - (m4x4f1[2u][3u] * m4x4f1[3u][0u])))) - (m4x4f1[0u][3u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][1u]) - (m4x4f1[2u][1u] * m4x4f1[3u][0u])))), - (((m4x4f1[0u][0u] * ((m4x4f1[1u][1u] * m4x4f1[3u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[3u][1u]))) - (m4x4f1[0u][1u] * ((m4x4f1[1u][0u] * m4x4f1[3u][3u]) - - (m4x4f1[1u][3u] * m4x4f1[3u][0u])))) + (m4x4f1[0u][3u] * ((m4x4f1[1u][0u] * - m4x4f1[3u][1u]) - (m4x4f1[1u][1u] * m4x4f1[3u][0u])))), (((-(m4x4f1[0u][0u]) * - ((m4x4f1[1u][1u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * m4x4f1[2u][1u]))) + - (m4x4f1[0u][1u] * ((m4x4f1[1u][0u] * m4x4f1[2u][3u]) - (m4x4f1[1u][3u] * - m4x4f1[2u][0u])))) - (m4x4f1[0u][3u] * ((m4x4f1[1u][0u] * m4x4f1[2u][1u]) - - (m4x4f1[1u][1u] * m4x4f1[2u][0u]))))), vec4f((((-(m4x4f1[1u][0u]) * ((m4x4f1[2u][1u] - * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * m4x4f1[3u][1u]))) + (m4x4f1[1u][1u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * m4x4f1[3u][0u])))) - - (m4x4f1[1u][2u] * ((m4x4f1[2u][0u] * m4x4f1[3u][1u]) - (m4x4f1[2u][1u] * - m4x4f1[3u][0u])))), (((m4x4f1[0u][0u] * ((m4x4f1[2u][1u] * m4x4f1[3u][2u]) - - (m4x4f1[2u][2u] * m4x4f1[3u][1u]))) - (m4x4f1[0u][1u] * ((m4x4f1[2u][0u] * - m4x4f1[3u][2u]) - (m4x4f1[2u][2u] * m4x4f1[3u][0u])))) + (m4x4f1[0u][2u] * - ((m4x4f1[2u][0u] * m4x4f1[3u][1u]) - (m4x4f1[2u][1u] * m4x4f1[3u][0u])))), - (((-(m4x4f1[0u][0u]) * ((m4x4f1[1u][1u] * m4x4f1[3u][2u]) - (m4x4f1[1u][2u] * - m4x4f1[3u][1u]))) + (m4x4f1[0u][1u] * ((m4x4f1[1u][0u] * m4x4f1[3u][2u]) - - (m4x4f1[1u][2u] * m4x4f1[3u][0u])))) - (m4x4f1[0u][2u] * ((m4x4f1[1u][0u] * - m4x4f1[3u][1u]) - (m4x4f1[1u][1u] * m4x4f1[3u][0u])))), (((m4x4f1[0u][0u] * - ((m4x4f1[1u][1u] * m4x4f1[2u][2u]) - (m4x4f1[1u][2u] * m4x4f1[2u][1u]))) - - (m4x4f1[0u][1u] * ((m4x4f1[1u][0u] * m4x4f1[2u][2u]) - (m4x4f1[1u][2u] * - m4x4f1[2u][0u])))) + (m4x4f1[0u][2u] * ((m4x4f1[1u][0u] * m4x4f1[2u][1u]) - - (m4x4f1[1u][1u] * m4x4f1[2u][0u]))))))); - } -} -)"); -} - -TEST_F(SpirvReaderTest, DISABLED_GlslStd450_MatrixInverse_MultipleInScope) { - EXPECT_IR(Preamble() + R"( - %1 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 - %2 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 - OpReturn - OpFunctionEnd - )", - R"( -%main = @compute @workgroup_size(1u, 1u, 1u) func():void { - $B1: { - let s = (1.0f / determinant(m2x2f1)); - let x_1 = mat2x2f(vec2f((s * m2x2f1[1u][1u]), (-(s) * - m2x2f1[0u][1u])), vec2f((-(s) * m2x2f1[1u][0u]), (s * m2x2f1[0u][0u]))); - let s_1 = (1.0f / determinant(m2x2f1)); - let x_2 = mat2x2f(vec2f((s_1 * m2x2f1[1u][1u]), (-(s_1) * - m2x2f1[0u][1u])), vec2f((-(s_1) * m2x2f1[1u][0u]), (s_1 * m2x2f1[0u][0u]))); - } -} -)"); -} - } // namespace } // namespace tint::spirv::reader diff --git a/src/tint/lang/spirv/reader/lower/builtins.cc b/src/tint/lang/spirv/reader/lower/builtins.cc index c9edc17f0b..b59b33a348 100644 --- a/src/tint/lang/spirv/reader/lower/builtins.cc +++ b/src/tint/lang/spirv/reader/lower/builtins.cc @@ -33,10 +33,10 @@ #include "src/tint/lang/spirv/ir/builtin_call.h" namespace tint::spirv::reader::lower { - namespace { -using namespace tint::core::fluent_types; // NOLINT +using namespace tint::core::fluent_types; // NOLINT +using namespace tint::core::number_suffixes; // NOLINT /// PIMPL state for the transform. struct State { @@ -64,6 +64,9 @@ struct State { case spirv::BuiltinFn::kNormalize: Normalize(builtin); break; + case spirv::BuiltinFn::kInverse: + Inverse(builtin); + break; default: TINT_UNREACHABLE() << "unknown spirv builtin: " << builtin->Func(); } @@ -82,6 +85,226 @@ struct State { }); call->Destroy(); } + + void Inverse(spirv::ir::BuiltinCall* call) { + auto* arg = call->Args()[0]; + auto* mat_ty = arg->Type()->As(); + TINT_ASSERT(mat_ty); + TINT_ASSERT(mat_ty->Columns() == mat_ty->Rows()); + + auto* elem_ty = mat_ty->Type(); + + b.InsertBefore(call, [&] { + auto* det = + b.Call(elem_ty, core::BuiltinFn::kDeterminant, Vector{arg}); + core::ir::Value* one = nullptr; + if (elem_ty->Is()) { + one = b.Constant(1.0_f); + } else if (elem_ty->Is()) { + one = b.Constant(1.0_h); + } else { + TINT_UNREACHABLE(); + } + auto* inv_det = b.Divide(elem_ty, one, det); + + // Returns (m * n) - (o * p) + auto sub_mul2 = [&](auto* m, auto* n, auto* o, auto* p) { + auto* x = b.Multiply(elem_ty, m, n); + auto* y = b.Multiply(elem_ty, o, p); + return b.Subtract(elem_ty, x, y); + }; + + // Returns (m * n) - (o * p) + (q * r) + auto sub_add_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) { + auto* w = b.Multiply(elem_ty, m, n); + auto* x = b.Multiply(elem_ty, o, p); + auto* y = b.Multiply(elem_ty, q, r); + + auto* z = b.Subtract(elem_ty, w, x); + return b.Add(elem_ty, z, y); + }; + + // Returns (m * n) + (o * p) - (q * r) + auto add_sub_mul3 = [&](auto* m, auto* n, auto* o, auto* p, auto* q, auto* r) { + auto* w = b.Multiply(elem_ty, m, n); + auto* x = b.Multiply(elem_ty, o, p); + auto* y = b.Multiply(elem_ty, q, r); + + auto* z = b.Add(elem_ty, w, x); + return b.Subtract(elem_ty, z, y); + }; + + switch (mat_ty->Columns()) { + case 2: { + auto* neg_inv_det = b.Negation(elem_ty, inv_det); + + auto* ma = b.Access(elem_ty, arg, 0_u, 0_u); + auto* mb = b.Access(elem_ty, arg, 0_u, 1_u); + auto* mc = b.Access(elem_ty, arg, 1_u, 0_u); + auto* md = b.Access(elem_ty, arg, 1_u, 1_u); + + auto* r_00 = b.Multiply(elem_ty, inv_det, md); + auto* r_01 = b.Multiply(elem_ty, neg_inv_det, mb); + auto* r_10 = b.Multiply(elem_ty, neg_inv_det, mc); + auto* r_11 = b.Multiply(elem_ty, inv_det, ma); + + auto* r1 = b.Construct(ty.vec2(elem_ty), r_00, r_01); + auto* r2 = b.Construct(ty.vec2(elem_ty), r_10, r_11); + b.ConstructWithResult(call->DetachResult(), r1, r2); + break; + } + case 3: { + auto* ma = b.Access(elem_ty, arg, 0_u, 0_u); + auto* mb = b.Access(elem_ty, arg, 0_u, 1_u); + auto* mc = b.Access(elem_ty, arg, 0_u, 2_u); + auto* md = b.Access(elem_ty, arg, 1_u, 0_u); + auto* me = b.Access(elem_ty, arg, 1_u, 1_u); + auto* mf = b.Access(elem_ty, arg, 1_u, 2_u); + auto* mg = b.Access(elem_ty, arg, 2_u, 0_u); + auto* mh = b.Access(elem_ty, arg, 2_u, 1_u); + auto* mi = b.Access(elem_ty, arg, 2_u, 2_u); + + // e * i - f * h + auto* r_00 = sub_mul2(me, mi, mf, mh); + // c * h - b * i + auto* r_01 = sub_mul2(mc, mh, mb, mi); + // b * f - c * e + auto* r_02 = sub_mul2(mb, mf, mc, me); + + // f * g - d * i + auto* r_10 = sub_mul2(mf, mg, md, mi); + // a * i - c * g + auto* r_11 = sub_mul2(ma, mi, mc, mg); + // c * d - a * f + auto* r_12 = sub_mul2(mc, md, ma, mf); + + // d * h - e * g + auto* r_20 = sub_mul2(md, mh, me, mg); + // b * g - a * h + auto* r_21 = sub_mul2(mb, mg, ma, mh); + // a * e - b * d + auto* r_22 = sub_mul2(ma, me, mb, md); + + auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02); + auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12); + auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22); + + auto* m = b.Construct(mat_ty, r1, r2, r3); + auto* inv = b.Multiply(mat_ty, inv_det, m); + call->Result(0)->ReplaceAllUsesWith(inv->Result(0)); + break; + } + case 4: { + auto* ma = b.Access(elem_ty, arg, 0_u, 0_u); + auto* mb = b.Access(elem_ty, arg, 0_u, 1_u); + auto* mc = b.Access(elem_ty, arg, 0_u, 2_u); + auto* md = b.Access(elem_ty, arg, 0_u, 3_u); + auto* me = b.Access(elem_ty, arg, 1_u, 0_u); + auto* mf = b.Access(elem_ty, arg, 1_u, 1_u); + auto* mg = b.Access(elem_ty, arg, 1_u, 2_u); + auto* mh = b.Access(elem_ty, arg, 1_u, 3_u); + auto* mi = b.Access(elem_ty, arg, 2_u, 0_u); + auto* mj = b.Access(elem_ty, arg, 2_u, 1_u); + auto* mk = b.Access(elem_ty, arg, 2_u, 2_u); + auto* ml = b.Access(elem_ty, arg, 2_u, 3_u); + auto* mm = b.Access(elem_ty, arg, 3_u, 0_u); + auto* mn = b.Access(elem_ty, arg, 3_u, 1_u); + auto* mo = b.Access(elem_ty, arg, 3_u, 2_u); + auto* mp = b.Access(elem_ty, arg, 3_u, 3_u); + + // kplo = k * p - l * o + auto* kplo = sub_mul2(mk, mp, ml, mo); + // jpln = j * p - l * n + auto* jpln = sub_mul2(mj, mp, ml, mn); + // jokn = j * o - k * n; + auto* jokn = sub_mul2(mj, mo, mk, mn); + // gpho = g * p - h * o + auto* gpho = sub_mul2(mg, mp, mh, mo); + // fphn = f * p - h * n + auto* fphn = sub_mul2(mf, mp, mh, mn); + // fogn = f * o - g * n; + auto* fogn = sub_mul2(mf, mo, mg, mn); + // glhk = g * l - h * k + auto* glhk = sub_mul2(mg, ml, mh, mk); + // flhj = f * l - h * j + auto* flhj = sub_mul2(mf, ml, mh, mj); + // fkgj = f * k - g * j; + auto* fkgj = sub_mul2(mf, mk, mg, mj); + // iplm = i * p - l * m + auto* iplm = sub_mul2(mi, mp, ml, mm); + // iokm = i * o - k * m + auto* iokm = sub_mul2(mi, mo, mk, mm); + // ephm = e * p - h * m; + auto* ephm = sub_mul2(me, mp, mh, mm); + // eogm = e * o - g * m + auto* eogm = sub_mul2(me, mo, mg, mm); + // elhi = e * l - h * i + auto* elhi = sub_mul2(me, ml, mh, mi); + // ekgi = e * k - g * i; + auto* ekgi = sub_mul2(me, mk, mg, mi); + // injm = i * n - j * m + auto* injm = sub_mul2(mi, mn, mj, mm); + // enfm = e * n - f * m + auto* enfm = sub_mul2(me, mn, mf, mm); + // ejfi = e * j - f * i; + auto* ejfi = sub_mul2(me, mj, mf, mi); + + auto* neg_b = b.Negation(elem_ty, mb); + // f * kplo - g * jpln + h * jokn + auto* r_00 = sub_add_mul3(mf, kplo, mg, jpln, mh, jokn); + // -b * kplo + c * jpln - d * jokn + auto* r_01 = add_sub_mul3(neg_b, kplo, mc, jpln, md, jokn); + // b * gpho - c * fphn + d * fogn + auto* r_02 = sub_add_mul3(mb, gpho, mc, fphn, md, fogn); + // -b * glhk + c * flhj - d * fkgj + auto* r_03 = add_sub_mul3(neg_b, glhk, mc, flhj, md, fkgj); + + auto* neg_e = b.Negation(elem_ty, me); + auto* neg_a = b.Negation(elem_ty, ma); + // -e * kplo + g * iplm - h * iokm + auto* r_10 = add_sub_mul3(neg_e, kplo, mg, iplm, mh, iokm); + // a * kplo - c * iplm + d * iokm + auto* r_11 = sub_add_mul3(ma, kplo, mc, iplm, md, iokm); + // -a * gpho + c * ephm - d * eogm + auto* r_12 = add_sub_mul3(neg_a, gpho, mc, ephm, md, eogm); + // a * glhk - c * elhi + d * ekgi + auto* r_13 = sub_add_mul3(ma, glhk, mc, elhi, md, ekgi); + + // e * jpln - f * iplm + h * injm + auto* r_20 = sub_add_mul3(me, jpln, mf, iplm, mh, injm); + // -a * jpln + b * iplm - d * injm + auto* r_21 = add_sub_mul3(neg_a, jpln, mb, iplm, md, injm); + // a * fphn - b * ephm + d * enfm + auto* r_22 = sub_add_mul3(ma, fphn, mb, ephm, md, enfm); + // -a * flhj + b * elhi - d * ejfi + auto* r_23 = add_sub_mul3(neg_a, flhj, mb, elhi, md, ejfi); + + // -e * jokn + f * iokm - g * injm + auto* r_30 = add_sub_mul3(neg_e, jokn, mf, iokm, mg, injm); + // a * jokn - b * iokm + c * injm + auto* r_31 = sub_add_mul3(ma, jokn, mb, iokm, mc, injm); + // -a * fogn + b * eogm - c * enfm + auto* r_32 = add_sub_mul3(neg_a, fogn, mb, eogm, mc, enfm); + // a * fkgj - b * ekgi + c * ejfi + auto* r_33 = sub_add_mul3(ma, fkgj, mb, ekgi, mc, ejfi); + + auto* r1 = b.Construct(ty.vec3(elem_ty), r_00, r_01, r_02, r_03); + auto* r2 = b.Construct(ty.vec3(elem_ty), r_10, r_11, r_12, r_13); + auto* r3 = b.Construct(ty.vec3(elem_ty), r_20, r_21, r_22, r_23); + auto* r4 = b.Construct(ty.vec3(elem_ty), r_30, r_31, r_32, r_33); + + auto* m = b.Construct(mat_ty, r1, r2, r3, r4); + auto* inv = b.Multiply(mat_ty, inv_det, m); + call->Result(0)->ReplaceAllUsesWith(inv->Result(0)); + break; + } + default: { + TINT_UNREACHABLE(); + } + } + }); + call->Destroy(); + } }; } // namespace diff --git a/src/tint/lang/spirv/reader/lower/builtins_test.cc b/src/tint/lang/spirv/reader/lower/builtins_test.cc index 95c814efdf..7f0d1c7baa 100644 --- a/src/tint/lang/spirv/reader/lower/builtins_test.cc +++ b/src/tint/lang/spirv/reader/lower/builtins_test.cc @@ -99,5 +99,645 @@ TEST_F(SpirvParser_BuiltinsTest, Normalize_Vector) { EXPECT_EQ(expect, str()); } +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat2x2f) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call(ty.mat2x2(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat2x2(), b.Splat(ty.vec2(), 10_f), + b.Splat(ty.vec2(), 20_f))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = construct vec2(10.0f), vec2(20.0f) + %3:mat2x2 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = construct vec2(10.0f), vec2(20.0f) + %3:f32 = determinant %2 + %4:f32 = div 1.0f, %3 + %5:f32 = negation %4 + %6:f32 = access %2, 0u, 0u + %7:f32 = access %2, 0u, 1u + %8:f32 = access %2, 1u, 0u + %9:f32 = access %2, 1u, 1u + %10:f32 = mul %4, %9 + %11:f32 = mul %5, %7 + %12:f32 = mul %5, %8 + %13:f32 = mul %4, %6 + %14:vec2 = construct %10, %11 + %15:vec2 = construct %12, %13 + %16:mat2x2 = construct %14, %15 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat2x2h) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call(ty.mat2x2(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat2x2(), b.Splat(ty.vec2(), 10_h), + b.Splat(ty.vec2(), 20_h))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = construct vec2(10.0h), vec2(20.0h) + %3:mat2x2 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = construct vec2(10.0h), vec2(20.0h) + %3:f16 = determinant %2 + %4:f16 = div 1.0h, %3 + %5:f16 = negation %4 + %6:f16 = access %2, 0u, 0u + %7:f16 = access %2, 0u, 1u + %8:f16 = access %2, 1u, 0u + %9:f16 = access %2, 1u, 1u + %10:f16 = mul %4, %9 + %11:f16 = mul %5, %7 + %12:f16 = mul %5, %8 + %13:f16 = mul %4, %6 + %14:vec2 = construct %10, %11 + %15:vec2 = construct %12, %13 + %16:mat2x2 = construct %14, %15 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat3x3f) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call( + ty.mat3x3(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat3x3(), b.Splat(ty.vec3(), 10_f), + b.Splat(ty.vec3(), 20_f), b.Splat(ty.vec3(), 30_f))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x3 = construct vec3(10.0f), vec3(20.0f), vec3(30.0f) + %3:mat3x3 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x3 = construct vec3(10.0f), vec3(20.0f), vec3(30.0f) + %3:f32 = determinant %2 + %4:f32 = div 1.0f, %3 + %5:f32 = access %2, 0u, 0u + %6:f32 = access %2, 0u, 1u + %7:f32 = access %2, 0u, 2u + %8:f32 = access %2, 1u, 0u + %9:f32 = access %2, 1u, 1u + %10:f32 = access %2, 1u, 2u + %11:f32 = access %2, 2u, 0u + %12:f32 = access %2, 2u, 1u + %13:f32 = access %2, 2u, 2u + %14:f32 = mul %9, %13 + %15:f32 = mul %10, %12 + %16:f32 = sub %14, %15 + %17:f32 = mul %7, %12 + %18:f32 = mul %6, %13 + %19:f32 = sub %17, %18 + %20:f32 = mul %6, %10 + %21:f32 = mul %7, %9 + %22:f32 = sub %20, %21 + %23:f32 = mul %10, %11 + %24:f32 = mul %8, %13 + %25:f32 = sub %23, %24 + %26:f32 = mul %5, %13 + %27:f32 = mul %7, %11 + %28:f32 = sub %26, %27 + %29:f32 = mul %7, %8 + %30:f32 = mul %5, %10 + %31:f32 = sub %29, %30 + %32:f32 = mul %8, %12 + %33:f32 = mul %9, %11 + %34:f32 = sub %32, %33 + %35:f32 = mul %6, %11 + %36:f32 = mul %5, %12 + %37:f32 = sub %35, %36 + %38:f32 = mul %5, %9 + %39:f32 = mul %6, %8 + %40:f32 = sub %38, %39 + %41:vec3 = construct %16, %19, %22 + %42:vec3 = construct %25, %28, %31 + %43:vec3 = construct %34, %37, %40 + %44:mat3x3 = construct %41, %42, %43 + %45:mat3x3 = mul %4, %44 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat3x3h) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call( + ty.mat3x3(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat3x3(), b.Splat(ty.vec3(), 10_h), + b.Splat(ty.vec3(), 20_h), b.Splat(ty.vec3(), 30_h))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x3 = construct vec3(10.0h), vec3(20.0h), vec3(30.0h) + %3:mat3x3 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x3 = construct vec3(10.0h), vec3(20.0h), vec3(30.0h) + %3:f16 = determinant %2 + %4:f16 = div 1.0h, %3 + %5:f16 = access %2, 0u, 0u + %6:f16 = access %2, 0u, 1u + %7:f16 = access %2, 0u, 2u + %8:f16 = access %2, 1u, 0u + %9:f16 = access %2, 1u, 1u + %10:f16 = access %2, 1u, 2u + %11:f16 = access %2, 2u, 0u + %12:f16 = access %2, 2u, 1u + %13:f16 = access %2, 2u, 2u + %14:f16 = mul %9, %13 + %15:f16 = mul %10, %12 + %16:f16 = sub %14, %15 + %17:f16 = mul %7, %12 + %18:f16 = mul %6, %13 + %19:f16 = sub %17, %18 + %20:f16 = mul %6, %10 + %21:f16 = mul %7, %9 + %22:f16 = sub %20, %21 + %23:f16 = mul %10, %11 + %24:f16 = mul %8, %13 + %25:f16 = sub %23, %24 + %26:f16 = mul %5, %13 + %27:f16 = mul %7, %11 + %28:f16 = sub %26, %27 + %29:f16 = mul %7, %8 + %30:f16 = mul %5, %10 + %31:f16 = sub %29, %30 + %32:f16 = mul %8, %12 + %33:f16 = mul %9, %11 + %34:f16 = sub %32, %33 + %35:f16 = mul %6, %11 + %36:f16 = mul %5, %12 + %37:f16 = sub %35, %36 + %38:f16 = mul %5, %9 + %39:f16 = mul %6, %8 + %40:f16 = sub %38, %39 + %41:vec3 = construct %16, %19, %22 + %42:vec3 = construct %25, %28, %31 + %43:vec3 = construct %34, %37, %40 + %44:mat3x3 = construct %41, %42, %43 + %45:mat3x3 = mul %4, %44 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat4x4f) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call( + ty.mat4x4(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat4x4(), b.Splat(ty.vec4(), 10_f), + b.Splat(ty.vec4(), 20_f), b.Splat(ty.vec4(), 30_f), + b.Splat(ty.vec4(), 40_f))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat4x4 = construct vec4(10.0f), vec4(20.0f), vec4(30.0f), vec4(40.0f) + %3:mat4x4 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat4x4 = construct vec4(10.0f), vec4(20.0f), vec4(30.0f), vec4(40.0f) + %3:f32 = determinant %2 + %4:f32 = div 1.0f, %3 + %5:f32 = access %2, 0u, 0u + %6:f32 = access %2, 0u, 1u + %7:f32 = access %2, 0u, 2u + %8:f32 = access %2, 0u, 3u + %9:f32 = access %2, 1u, 0u + %10:f32 = access %2, 1u, 1u + %11:f32 = access %2, 1u, 2u + %12:f32 = access %2, 1u, 3u + %13:f32 = access %2, 2u, 0u + %14:f32 = access %2, 2u, 1u + %15:f32 = access %2, 2u, 2u + %16:f32 = access %2, 2u, 3u + %17:f32 = access %2, 3u, 0u + %18:f32 = access %2, 3u, 1u + %19:f32 = access %2, 3u, 2u + %20:f32 = access %2, 3u, 3u + %21:f32 = mul %15, %20 + %22:f32 = mul %16, %19 + %23:f32 = sub %21, %22 + %24:f32 = mul %14, %20 + %25:f32 = mul %16, %18 + %26:f32 = sub %24, %25 + %27:f32 = mul %14, %19 + %28:f32 = mul %15, %18 + %29:f32 = sub %27, %28 + %30:f32 = mul %11, %20 + %31:f32 = mul %12, %19 + %32:f32 = sub %30, %31 + %33:f32 = mul %10, %20 + %34:f32 = mul %12, %18 + %35:f32 = sub %33, %34 + %36:f32 = mul %10, %19 + %37:f32 = mul %11, %18 + %38:f32 = sub %36, %37 + %39:f32 = mul %11, %16 + %40:f32 = mul %12, %15 + %41:f32 = sub %39, %40 + %42:f32 = mul %10, %16 + %43:f32 = mul %12, %14 + %44:f32 = sub %42, %43 + %45:f32 = mul %10, %15 + %46:f32 = mul %11, %14 + %47:f32 = sub %45, %46 + %48:f32 = mul %13, %20 + %49:f32 = mul %16, %17 + %50:f32 = sub %48, %49 + %51:f32 = mul %13, %19 + %52:f32 = mul %15, %17 + %53:f32 = sub %51, %52 + %54:f32 = mul %9, %20 + %55:f32 = mul %12, %17 + %56:f32 = sub %54, %55 + %57:f32 = mul %9, %19 + %58:f32 = mul %11, %17 + %59:f32 = sub %57, %58 + %60:f32 = mul %9, %16 + %61:f32 = mul %12, %13 + %62:f32 = sub %60, %61 + %63:f32 = mul %9, %15 + %64:f32 = mul %11, %13 + %65:f32 = sub %63, %64 + %66:f32 = mul %13, %18 + %67:f32 = mul %14, %17 + %68:f32 = sub %66, %67 + %69:f32 = mul %9, %18 + %70:f32 = mul %10, %17 + %71:f32 = sub %69, %70 + %72:f32 = mul %9, %14 + %73:f32 = mul %10, %13 + %74:f32 = sub %72, %73 + %75:f32 = negation %6 + %76:f32 = mul %10, %23 + %77:f32 = mul %11, %26 + %78:f32 = mul %12, %29 + %79:f32 = sub %76, %77 + %80:f32 = add %79, %78 + %81:f32 = mul %75, %23 + %82:f32 = mul %7, %26 + %83:f32 = mul %8, %29 + %84:f32 = add %81, %82 + %85:f32 = sub %84, %83 + %86:f32 = mul %6, %32 + %87:f32 = mul %7, %35 + %88:f32 = mul %8, %38 + %89:f32 = sub %86, %87 + %90:f32 = add %89, %88 + %91:f32 = mul %75, %41 + %92:f32 = mul %7, %44 + %93:f32 = mul %8, %47 + %94:f32 = add %91, %92 + %95:f32 = sub %94, %93 + %96:f32 = negation %9 + %97:f32 = negation %5 + %98:f32 = mul %96, %23 + %99:f32 = mul %11, %50 + %100:f32 = mul %12, %53 + %101:f32 = add %98, %99 + %102:f32 = sub %101, %100 + %103:f32 = mul %5, %23 + %104:f32 = mul %7, %50 + %105:f32 = mul %8, %53 + %106:f32 = sub %103, %104 + %107:f32 = add %106, %105 + %108:f32 = mul %97, %32 + %109:f32 = mul %7, %56 + %110:f32 = mul %8, %59 + %111:f32 = add %108, %109 + %112:f32 = sub %111, %110 + %113:f32 = mul %5, %41 + %114:f32 = mul %7, %62 + %115:f32 = mul %8, %65 + %116:f32 = sub %113, %114 + %117:f32 = add %116, %115 + %118:f32 = mul %9, %26 + %119:f32 = mul %10, %50 + %120:f32 = mul %12, %68 + %121:f32 = sub %118, %119 + %122:f32 = add %121, %120 + %123:f32 = mul %97, %26 + %124:f32 = mul %6, %50 + %125:f32 = mul %8, %68 + %126:f32 = add %123, %124 + %127:f32 = sub %126, %125 + %128:f32 = mul %5, %35 + %129:f32 = mul %6, %56 + %130:f32 = mul %8, %71 + %131:f32 = sub %128, %129 + %132:f32 = add %131, %130 + %133:f32 = mul %97, %44 + %134:f32 = mul %6, %62 + %135:f32 = mul %8, %74 + %136:f32 = add %133, %134 + %137:f32 = sub %136, %135 + %138:f32 = mul %96, %29 + %139:f32 = mul %10, %53 + %140:f32 = mul %11, %68 + %141:f32 = add %138, %139 + %142:f32 = sub %141, %140 + %143:f32 = mul %5, %29 + %144:f32 = mul %6, %53 + %145:f32 = mul %7, %68 + %146:f32 = sub %143, %144 + %147:f32 = add %146, %145 + %148:f32 = mul %97, %38 + %149:f32 = mul %6, %59 + %150:f32 = mul %7, %71 + %151:f32 = add %148, %149 + %152:f32 = sub %151, %150 + %153:f32 = mul %5, %47 + %154:f32 = mul %6, %65 + %155:f32 = mul %7, %74 + %156:f32 = sub %153, %154 + %157:f32 = add %156, %155 + %158:vec3 = construct %80, %85, %90, %95 + %159:vec3 = construct %102, %107, %112, %117 + %160:vec3 = construct %122, %127, %132, %137 + %161:vec3 = construct %142, %147, %152, %157 + %162:mat4x4 = construct %158, %159, %160, %161 + %163:mat4x4 = mul %4, %162 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + +TEST_F(SpirvParser_BuiltinsTest, Inverse_Mat4x4h) { + auto* ep = b.ComputeFunction("foo"); + + b.Append(ep->Block(), [&] { // + b.Call( + ty.mat4x4(), spirv::BuiltinFn::kInverse, + b.Construct(ty.mat4x4(), b.Splat(ty.vec4(), 10_h), + b.Splat(ty.vec4(), 20_h), b.Splat(ty.vec4(), 30_h), + b.Splat(ty.vec4(), 40_h))); + b.Return(ep); + }); + + auto* src = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat4x4 = construct vec4(10.0h), vec4(20.0h), vec4(30.0h), vec4(40.0h) + %3:mat4x4 = spirv.inverse %2 + ret + } +} +)"; + EXPECT_EQ(src, str()); + Run(Builtins); + + auto* expect = R"( +%foo = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat4x4 = construct vec4(10.0h), vec4(20.0h), vec4(30.0h), vec4(40.0h) + %3:f16 = determinant %2 + %4:f16 = div 1.0h, %3 + %5:f16 = access %2, 0u, 0u + %6:f16 = access %2, 0u, 1u + %7:f16 = access %2, 0u, 2u + %8:f16 = access %2, 0u, 3u + %9:f16 = access %2, 1u, 0u + %10:f16 = access %2, 1u, 1u + %11:f16 = access %2, 1u, 2u + %12:f16 = access %2, 1u, 3u + %13:f16 = access %2, 2u, 0u + %14:f16 = access %2, 2u, 1u + %15:f16 = access %2, 2u, 2u + %16:f16 = access %2, 2u, 3u + %17:f16 = access %2, 3u, 0u + %18:f16 = access %2, 3u, 1u + %19:f16 = access %2, 3u, 2u + %20:f16 = access %2, 3u, 3u + %21:f16 = mul %15, %20 + %22:f16 = mul %16, %19 + %23:f16 = sub %21, %22 + %24:f16 = mul %14, %20 + %25:f16 = mul %16, %18 + %26:f16 = sub %24, %25 + %27:f16 = mul %14, %19 + %28:f16 = mul %15, %18 + %29:f16 = sub %27, %28 + %30:f16 = mul %11, %20 + %31:f16 = mul %12, %19 + %32:f16 = sub %30, %31 + %33:f16 = mul %10, %20 + %34:f16 = mul %12, %18 + %35:f16 = sub %33, %34 + %36:f16 = mul %10, %19 + %37:f16 = mul %11, %18 + %38:f16 = sub %36, %37 + %39:f16 = mul %11, %16 + %40:f16 = mul %12, %15 + %41:f16 = sub %39, %40 + %42:f16 = mul %10, %16 + %43:f16 = mul %12, %14 + %44:f16 = sub %42, %43 + %45:f16 = mul %10, %15 + %46:f16 = mul %11, %14 + %47:f16 = sub %45, %46 + %48:f16 = mul %13, %20 + %49:f16 = mul %16, %17 + %50:f16 = sub %48, %49 + %51:f16 = mul %13, %19 + %52:f16 = mul %15, %17 + %53:f16 = sub %51, %52 + %54:f16 = mul %9, %20 + %55:f16 = mul %12, %17 + %56:f16 = sub %54, %55 + %57:f16 = mul %9, %19 + %58:f16 = mul %11, %17 + %59:f16 = sub %57, %58 + %60:f16 = mul %9, %16 + %61:f16 = mul %12, %13 + %62:f16 = sub %60, %61 + %63:f16 = mul %9, %15 + %64:f16 = mul %11, %13 + %65:f16 = sub %63, %64 + %66:f16 = mul %13, %18 + %67:f16 = mul %14, %17 + %68:f16 = sub %66, %67 + %69:f16 = mul %9, %18 + %70:f16 = mul %10, %17 + %71:f16 = sub %69, %70 + %72:f16 = mul %9, %14 + %73:f16 = mul %10, %13 + %74:f16 = sub %72, %73 + %75:f16 = negation %6 + %76:f16 = mul %10, %23 + %77:f16 = mul %11, %26 + %78:f16 = mul %12, %29 + %79:f16 = sub %76, %77 + %80:f16 = add %79, %78 + %81:f16 = mul %75, %23 + %82:f16 = mul %7, %26 + %83:f16 = mul %8, %29 + %84:f16 = add %81, %82 + %85:f16 = sub %84, %83 + %86:f16 = mul %6, %32 + %87:f16 = mul %7, %35 + %88:f16 = mul %8, %38 + %89:f16 = sub %86, %87 + %90:f16 = add %89, %88 + %91:f16 = mul %75, %41 + %92:f16 = mul %7, %44 + %93:f16 = mul %8, %47 + %94:f16 = add %91, %92 + %95:f16 = sub %94, %93 + %96:f16 = negation %9 + %97:f16 = negation %5 + %98:f16 = mul %96, %23 + %99:f16 = mul %11, %50 + %100:f16 = mul %12, %53 + %101:f16 = add %98, %99 + %102:f16 = sub %101, %100 + %103:f16 = mul %5, %23 + %104:f16 = mul %7, %50 + %105:f16 = mul %8, %53 + %106:f16 = sub %103, %104 + %107:f16 = add %106, %105 + %108:f16 = mul %97, %32 + %109:f16 = mul %7, %56 + %110:f16 = mul %8, %59 + %111:f16 = add %108, %109 + %112:f16 = sub %111, %110 + %113:f16 = mul %5, %41 + %114:f16 = mul %7, %62 + %115:f16 = mul %8, %65 + %116:f16 = sub %113, %114 + %117:f16 = add %116, %115 + %118:f16 = mul %9, %26 + %119:f16 = mul %10, %50 + %120:f16 = mul %12, %68 + %121:f16 = sub %118, %119 + %122:f16 = add %121, %120 + %123:f16 = mul %97, %26 + %124:f16 = mul %6, %50 + %125:f16 = mul %8, %68 + %126:f16 = add %123, %124 + %127:f16 = sub %126, %125 + %128:f16 = mul %5, %35 + %129:f16 = mul %6, %56 + %130:f16 = mul %8, %71 + %131:f16 = sub %128, %129 + %132:f16 = add %131, %130 + %133:f16 = mul %97, %44 + %134:f16 = mul %6, %62 + %135:f16 = mul %8, %74 + %136:f16 = add %133, %134 + %137:f16 = sub %136, %135 + %138:f16 = mul %96, %29 + %139:f16 = mul %10, %53 + %140:f16 = mul %11, %68 + %141:f16 = add %138, %139 + %142:f16 = sub %141, %140 + %143:f16 = mul %5, %29 + %144:f16 = mul %6, %53 + %145:f16 = mul %7, %68 + %146:f16 = sub %143, %144 + %147:f16 = add %146, %145 + %148:f16 = mul %97, %38 + %149:f16 = mul %6, %59 + %150:f16 = mul %7, %71 + %151:f16 = add %148, %149 + %152:f16 = sub %151, %150 + %153:f16 = mul %5, %47 + %154:f16 = mul %6, %65 + %155:f16 = mul %7, %74 + %156:f16 = sub %153, %154 + %157:f16 = add %156, %155 + %158:vec3 = construct %80, %85, %90, %95 + %159:vec3 = construct %102, %107, %112, %117 + %160:vec3 = construct %122, %127, %132, %137 + %161:vec3 = construct %142, %147, %152, %157 + %162:mat4x4 = construct %158, %159, %160, %161 + %163:mat4x4 = mul %4, %162 + ret + } +} +)"; + EXPECT_EQ(expect, str()); +} + } // namespace } // namespace tint::spirv::reader::lower diff --git a/src/tint/lang/spirv/reader/parser/BUILD.bazel b/src/tint/lang/spirv/reader/parser/BUILD.bazel index e99e6d067f..74fa26b6d9 100644 --- a/src/tint/lang/spirv/reader/parser/BUILD.bazel +++ b/src/tint/lang/spirv/reader/parser/BUILD.bazel @@ -86,6 +86,7 @@ cc_library( "constant_test.cc", "function_test.cc", "helper_test.h", + "import_glsl_std450_test.cc", "import_test.cc", "memory_test.cc", "misc_test.cc", diff --git a/src/tint/lang/spirv/reader/parser/BUILD.cmake b/src/tint/lang/spirv/reader/parser/BUILD.cmake index d4c22cbd32..a87f1ebdf6 100644 --- a/src/tint/lang/spirv/reader/parser/BUILD.cmake +++ b/src/tint/lang/spirv/reader/parser/BUILD.cmake @@ -95,6 +95,7 @@ tint_add_target(tint_lang_spirv_reader_parser_test test lang/spirv/reader/parser/constant_test.cc lang/spirv/reader/parser/function_test.cc lang/spirv/reader/parser/helper_test.h + lang/spirv/reader/parser/import_glsl_std450_test.cc lang/spirv/reader/parser/import_test.cc lang/spirv/reader/parser/memory_test.cc lang/spirv/reader/parser/misc_test.cc diff --git a/src/tint/lang/spirv/reader/parser/BUILD.gn b/src/tint/lang/spirv/reader/parser/BUILD.gn index 9f4c271f76..63ce83183a 100644 --- a/src/tint/lang/spirv/reader/parser/BUILD.gn +++ b/src/tint/lang/spirv/reader/parser/BUILD.gn @@ -94,6 +94,7 @@ if (tint_build_unittests) { "constant_test.cc", "function_test.cc", "helper_test.h", + "import_glsl_std450_test.cc", "import_test.cc", "memory_test.cc", "misc_test.cc", diff --git a/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc new file mode 100644 index 0000000000..b69f6a7a9e --- /dev/null +++ b/src/tint/lang/spirv/reader/parser/import_glsl_std450_test.cc @@ -0,0 +1,147 @@ +// Copyright 2024 The Dawn & Tint Authors +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "src/tint/lang/spirv/reader/parser/helper_test.h" + +namespace tint::spirv::reader { +namespace { + +std::string Preamble() { + return R"( + OpCapability Shader + %glsl = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %100 "main" + OpExecutionMode %100 LocalSize 1 1 1 + + %void = OpTypeVoid + %voidfn = OpTypeFunction %void + + %float = OpTypeFloat 32 + + %v2float = OpTypeVector %float 2 + %v3float = OpTypeVector %float 3 + %v4float = OpTypeVector %float 4 + %mat2v2float = OpTypeMatrix %v2float 2 + %mat3v3float = OpTypeMatrix %v3float 3 + %mat4v4float = OpTypeMatrix %v4float 4 + + %float_50 = OpConstant %float 50 + %float_60 = OpConstant %float 60 + %float_70 = OpConstant %float 70 + + %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 + %v3float_50_60_70 = OpConstantComposite %v3float %float_50 %float_60 %float_70 + %v4float_50_50_50_50 = OpConstantComposite %v4float %float_50 %float_50 %float_50 %float_50 + + %mat2v2float_50_60 = OpConstantComposite %mat2v2float %v2float_50_60 %v2float_50_60 + %mat3v3float_50_60_70 = OpConstantComposite %mat3v3float %v3float_50_60_70 %v3float_50_60_70 %v3float_50_60_70 + %mat4v4float_50_50_50_50 = OpConstantComposite %mat4v4float %v4float_50_50_50_50 %v4float_50_50_50_50 %v4float_50_50_50_50 %v4float_50_50_50_50 + + %100 = OpFunction %void None %voidfn + %entry = OpLabel +)"; +} + +TEST_F(SpirvParserTest, GlslStd450_MatrixInverse_mat2x2) { + EXPECT_IR(Preamble() + R"( + %1 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 + %2 = OpCopyObject %mat2v2float %1 + OpReturn + OpFunctionEnd + )", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = spirv.inverse mat2x2(vec2(50.0f, 60.0f)) + %3:mat2x2 = let %2 + ret + } +} +)"); +} + +TEST_F(SpirvParserTest, GlslStd450_MatrixInverse_mat3x3) { + EXPECT_IR(Preamble() + R"( + %1 = OpExtInst %mat3v3float %glsl MatrixInverse %mat3v3float_50_60_70 + %2 = OpCopyObject %mat3v3float %1 + OpReturn + OpFunctionEnd + )", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat3x3 = spirv.inverse mat3x3(vec3(50.0f, 60.0f, 70.0f)) + %3:mat3x3 = let %2 + ret + } +} +)"); +} + +TEST_F(SpirvParserTest, GlslStd450_MatrixInverse_mat4x4) { + EXPECT_IR(Preamble() + R"( + %1 = OpExtInst %mat4v4float %glsl MatrixInverse %mat4v4float_50_50_50_50 + %2 = OpCopyObject %mat4v4float %1 + OpReturn + OpFunctionEnd + )", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat4x4 = spirv.inverse mat4x4(vec4(50.0f)) + %3:mat4x4 = let %2 + ret + } +} +)"); +} + +TEST_F(SpirvParserTest, GlslStd450_MatrixInverse_MultipleInScope) { + EXPECT_IR(Preamble() + R"( + %1 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 + %2 = OpExtInst %mat2v2float %glsl MatrixInverse %mat2v2float_50_60 + %3 = OpCopyObject %mat2v2float %1 + %4 = OpCopyObject %mat2v2float %2 + OpReturn + OpFunctionEnd + )", + R"( +%main = @compute @workgroup_size(1u, 1u, 1u) func():void { + $B1: { + %2:mat2x2 = spirv.inverse mat2x2(vec2(50.0f, 60.0f)) + %3:mat2x2 = spirv.inverse mat2x2(vec2(50.0f, 60.0f)) + %4:mat2x2 = let %2 + %5:mat2x2 = let %3 + ret + } +} +)"); +} + +} // namespace +} // namespace tint::spirv::reader diff --git a/src/tint/lang/spirv/reader/parser/parser.cc b/src/tint/lang/spirv/reader/parser/parser.cc index a32a8fbbef..ffd6f70d4c 100644 --- a/src/tint/lang/spirv/reader/parser/parser.cc +++ b/src/tint/lang/spirv/reader/parser/parser.cc @@ -759,6 +759,8 @@ class Parser { switch (ext_opcode) { case GLSLstd450Normalize: return spirv::BuiltinFn::kNormalize; + case GLSLstd450MatrixInverse: + return spirv::BuiltinFn::kInverse; default: break; } diff --git a/src/tint/lang/spirv/spirv.def b/src/tint/lang/spirv/spirv.def index cf7b5fc4ce..db03a8ded4 100644 --- a/src/tint/lang/spirv/spirv.def +++ b/src/tint/lang/spirv/spirv.def @@ -321,6 +321,10 @@ implicit(T: f32_f16, N: num) fn vector_times_scalar(vec, T) -> vec implicit(T: f32_f16) fn normalize(T) -> T implicit(N: num, T: f32_f16) fn normalize(vec) -> vec +implicit(T: f32_f16) fn inverse(mat2x2) -> mat2x2 +implicit(T: f32_f16) fn inverse(mat3x3) -> mat3x3 +implicit(T: f32_f16) fn inverse(mat4x4) -> mat4x4 + //////////////////////////////////////////////////////////////////////////////// // SPV_KHR_integer_dot_product instructions //////////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/lang/spirv/writer/printer/printer.cc b/src/tint/lang/spirv/writer/printer/printer.cc index b69b23cb46..ddc93079da 100644 --- a/src/tint/lang/spirv/writer/printer/printer.cc +++ b/src/tint/lang/spirv/writer/printer/printer.cc @@ -1385,6 +1385,11 @@ class Printer { case spirv::BuiltinFn::kSelect: op = spv::Op::OpSelect; break; + case spirv::BuiltinFn::kInverse: + op = spv::Op::OpExtInst; + operands.push_back(ImportGlslStd450()); + operands.push_back(U32Operand(GLSLstd450MatrixInverse)); + break; case spirv::BuiltinFn::kUdot: module_.PushExtension("SPV_KHR_integer_dot_product"); module_.PushCapability(SpvCapabilityDotProductKHR);