Skip to content

Commit

Permalink
[Codegen][IGEMM] Add new pass for IGEMM transformation with reshape p…
Browse files Browse the repository at this point in the history
…ropagation (iree-org#18161)

This PR adds a new pass to perform the IGEMM transformation in Codegen.
The new pass uses the `Conv2DToIm2colOp` patterns plus some reshape
propagation and cleanup patterns. The PR also adds a control function on
the `Conv2DToIm2colOp` patterns, in order to avoid transforming
configured operations.

This separates the `Conv2DToIm2colOp` transformation from the
codegen-specific IGEMM pipeline, and addresses an issue with fusions
that requires reshape propagation. When there are consumers of the
convolution op, the consumer needs to also be collapsed in order to tile
and fuse it with the GEMM.

Adding reshape propagation is just one solution to the fusion issue. The
other potential solution is to allow the im2col op to have multiple M
dimensions in its result, and create a multi-M contraction instead of
the collapsed version. This second solution is ideal as long as backends
are able to handle the multi-M contraction, but it requires more work to
change the im2col op semantics. For now this PR fixes the issue, and the
alternative solution is left as a TODO.

---------

Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
  • Loading branch information
Max191 authored Aug 9, 2024
1 parent 1fddcd6 commit f0e8cda
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 6 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_compiler_cc_library(
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_cc_library(
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
Expand Down
104 changes: 104 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

namespace {

using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;

class ConvolutionToIGEMMPass
: public ConvolutionToIGEMMBase<ConvolutionToIGEMMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();

// Rewrite convolutions into a im2col and GEMM.
{
auto conv2dToIm2colControlFn = [](Operation *conv) {
// Don't transform convolutions that have a preset lowering config.
if (getLoweringConfig(conv)) {
return false;
}
return true;
};
RewritePatternSet patterns(&getContext());
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
patterns, conv2dToIm2colControlFn);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}

// The im2col transformation collapses some of the dimensions of the
// convolution operands. Try to push the reshape ops towards the boundaries
// of the function and fold with interface tensor ops.
//
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
// generate a multi-M dim contraction instead of collapsing and
// propagating reshapes. It should ultimately become a pass option to
// decide whether to collapse the contraction dimensions into a single
// M/N/K dimension.
{
RewritePatternSet bubbleCollapseShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

// Block only if one of the operations has a lowering configuration
// which means it likely expects tiling specific to its original
// shape.
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
return false;
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(bubbleCollapseShapePatterns)))) {
return signalPassFailure();
}
}
}
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvolutionToIGEMMPass() {
return std::make_unique<ConvolutionToIGEMMPass>();
}

} // namespace mlir::iree_compiler
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvertToDestinationPassingStylePass(
bool useWARForCooperativeMatrixCodegen = false);

/// Converts convolution operations to a GEMM with an im2col op on the image.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass();

// Decompose affine.apply operations into sub affine.apply that can be
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def ConvertToDestinationPassingStyle :
];
}

def ConvolutionToIGEMM :
InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
let summary =
"Transforms convolution operations into an implicit GEMM format.";
let constructor = "mlir::iree_compiler::createConvolutionToIGEMMPass()";
}

def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
let summary = "Decompose `affine.apply` operations into sub `affine.apply`";
let description = [{
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_conv2d.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s

#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
func.func public @conv_with_consumer(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf16> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%0 = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%1 = tensor.empty() : tensor<1x14x14x16xf16>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x14x14x16xf32>) outs(%1 : tensor<1x14x14x16xf16>) {
^bb0(%in: f32, %out: f16):
%3 = arith.truncf %in : f32 to f16
linalg.yield %3 : f16
} -> tensor<1x14x14x16xf16>
return %2 : tensor<1x14x14x16xf16>
}
// CHECK: func.func public @conv_with_consumer
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: -> tensor<1x196x16xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK: %[[TRUNCF:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRUNCF]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf16> into tensor<1x14x14x16xf16>
// CHECK: return %[[EXPANDED]] : tensor<1x14x14x16xf16>

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @fold_with_interface_tensor() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>> -> tensor<1x16x16x4xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>> -> tensor<3x3x4x16xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> -> tensor<1x14x14x16xf32>
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%6 = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%3, %4: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : tensor<1x14x14x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
return
}
}

// CHECK: func.func @fold_with_interface_tensor
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x16x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<36x16xf32>
// CHECK-DAG: %[[RES:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x196x16xf32>
// CHECK-DAG: %[[IM2COL:.+]] = iree_linalg_ext.im2col {{.*}} ins(%[[LHS]] : tensor<1x16x16x4xf32>){{.*}}-> tensor<1x196x36xf32>
// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}}outs(%[[RES]] : tensor<1x196x16xf32>)
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[RHS]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x196x16xf32>) {
// CHECK: flow.dispatch.tensor.store %[[MATMUL]]

// -----

#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
// CHECK: func.func public @conv_with_lowering_config
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering_config
4 changes: 2 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1054,8 +1054,8 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
OpPassManager &modulePassManager) {
{
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPredicatedPass(
clLLVMGPUUseIgemm, IREE::LinalgExt::createConvertConv2DToIm2ColOpPass);
funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm,
createConvolutionToIGEMMPass);
funcPassManager.addPass(createGPUGeneralizeNamedOpsPass);
addCommonTargetExecutablePreprocessingPasses(funcPassManager);
addEncodingToNopPasses(funcPassManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {

namespace {

using ControlFnTy = std::optional<std::function<bool(Operation *)>>;

// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
// and linalg.matmul.
//
Expand Down Expand Up @@ -75,8 +77,16 @@ class ConvertConv2DNhwcHwcf final
public:
using OpRewritePattern::OpRewritePattern;

ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn)
: OpRewritePattern<linalg::Conv2DNhwcHwcfOp>(context),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
if (controlFn.has_value() && !controlFn.value()(convOp)) {
return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
}

auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
Expand Down Expand Up @@ -181,6 +191,9 @@ class ConvertConv2DNhwcHwcf final

return success();
}

private:
ControlFnTy controlFn;
};

// For nchw, because the channels are to the left of the image shape dimensions,
Expand All @@ -192,8 +205,16 @@ class ConvertConv2DNchwFchw final
public:
using OpRewritePattern::OpRewritePattern;

ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn)
: OpRewritePattern<linalg::Conv2DNchwFchwOp>(context),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
PatternRewriter &rewriter) const override {
if (controlFn.has_value() && !controlFn.value()(convOp)) {
return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
}

auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
Expand Down Expand Up @@ -296,18 +317,19 @@ class ConvertConv2DNchwFchw final

return success();
}

private:
ControlFnTy controlFn;
};

struct ConvertConv2DToIm2ColOpPass
: ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<tensor::TensorDialect, IREE::LinalgExt::IREELinalgExtDialect>();
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(context);
populateConv2DToIm2colOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand All @@ -317,6 +339,12 @@ struct ConvertConv2DToIm2ColOpPass

} // namespace

void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
ControlFnTy controlFn) {
patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(
patterns.getContext(), std::move(controlFn));
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass() {
return std::make_unique<ConvertConv2DToIm2ColOpPass>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ createDecomposeWinogradTransformPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass();

// Patterns to convert linalg convolution ops into a gemm with an im2col
// op and reshapes on the inputs.
void populateConv2DToIm2colOpPatterns(
RewritePatternSet &patterns,
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);

// Creates a pass to convert linalg convolution ops into a sequence of
// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
// tranformation.
Expand Down

0 comments on commit f0e8cda

Please sign in to comment.