Skip to content

Commit

Permalink
Ukernel lowering for data-tiled multi_mma with `mfma_i32_16x16x32_i…
Browse files Browse the repository at this point in the history
…8` (iree-org#19522)

This finishes implementing an initial ukernel for `multi_mma` for
`DataTiledMMAAttr` with `kind = mfma_i32_16x16x32_i8`.

The ukernel takes unroll and subgroup parameters as function parameters.
The idea is that once inlining works as intended, these function
parameters will be constants and the optimized code will be the same as
if we had hardcoded specific values. This inlining isn't happening at
the moment, but that is a bug that we should fix first. It is happening
in LLVMCPU, so that's probably something missing in LLVMGPU.

The ukernel file has a comment with a few TODOs to get from this initial
naive ukernel to something faster. The first step is to fix the
above-mentioned inlining problem, then get shared memory, then get
better instruction scheduling.

Signed-off-by: Benoit Jacob <jacob.benoit.1@gmail.com>
  • Loading branch information
bjacob authored Dec 19, 2024
1 parent 5c4bc67 commit fb4d094
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 95 deletions.
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ argmax_bc_files = [
]

iree_amdgpu_bitcode_library(
name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942",
name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942",
srcs = [
"common.h",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c",
],
out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc",
gpu_arch = "gfx942",
)

iree_c_embed_data(
name = "iree_uk_amdgpu_bitcode",
srcs = argmax_bc_files + [
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc",
],
c_file_output = "iree_uk_amdgpu_bitcode.c",
flatten = True,
Expand Down
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ iree_amdgpu_bitcode_library(

iree_amdgpu_bitcode_library(
NAME
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942
GPU_ARCH
gfx942
SRCS
"common.h"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c"
OUT
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
)

iree_c_embed_data(
Expand All @@ -238,7 +238,7 @@ iree_c_embed_data(
"iree_uk_amdgpu_argmax_f32i64.gfx1100.bc"
"iree_uk_amdgpu_argmax_f32i64.gfx90a.bc"
"iree_uk_amdgpu_argmax_f32i64.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
C_FILE_OUTPUT
"iree_uk_amdgpu_bitcode.c"
H_FILE_OUTPUT
Expand Down
1 change: 0 additions & 1 deletion compiler/plugins/target/ROCM/builtins/ukernel/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ typedef __UINT64_TYPE__ uint64_t;
// Vector typedefs
//===----------------------------------------------------------------------===//

typedef __attribute__((__vector_size__(8 * 2))) int64_t int64x2_t;
typedef __attribute__((__vector_size__(4 * 4))) int32_t int32x4_t;

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

// Very naive kernel. TODO(bjacob):
// 1. Inlining: the `always_inline` attribute here is correctly preserved in
// the bitcode, but isn't having the intended effect of inlining calls to
// this function. Making that work is key as various function parameters
// (e.g. `unroll_m`) are meant to be constants.
// 2. Shared memory: can't allocate it within the microkernel (which is just a
// helper device function, not the actual amdgpu_kernel). Need to get it
// passed down here as a `T [[clang::address_space(3)]] *` parameter.
// 3. Better scheduling via either barrier intrinsics or inline assemby.
// 4. Subgroups1x4 being asymmetric is a historical accident... should be 2x2.
[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8(
const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer,
int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size,
int32_t unroll_m, int32_t subgroups_m, int32_t unroll_n,
int32_t subgroups_n, int32_t unroll_k) {
/*
TODO(bjacob): reenable this once inlining works.
// Load existing accumulators. This is a VLA, but should become fixed-size
// once this function is inlined and unroll_* factors become constants.
int32x4_t c[unroll_m][unroll_n];
*/
// Load existing accumulators.
if (unroll_m > 8 || unroll_n > 2) {
__builtin_trap();
}
int32x4_t c[8][2];
int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset);
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
c[m][n] = c_global[64 * (m * unroll_n + n)];
}
}

// Arithmetic loop.
const int64_t *a_global = (const int64_t *)(a_buffer + a_offset);
const int64_t *b_global = (const int64_t *)(b_buffer + b_offset);
for (int k_outer = 0; k_outer < k_size; ++k_outer) {
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
for (int k = 0; k < unroll_k; ++k) {
c[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8(
a_global[64 * unroll_k * m + k], b_global[64 * unroll_k * n + k],
c[m][n], 0, 0, 0);
}
}
}
a_global += 64 * unroll_m * subgroups_m * unroll_k;
b_global += 64 * unroll_n * subgroups_n * unroll_k;
}

// Store accumulators.
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
c_global[64 * (m * unroll_n + n)] = c[m][n];
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>,

// CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8
// CHECK: iree_gpu.multi_mma
// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
// CHECK-NOT: promote_operands
// CHECK-SAME: reduction = [0, 0, 0]
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4"
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
Expand All @@ -33,12 +34,8 @@ namespace {
static FailureOr<IREE::Codegen::UKernelOpInterface>
matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) {
Value input = op.getDpsInputOperand(0)->get();
auto inputType = cast<ShapedType>(input.getType());
Value index = op.getDpsInitOperand(1)->get();
auto indexType = cast<ShapedType>(index.getType());
std::string suffix;
llvm::raw_string_ostream(suffix)
<< inputType.getElementType() << indexType.getElementType();
auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
if (!loweringConfig) {
return rewriter.notifyMatchFailure(op, "no lowering_config on this op");
Expand Down Expand Up @@ -84,6 +81,50 @@ struct LowerArgmaxToUKernelPattern : OpRewritePattern<linalg::GenericOp> {
}
};

struct LowerMultiMmaToUKernelPattern : OpRewritePattern<IREE::GPU::MultiMmaOp> {
LowerMultiMmaToUKernelPattern(MLIRContext *context)
: OpRewritePattern<IREE::GPU::MultiMmaOp>(context) {}

LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp op,
PatternRewriter &rewriter) const override {
auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
if (!loweringConfig) {
return rewriter.notifyMatchFailure(op, "no lowering_config on this op");
}
IREE::GPU::UKernelConfigAttr ukernelAttr =
IREE::GPU::getUkernelSpec(loweringConfig);
if (!ukernelAttr) {
return rewriter.notifyMatchFailure(op, "no ukernel selected for this op");
}
auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
if (!mma) {
return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr");
}
auto castIndexToI32 = [&](Value val) {
return rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getI32Type(), val);
};
auto constI32 = [&](int val) {
return rewriter.create<arith::ConstantIntOp>(op.getLoc(), val,
rewriter.getI32Type());
};
Value k = castIndexToI32(
rewriter.create<tensor::DimOp>(op.getLoc(), op.getLhs(), 1));
Value unrollM = constI32(mma.getUnrollM());
Value subgroupsM = constI32(mma.getSubgroupsM());
Value unrollN = constI32(mma.getUnrollN());
Value subgroupsN = constI32(mma.getSubgroupsN());
Value unrollK = constI32(mma.getUnrollK());
rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
op, TypeRange{op.getAccType()}, ukernelAttr.getName(),
ValueRange{op.getLhs(), op.getRhs()}, op.getAcc(),
ValueRange{k, unrollM, subgroupsM, unrollN, subgroupsN, unrollK},
ukernelAttr.getDefAttrs(),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
return success();
}
};

struct GPULowerToUKernelsPass final
: impl::GPULowerToUKernelsPassBase<GPULowerToUKernelsPass> {
void runOnOperation() override {
Expand All @@ -101,7 +142,8 @@ struct GPULowerToUKernelsPass final
// evidence that it is difficult for codegen to consistently approach
// microkernels performance, and that consideration overrides the benefit of
// fusions for these ops.
patterns.insert<LowerArgmaxToUKernelPattern>(context);
patterns.add<LowerArgmaxToUKernelPattern, LowerMultiMmaToUKernelPattern>(
context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def GPULowerToUKernelsPass :
let dependentDialects = [
"::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::arith::ArithDialect",
"::mlir::tensor::TensorDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s

#config = #iree_gpu.lowering_config<{ukernel = #iree_gpu.ukernel_config<name = "some_ukernel", def_attrs = {vm.import.module = "rocm"}>}>
func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
Expand Down Expand Up @@ -42,9 +40,7 @@ func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tenso

// -----

func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
Expand All @@ -70,3 +66,27 @@ func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> te
//CHECK-LABEL: func @argmax_f32i64_without_selected_ukernel(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic

// -----

func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x1x1x2x8xi8>, %b : tensor<1x2x1x2x1x1x2x8xi8>, %c : tensor<1x1x1x8x2x1x1x4xi32>) -> tensor<1x1x1x8x2x1x1x4xi32> {
%d = iree_gpu.multi_mma %a, %b, %c {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2>,
lowering_config = #iree_gpu.lowering_config<{
reduction = [0, 0, 0],
ukernel = #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", def_attrs = {vm.import.module = "rocm"}>,
workgroup = [1, 1, 0]}>
} : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32>
return %d : tensor<1x1x1x8x2x1x1x4xi32>
}

// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8(
// CHECK-DAG: %c2_i32 = arith.constant 2 : i32
// CHECK-DAG: %c8_i32 = arith.constant 8 : i32
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-DAG: %c4_i32 = arith.constant 4 : i32
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic
// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
// CHECK-SAME: (%c2_i32, %c8_i32, %c1_i32, %c2_i32, %c4_i32, %c2_i32 : i32, i32, i32, i32, i32, i32)
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,11 @@ distributeMultiMmaOp(RewriterBase &rewriter, IREE::GPU::MultiMmaOp mmaOp,
accStrides);

// Step 3. Create the new multi_mma op.
auto newKind = mmaOp.getKind();
if (auto dataTiledMma = dyn_cast<DataTiledMMAAttr>(newKind)) {
newKind = DataTiledMMAAttr::get(
context, dataTiledMma.getIntrinsic(), dataTiledMma.getUnrollM(),
/*subgroups_m=*/1, dataTiledMma.getUnrollN(),
/*subgroups_n=*/1, dataTiledMma.getUnrollK());
}
auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
loc, lhsSlice, rhsSlice, accSlice, mmaOp.getIndexingMaps(),
mmaOp.getIteratorTypes(), newKind);
mmaOp.getIteratorTypes(), mmaOp.getKind());

newMmaOp->setDiscardableAttrs(mmaOp->getDiscardableAttrDictionary());

// Step 4. Insert the result of the multi_mma using the same offsets/sizes as
// the accumulator slice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_k = 4>}
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, unroll_k = 4>}
// CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
}
funcPassManager.addPass(IREE::GPU::createDistributeMmaToLanesPass());

// Step 4.5. Things that need to happen right after distribution to threads.
funcPassManager.addPass(createGPULowerToUKernelsPass());

// Normalize loop bounds for later lowerings.
funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass(
NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,8 @@ getUKernelNameAndSuffixForMultiMma(IREE::GPU::MultiMmaOp op) {
if (!mma) {
return {}; // Only handling DataTiledMMAAttr for now.
}
std::string suffix{
stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()};
if (mma.getUnrollM() != 1 || mma.getUnrollN() != 1 || mma.getUnrollK() != 1) {
suffix += llvm::formatv("_unroll{}x{}x{}", mma.getUnrollM(),
mma.getUnrollN(), mma.getUnrollK());
}
if (mma.getSubgroupsM() != 1 || mma.getSubgroupsN() != 1) {
suffix += llvm::formatv("_subgroups{}x{}", mma.getSubgroupsM(),
mma.getSubgroupsN());
}
return {"multi_mma", suffix};
return {"multi_mma",
stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()};
}

// Returns ukernel name and suffix for any op. Empty name = no ukernel.
Expand Down
Loading

0 comments on commit fb4d094

Please sign in to comment.