From fb4d09470dc4be674de810dbfbf2d3764e2970ba Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 19 Dec 2024 12:31:42 -0500 Subject: [PATCH] Ukernel lowering for data-tiled `multi_mma` with `mfma_i32_16x16x32_i8` (#19522) This finishes implementing an initial ukernel for `multi_mma` for `DataTiledMMAAttr` with `kind = mfma_i32_16x16x32_i8`. The ukernel takes unroll and subgroup parameters as function parameters. The idea is that once inlining works as intended, these function parameters will be constants and the optimized code will be the same as if we had hardcoded specific values. This inlining isn't happening at the moment, but that is a bug that we should fix first. It is happening in LLVMCPU, so that's probably something missing in LLVMGPU. The ukernel file has a comment with a few TODOs to get from this initial naive ukernel to something faster. The first step is to fix the above-mentioned inlining problem, then get shared memory, then get better instruction scheduling. Signed-off-by: Benoit Jacob --- .../target/ROCM/builtins/ukernel/BUILD.bazel | 8 +-- .../ROCM/builtins/ukernel/CMakeLists.txt | 8 +-- .../target/ROCM/builtins/ukernel/common.h | 1 - ...uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c | 65 +++++++++++++++++++ ...i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c | 53 --------------- .../test/config_ukernel_multi_mma_gfx942.mlir | 4 +- .../Codegen/Common/GPU/GPULowerToUKernels.cpp | 52 +++++++++++++-- .../compiler/Codegen/Common/GPU/Passes.td | 2 + .../GPU/test/gpu_lower_to_ukernels.mlir | 32 +++++++-- .../Dialect/GPU/Transforms/Transforms.cpp | 11 +--- .../test/distribute_mma_to_lanes.mlir | 2 +- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 3 + .../LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp | 13 +--- tests/e2e/matmul/CMakeLists.txt | 30 +++++++++ 14 files changed, 189 insertions(+), 95 deletions(-) create mode 100644 compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c delete mode 100644 compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel b/compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel index 840d45fc27cb..7bcedce7e29f 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel +++ b/compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel @@ -60,19 +60,19 @@ argmax_bc_files = [ ] iree_amdgpu_bitcode_library( - name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942", + name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942", srcs = [ "common.h", - "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c", + "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c", ], - out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc", + out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc", gpu_arch = "gfx942", ) iree_c_embed_data( name = "iree_uk_amdgpu_bitcode", srcs = argmax_bc_files + [ - "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc", + "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc", ], c_file_output = "iree_uk_amdgpu_bitcode.c", flatten = True, diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt index ad1a19028a5b..97962aaff481 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt +++ b/compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt @@ -208,14 +208,14 @@ iree_amdgpu_bitcode_library( iree_amdgpu_bitcode_library( NAME - iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942 + iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942 GPU_ARCH gfx942 SRCS "common.h" - "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c" + "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c" OUT - "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc" + "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc" ) iree_c_embed_data( @@ -238,7 +238,7 @@ iree_c_embed_data( "iree_uk_amdgpu_argmax_f32i64.gfx1100.bc" "iree_uk_amdgpu_argmax_f32i64.gfx90a.bc" "iree_uk_amdgpu_argmax_f32i64.gfx942.bc" - "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc" + "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc" C_FILE_OUTPUT "iree_uk_amdgpu_bitcode.c" H_FILE_OUTPUT diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/common.h b/compiler/plugins/target/ROCM/builtins/ukernel/common.h index 14b65a253c5d..d046986cc9b5 100644 --- a/compiler/plugins/target/ROCM/builtins/ukernel/common.h +++ b/compiler/plugins/target/ROCM/builtins/ukernel/common.h @@ -61,7 +61,6 @@ typedef __UINT64_TYPE__ uint64_t; // Vector typedefs //===----------------------------------------------------------------------===// -typedef __attribute__((__vector_size__(8 * 2))) int64_t int64x2_t; typedef __attribute__((__vector_size__(4 * 4))) int32_t int32x4_t; //===----------------------------------------------------------------------===// diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c new file mode 100644 index 000000000000..9029a86ddb59 --- /dev/null +++ b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c @@ -0,0 +1,65 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" + +// Very naive kernel. TODO(bjacob): +// 1. Inlining: the `always_inline` attribute here is correctly preserved in +// the bitcode, but isn't having the intended effect of inlining calls to +// this function. Making that work is key as various function parameters +// (e.g. `unroll_m`) are meant to be constants. +// 2. Shared memory: can't allocate it within the microkernel (which is just a +// helper device function, not the actual amdgpu_kernel). Need to get it +// passed down here as a `T [[clang::address_space(3)]] *` parameter. +// 3. Better scheduling via either barrier intrinsics or inline assemby. +// 4. Subgroups1x4 being asymmetric is a historical accident... should be 2x2. +[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8( + const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer, + int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size, + int32_t unroll_m, int32_t subgroups_m, int32_t unroll_n, + int32_t subgroups_n, int32_t unroll_k) { + /* + TODO(bjacob): reenable this once inlining works. + // Load existing accumulators. This is a VLA, but should become fixed-size + // once this function is inlined and unroll_* factors become constants. + int32x4_t c[unroll_m][unroll_n]; + */ + // Load existing accumulators. + if (unroll_m > 8 || unroll_n > 2) { + __builtin_trap(); + } + int32x4_t c[8][2]; + int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset); + for (int m = 0; m < unroll_m; ++m) { + for (int n = 0; n < unroll_n; ++n) { + c[m][n] = c_global[64 * (m * unroll_n + n)]; + } + } + + // Arithmetic loop. + const int64_t *a_global = (const int64_t *)(a_buffer + a_offset); + const int64_t *b_global = (const int64_t *)(b_buffer + b_offset); + for (int k_outer = 0; k_outer < k_size; ++k_outer) { + for (int m = 0; m < unroll_m; ++m) { + for (int n = 0; n < unroll_n; ++n) { + for (int k = 0; k < unroll_k; ++k) { + c[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8( + a_global[64 * unroll_k * m + k], b_global[64 * unroll_k * n + k], + c[m][n], 0, 0, 0); + } + } + } + a_global += 64 * unroll_m * subgroups_m * unroll_k; + b_global += 64 * unroll_n * subgroups_n * unroll_k; + } + + // Store accumulators. + for (int m = 0; m < unroll_m; ++m) { + for (int n = 0; n < unroll_n; ++n) { + c_global[64 * (m * unroll_n + n)] = c[m][n]; + } + } +} diff --git a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c b/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c deleted file mode 100644 index 7d0e2643050e..000000000000 --- a/compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h" - -// Very naive kernel. TODO(bjacob): -// 1. Shared memory: can't allocate it within the microkernel (which is just a -// helper device function, not the actual amdgpu_kernel). Need to get it -// passed down here as a `T [[clang::address_space(3)]] *` parameter. -// 2. Better scheduling via either barrier intrinsics or inline assemby. -// 3. Subgroups1x4 being asymmetric is a historical accident... should be 2x2. -[[clang::always_inline]] void -iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4( - const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer, - int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int64_t k_size) { - int tid = __builtin_amdgcn_workitem_id_x(); - - // Load existing accumulators. - int32x4_t acc[8][2] = {{0}}; - int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset); - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 2; ++j) { - acc[i][j] = c_global[256 * (2 * i + j) + tid]; - } - } - - // Arithmetic loop. - const int64x2_t *a_global = - (const int64x2_t *)(a_buffer + a_offset) + (tid % 64); - const int64x2_t *b_global = (const int64x2_t *)(b_buffer + b_offset) + tid; - for (int k_outer = 0; k_outer < k_size; ++k_outer) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 2; ++j) { - for (int k = 0; k < 2; ++k) { - acc[i][j] = __builtin_amdgcn_mfma_i32_16x16x32_i8( - a_global[64 * i][k], b_global[256 * j][k], acc[i][j], 0, 0, 0); - } - } - } - a_global += 512; - b_global += 512; - } - - // Store accumulators. - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 2; ++j) { - c_global[256 * (2 * i + j) + tid] = acc[i][j]; - } - } -} diff --git a/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir b/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir index 646418f80666..2fd78139ed59 100644 --- a/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir +++ b/compiler/plugins/target/ROCM/test/config_ukernel_multi_mma_gfx942.mlir @@ -23,7 +23,7 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>, // CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8 // CHECK: iree_gpu.multi_mma -// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc" +// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc" // CHECK-NOT: promote_operands // CHECK-SAME: reduction = [0, 0, 0] -// CHECK-SAME: #iree_gpu.ukernel_config matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) { Value input = op.getDpsInputOperand(0)->get(); - auto inputType = cast(input.getType()); Value index = op.getDpsInitOperand(1)->get(); auto indexType = cast(index.getType()); - std::string suffix; - llvm::raw_string_ostream(suffix) - << inputType.getElementType() << indexType.getElementType(); auto loweringConfig = getLoweringConfig(op); if (!loweringConfig) { return rewriter.notifyMatchFailure(op, "no lowering_config on this op"); @@ -84,6 +81,50 @@ struct LowerArgmaxToUKernelPattern : OpRewritePattern { } }; +struct LowerMultiMmaToUKernelPattern : OpRewritePattern { + LowerMultiMmaToUKernelPattern(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp op, + PatternRewriter &rewriter) const override { + auto loweringConfig = getLoweringConfig(op); + if (!loweringConfig) { + return rewriter.notifyMatchFailure(op, "no lowering_config on this op"); + } + IREE::GPU::UKernelConfigAttr ukernelAttr = + IREE::GPU::getUkernelSpec(loweringConfig); + if (!ukernelAttr) { + return rewriter.notifyMatchFailure(op, "no ukernel selected for this op"); + } + auto mma = dyn_cast(op.getKind()); + if (!mma) { + return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr"); + } + auto castIndexToI32 = [&](Value val) { + return rewriter.create(op.getLoc(), + rewriter.getI32Type(), val); + }; + auto constI32 = [&](int val) { + return rewriter.create(op.getLoc(), val, + rewriter.getI32Type()); + }; + Value k = castIndexToI32( + rewriter.create(op.getLoc(), op.getLhs(), 1)); + Value unrollM = constI32(mma.getUnrollM()); + Value subgroupsM = constI32(mma.getSubgroupsM()); + Value unrollN = constI32(mma.getUnrollN()); + Value subgroupsN = constI32(mma.getSubgroupsN()); + Value unrollK = constI32(mma.getUnrollK()); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getAccType()}, ukernelAttr.getName(), + ValueRange{op.getLhs(), op.getRhs()}, op.getAcc(), + ValueRange{k, unrollM, subgroupsM, unrollN, subgroupsN, unrollK}, + ukernelAttr.getDefAttrs(), + /*strided_outer_dims=*/rewriter.getIndexAttr(0)); + return success(); + } +}; + struct GPULowerToUKernelsPass final : impl::GPULowerToUKernelsPassBase { void runOnOperation() override { @@ -101,7 +142,8 @@ struct GPULowerToUKernelsPass final // evidence that it is difficult for codegen to consistently approach // microkernels performance, and that consideration overrides the benefit of // fusions for these ops. - patterns.insert(context); + patterns.add( + context); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index ff2b2b94f9b2..24552cbdfee0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -111,6 +111,8 @@ def GPULowerToUKernelsPass : let dependentDialects = [ "::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect", "::mlir::iree_compiler::IREE::GPU::IREEGPUDialect", + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", ]; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir index 7acab19f945a..bc9331fea2cc 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_lower_to_ukernels.mlir @@ -1,9 +1,7 @@ // RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s #config = #iree_gpu.lowering_config<{ukernel = #iree_gpu.ukernel_config}> -func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { - hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> -} { +func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> { %c0_i64 = arith.constant 0 : i64 %cst = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<1xi64> @@ -42,9 +40,7 @@ func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tenso // ----- -func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes { - hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}> -} { +func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> { %c0_i64 = arith.constant 0 : i64 %cst = arith.constant 0xFF800000 : f32 %0 = tensor.empty() : tensor<1xi64> @@ -70,3 +66,27 @@ func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> te //CHECK-LABEL: func @argmax_f32i64_without_selected_ukernel( // CHECK-NOT: iree_codegen.ukernel.generic // CHECK: linalg.generic + +// ----- + +func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x1x1x2x8xi8>, %b : tensor<1x2x1x2x1x1x2x8xi8>, %c : tensor<1x1x1x8x2x1x1x4xi32>) -> tensor<1x1x1x8x2x1x1x4xi32> { + %d = iree_gpu.multi_mma %a, %b, %c { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = [#iree_gpu.iterator_type, #iree_gpu.iterator_type, #iree_gpu.iterator_type], + kind = #iree_gpu.data_tiled_mma_layout, + lowering_config = #iree_gpu.lowering_config<{ + reduction = [0, 0, 0], + ukernel = #iree_gpu.ukernel_config, + workgroup = [1, 1, 0]}> + } : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32> + return %d : tensor<1x1x1x8x2x1x1x4xi32> +} + +// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8( +// CHECK-DAG: %c2_i32 = arith.constant 2 : i32 +// CHECK-DAG: %c8_i32 = arith.constant 8 : i32 +// CHECK-DAG: %c1_i32 = arith.constant 1 : i32 +// CHECK-DAG: %c4_i32 = arith.constant 4 : i32 +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic +// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8" +// CHECK-SAME: (%c2_i32, %c8_i32, %c1_i32, %c2_i32, %c4_i32, %c2_i32 : i32, i32, i32, i32, i32, i32) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 75bf5e51d54c..659f2a9487a1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -620,16 +620,11 @@ distributeMultiMmaOp(RewriterBase &rewriter, IREE::GPU::MultiMmaOp mmaOp, accStrides); // Step 3. Create the new multi_mma op. - auto newKind = mmaOp.getKind(); - if (auto dataTiledMma = dyn_cast(newKind)) { - newKind = DataTiledMMAAttr::get( - context, dataTiledMma.getIntrinsic(), dataTiledMma.getUnrollM(), - /*subgroups_m=*/1, dataTiledMma.getUnrollN(), - /*subgroups_n=*/1, dataTiledMma.getUnrollK()); - } auto newMmaOp = rewriter.create( loc, lhsSlice, rhsSlice, accSlice, mmaOp.getIndexingMaps(), - mmaOp.getIteratorTypes(), newKind); + mmaOp.getIteratorTypes(), mmaOp.getKind()); + + newMmaOp->setDiscardableAttrs(mmaOp->getDiscardableAttrDictionary()); // Step 4. Insert the result of the multi_mma using the same offsets/sizes as // the accumulator slice. diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir index 07729a11e2b5..a5a0ff14e9cb 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir @@ -471,7 +471,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor< // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] // CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] -// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout} +// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout} // CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] // CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index d460a1b9f56b..f8399d3c69a2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -410,6 +410,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, } funcPassManager.addPass(IREE::GPU::createDistributeMmaToLanesPass()); + // Step 4.5. Things that need to happen right after distribution to threads. + funcPassManager.addPass(createGPULowerToUKernelsPass()); + // Normalize loop bounds for later lowerings. funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass( NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp index 453669db7426..8d81cf78ec61 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Utils/LLVMGPUSelectUKernels.cpp @@ -42,17 +42,8 @@ getUKernelNameAndSuffixForMultiMma(IREE::GPU::MultiMmaOp op) { if (!mma) { return {}; // Only handling DataTiledMMAAttr for now. } - std::string suffix{ - stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()}; - if (mma.getUnrollM() != 1 || mma.getUnrollN() != 1 || mma.getUnrollK() != 1) { - suffix += llvm::formatv("_unroll{}x{}x{}", mma.getUnrollM(), - mma.getUnrollN(), mma.getUnrollK()); - } - if (mma.getSubgroupsM() != 1 || mma.getSubgroupsN() != 1) { - suffix += llvm::formatv("_subgroups{}x{}", mma.getSubgroupsM(), - mma.getSubgroupsN()); - } - return {"multi_mma", suffix}; + return {"multi_mma", + stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()}; } // Returns ukernel name and suffix for any op. Empty name = no ukernel. diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index b744d346ebef..cf1ed28038d1 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -1600,6 +1600,36 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_cdna3_dt_uk_i8 + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=i8" + "--acc_type=i32" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + "--iree-opt-data-tiling" + "--iree-global-opt-experimental-rocm-data-tiling" + "--iree-global-opt-enable-early-materialization=true" + "--iree-hip-enable-ukernels=multi_mma" + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + iree_generated_e2e_runner_test( NAME e2e_matmul_cdna3_dt_f32