Skip to content

Commit

Permalink
[LLVMGPU] Add KernelConfig for subgroup reduction attention pipeline (i…
Browse files Browse the repository at this point in the history
…ree-org#19427)

Adds a dumb kernel config selection as a fallback for attention when it
cannot target intrinsics or is too skinny to use intrinsics. The config
logic is very simply right now, and will improve with time.

Also adds a fix for vector.contract lowering when promotion is needed.
  • Loading branch information
Groverkss authored Dec 11, 2024
1 parent a6da532 commit c315833
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 10 deletions.
220 changes: 214 additions & 6 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,10 +717,9 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
targetSubgroupSize, pipelineConfig);
}

static LogicalResult
setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
IREE::LinalgExt::AttentionOp op) {
static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig(
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
IREE::LinalgExt::AttentionOp op) {
if (target.getWgp().getMma().empty())
return failure();

Expand Down Expand Up @@ -753,8 +752,10 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
return failure();
}

// TODO: Do we need a matvec-like attention pipeline? Probably not,
// considering M is generally the largest dimension.
// Bail out on skinny attention.
if (bounds[mDim] <= kVerySkinnyDimThreshold) {
return failure();
}

Value qMatrix = op.getQuery();
Value kMatrix = op.getKey();
Expand Down Expand Up @@ -974,6 +975,209 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
workgroupSize, targetSubgroupSize, pipelineConfig);
}

static IREE::GPU::Basis projectBasis(const IREE::GPU::Basis &basis,
ArrayRef<int64_t> projectedDims) {
// Projection simply involves projecting the mapping and keeping the counts.
IREE::GPU::Basis projectedBasis;
projectedBasis.counts = basis.counts;
SetVector<int64_t> projected(projectedDims.begin(), projectedDims.end());
for (auto [dim, map] : llvm::enumerate(basis.mapping)) {
if (projected.contains(dim)) {
continue;
}
projectedBasis.mapping.push_back(map);
}
return projectedBasis;
}

static LogicalResult
setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
IREE::LinalgExt::AttentionOp op) {
// This configuration is not really smart right now. It just makes sure that
// attention always compiles and tries to distribute workload on threads,
// subgroups and workgroups as much as it can.
// TODO: Update this configuration with target information, like the
// WarpReduction pipeline does.
const int64_t targetSubgroupSize = target.getPreferredSubgroupSize();

// Get iteration domain bounds.
OpBuilder b(op);
FailureOr<SmallVector<int64_t>> maybeBounds = op.getStaticLoopRanges();
if (failed(maybeBounds)) {
return failure();
}

SmallVector<int64_t> bounds = maybeBounds.value();

auto opInfo =
IREE::LinalgExt::AttentionOpDetail::get(
op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
.value();

SmallVector<int64_t> parallelDims;
SmallVector<int64_t> reductionDims;
for (auto [dim, itType] : llvm::enumerate(op.getLoopIteratorTypes())) {
switch (itType) {
case utils::IteratorType::parallel:
parallelDims.push_back(dim);
break;
case utils::IteratorType::reduction:
reductionDims.push_back(dim);
break;
}
}

auto distributeDimensionsToBasis = [&bounds](int64_t available,
ArrayRef<int64_t> dims,
IREE::GPU::Basis &basis) {
for (int64_t dim : dims) {
basis.mapping[dim] = dim;
int64_t dimSize = bounds[dim];
if (ShapedType::isDynamic(dimSize)) {
basis.counts[dim] = 1;
continue;
}
int64_t used = std::gcd(available, dimSize);
available /= used;
bounds[dim] /= used;
basis.counts[dim] = used;
}
return available;
};

SmallVector<int64_t> workgroupTileSizes(opInfo.getDomainRank(), 0);
// Distribute all batch dimensions to workgroups.
for (int64_t dim : opInfo.getBatchDims()) {
workgroupTileSizes[dim] = 1;
bounds[dim] = 1;
}

IREE::GPU::Basis threadBasis = {
SmallVector<int64_t>(opInfo.getDomainRank(), 1),
SmallVector<int64_t>(opInfo.getDomainRank())};
int64_t remainingThreads = targetSubgroupSize;
if (!target.supportsSubgroupShuffle()) {
// If target does not support subgroup shuffles, don't distribute threads on
// reduction dimensions.
distributeDimensionsToBasis(1, reductionDims, threadBasis);
} else {
remainingThreads = distributeDimensionsToBasis(remainingThreads,
reductionDims, threadBasis);
}
remainingThreads =
distributeDimensionsToBasis(remainingThreads, parallelDims, threadBasis);

IREE::GPU::Basis subgroupBasis = {
SmallVector<int64_t>(opInfo.getDomainRank(), 1),
SmallVector<int64_t>(opInfo.getDomainRank())};
int64_t remainingSubgroups = target.getWgp().getSimdsPerWgp().value_or(1);
// TODO: We cannot distribute subgroups on reduction dimensions yet, because
// VectorDistribution does not know how to do workgroup reduction right now.
distributeDimensionsToBasis(1, reductionDims, subgroupBasis);
remainingSubgroups = distributeDimensionsToBasis(remainingSubgroups,
parallelDims, subgroupBasis);

LDBG("Thread Basis");
LLVM_DEBUG({
llvm::interleaveComma(threadBasis.counts, llvm::dbgs());
llvm::dbgs() << "\n";
llvm::interleaveComma(threadBasis.mapping, llvm::dbgs());
llvm::dbgs() << "\n";
});
LDBG("Subgroup Basis");
LLVM_DEBUG({
llvm::interleaveComma(subgroupBasis.counts, llvm::dbgs());
llvm::dbgs() << "\n";
llvm::interleaveComma(subgroupBasis.mapping, llvm::dbgs());
llvm::dbgs() << "\n";
});

// Tile remaining parallel dimensions to workgroups.
for (int64_t dim : parallelDims) {
if (ShapedType::isDynamic(dim)) {
workgroupTileSizes[dim] = 1;
}
if (bounds[dim] != 1) {
int64_t threadCount = threadBasis.counts[threadBasis.mapping[dim]];
int64_t subgroupCount = subgroupBasis.counts[subgroupBasis.mapping[dim]];
workgroupTileSizes[dim] = threadCount * subgroupCount;
}
}

// Tile remaining reduction dimensions to serial loops.
SmallVector<int64_t> reductionTileSizes(opInfo.getDomainRank(), 0);
for (int64_t dim : opInfo.getK2Dims()) {
if (ShapedType::isDynamic(dim)) {
reductionTileSizes[dim] = 1;
}
if (bounds[dim] != 1) {
int64_t threadCount = threadBasis.counts[threadBasis.mapping[dim]];
int64_t subgroupCount = subgroupBasis.counts[subgroupBasis.mapping[dim]];
reductionTileSizes[dim] = threadCount * subgroupCount;
}
}

int64_t flatWorkgroupSize =
targetSubgroupSize * ShapedType::getNumElements(subgroupBasis.counts);
std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};

MLIRContext *context = op.getContext();

SmallVector<NamedAttribute, 2> attrs;
attrs.emplace_back(StringAttr::get(context, "workgroup"),
b.getI64ArrayAttr(workgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "reduction"),
b.getI64ArrayAttr(reductionTileSizes));

SmallVector<NamedAttribute> qkConfig;
IREE::GPU::setBasis(context, qkConfig, IREE::GPU::TilingLevel::Subgroup,
projectBasis(subgroupBasis, opInfo.getNDims()));
IREE::GPU::setBasis(context, qkConfig, IREE::GPU::TilingLevel::Thread,
projectBasis(threadBasis, opInfo.getNDims()));

SmallVector<NamedAttribute> pvConfig;
IREE::GPU::setBasis(context, pvConfig, IREE::GPU::TilingLevel::Subgroup,
projectBasis(subgroupBasis, opInfo.getK1Dims()));
IREE::GPU::setBasis(context, pvConfig, IREE::GPU::TilingLevel::Thread,
projectBasis(threadBasis, opInfo.getK1Dims()));

SmallVector<NamedAttribute, 2> qkAttrs;
SmallVector<NamedAttribute, 2> pvAttrs;

auto qkConfigDict = b.getDictionaryAttr(qkConfig);
auto pvConfigDict = b.getDictionaryAttr(pvConfig);

auto qkLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, qkConfigDict);
auto pvLoweringConfig =
IREE::GPU::LoweringConfigAttr::get(context, pvConfigDict);

qkAttrs.emplace_back(b.getNamedAttr("lowering_config", qkLoweringConfig));
pvAttrs.emplace_back(b.getNamedAttr("lowering_config", pvLoweringConfig));

auto qkAttrDict = b.getDictionaryAttr(qkAttrs);
auto pvAttrDict = b.getDictionaryAttr(pvAttrs);

SmallVector<NamedAttribute, 2> decompositionConfig;
decompositionConfig.emplace_back(
b.getNamedAttr(IREE::LinalgExt::AttentionOp::getQKAttrStr(), qkAttrDict));
decompositionConfig.emplace_back(
b.getNamedAttr(IREE::LinalgExt::AttentionOp::getPVAttrStr(), pvAttrDict));

// Set attention decomposition control config.
op.setDecompositionConfigAttr(b.getDictionaryAttr(decompositionConfig));

auto configDict = b.getDictionaryAttr(attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

return setOpConfigAndEntryPointFnTranslation(
entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
workgroupSize, targetSubgroupSize);

return success();
}

static LogicalResult
setVectorDistributionConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Expand Down Expand Up @@ -1004,6 +1208,10 @@ setVectorDistributionConfig(IREE::GPU::TargetAttr target,

if (auto attnOp = dyn_cast<IREE::LinalgExt::AttentionOp>(computeOp)) {
LDBG("VectorDistribution: trying to find a suitable attention config");
if (succeeded(setAttentionIntrinsicBasedVectorDistributionConfig(
target, entryPoint, attnOp))) {
return success();
}
return setAttentionVectorDistributionConfig(target, entryPoint, attnOp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,56 @@ namespace mlir::iree_compiler {
//====---------------------------------------------------------------------===//

namespace {

struct PromoteContractOperands final
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
Type operandElType = getElementTypeOrSelf(contractOp.getLhsType());
Type resultElType = getElementTypeOrSelf(contractOp.getResultType());

if (operandElType == resultElType) {
return failure();
}

Location loc = contractOp.getLoc();
Value lhs =
promoteToElementType(loc, rewriter, contractOp.getLhs(), resultElType);
Value rhs =
promoteToElementType(loc, rewriter, contractOp.getRhs(), resultElType);

rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(), contractOp.getIndexingMaps(),
contractOp.getIteratorTypes());

return success();
}

Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
Type dstElementType) const {
Type elementType = getElementTypeOrSelf(v.getType());
if (elementType == dstElementType)
return v;

// vector.contract only allows extension on operands.
assert(elementType.getIntOrFloatBitWidth() <=
dstElementType.getIntOrFloatBitWidth() &&
"vector.contract does not allow truncation of operands");

Type promotedType = dstElementType;
if (auto vecType = dyn_cast<VectorType>(v.getType()))
promotedType = vecType.clone(promotedType);

if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
// For integer types, vector.contract only supports signless integer types
// and promotion happens via sign extension.
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
};

struct LLVMGPUVectorLoweringPass final
: impl::LLVMGPUVectorLoweringPassBase<LLVMGPUVectorLoweringPass> {
void getDependentDialects(DialectRegistry &registry) const override {
Expand All @@ -48,6 +98,8 @@ struct LLVMGPUVectorLoweringPass final
contractLoweringPatterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
contractLoweringPatterns.add<PromoteContractOperands>(
funcOp->getContext());
vector::populateVectorMaskOpLoweringPatterns(contractLoweringPatterns);
vector::populateVectorShapeCastLoweringPatterns(contractLoweringPatterns);
vector::populateVectorMultiReductionLoweringPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"config_tile_and_fuse.mlir",
"config_vector_distribute_gfx1100.mlir",
"config_vector_distribute_gfx942.mlir",
"config_vector_distribute_reduction_gfx942.mlir",
"config_user_vector_distribute.mlir",
"lowering_scalar_dispatch.mlir",
"pipeline_igemm_tile_and_fuse.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"config_user_vector_distribute.mlir"
"config_vector_distribute_gfx1100.mlir"
"config_vector_distribute_gfx942.mlir"
"config_vector_distribute_reduction_gfx942.mlir"
"lowering_scalar_dispatch.mlir"
"pipeline_igemm_tile_and_fuse.mlir"
"pipeline_tile_and_fuse.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --iree-codegen-llvmgpu-use-vector-distribution \
// RUN: --iree-codegen-llvmgpu-use-igemm=false \
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s

// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>
func.func @attention_20x1x64x4096x64() {
%cst = arith.constant 1.250000e-01 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x1x64xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<20x1x64xf16>>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [20, 1, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x1x64xf16>> -> tensor<20x1x64xf16>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0], sizes = [20, 4096, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<20x4096x64xf16>> -> tensor<20x4096x64xf16>
%7 = tensor.empty() : tensor<20x1x64xf16>
%8 = iree_linalg_ext.attention {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>]}
ins(%4, %5, %6, %cst : tensor<20x1x64xf16>, tensor<20x4096x64xf16>, tensor<20x4096x64xf16>, f16) outs(%7 : tensor<20x1x64xf16>) {
^bb0(%score: f32):
iree_linalg_ext.yield %score : f32
} -> tensor<20x1x64xf16>
flow.dispatch.tensor.store %8, %3, offsets = [0, 0, 0], sizes = [20, 1, 64], strides = [1, 1, 1] : tensor<20x1x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<20x1x64xf16>>
return
}

// CHECK: decomposition_config =
// CHECK-SAME: pv_attrs =
// CHECK-SAME: #iree_gpu.lowering_config
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 4], [0, 1, 3, 4]{{\]}}
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 64, 1, 1], [0, 1, 3, 4]{{\]}}
// CHECK-SAME: qk_attrs =
// CHECK-SAME: #iree_gpu.lowering_config
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 4], [0, 1, 2, 3]{{\]}}
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 64, 1, 1], [0, 1, 2, 3]{{\]}}
// CHECK-SAME: lowering_config =
// CHECK-SAME: #iree_gpu.lowering_config
// CHECK-SAME: reduction = [0, 0, 0, 1, 0]
// CHECK-SAME: workgroup = [1, 0, 0, 0, 4]
Loading

0 comments on commit c315833

Please sign in to comment.