From f0e8cda0776973feb5a8caeca86929211733eb64 Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Fri, 9 Aug 2024 08:56:33 -0700 Subject: [PATCH] [Codegen][IGEMM] Add new pass for IGEMM transformation with reshape propagation (#18161) This PR adds a new pass to perform the IGEMM transformation in Codegen. The new pass uses the `Conv2DToIm2colOp` patterns plus some reshape propagation and cleanup patterns. The PR also adds a control function on the `Conv2DToIm2colOp` patterns, in order to avoid transforming configured operations. This separates the `Conv2DToIm2colOp` transformation from the codegen-specific IGEMM pipeline, and addresses an issue with fusions that requires reshape propagation. When there are consumers of the convolution op, the consumer needs to also be collapsed in order to tile and fuse it with the GEMM. Adding reshape propagation is just one solution to the fusion issue. The other potential solution is to allow the im2col op to have multiple M dimensions in its result, and create a multi-M contraction instead of the collapsed version. This second solution is ideal as long as backends are able to handle the multi-M contraction, but it requires more work to change the im2col op semantics. For now this PR fixes the issue, and the alternative solution is left as a TODO. --------- Signed-off-by: Max Dawkins --- .../iree/compiler/Codegen/Common/BUILD.bazel | 1 + .../compiler/Codegen/Common/CMakeLists.txt | 1 + .../Codegen/Common/ConvolutionToIGEMM.cpp | 104 ++++++++++++++++++ .../src/iree/compiler/Codegen/Common/Passes.h | 4 + .../iree/compiler/Codegen/Common/Passes.td | 7 ++ .../compiler/Codegen/Common/test/BUILD.bazel | 1 + .../Codegen/Common/test/CMakeLists.txt | 1 + .../Common/test/convolution_to_igemm.mlir | 92 ++++++++++++++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 4 +- .../Transforms/ConvertConv2DToIm2ColOp.cpp | 36 +++++- .../Dialect/LinalgExt/Transforms/Passes.h | 6 + 11 files changed, 251 insertions(+), 6 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel index 15cd82b2f145..ec3eaec9c164 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/BUILD.bazel @@ -95,6 +95,7 @@ iree_compiler_cc_library( "ConvertBf16ArithToF32.cpp", "ConvertBf16ToUInt16Buffers.cpp", "ConvertToDestinationPassingStylePass.cpp", + "ConvolutionToIGEMM.cpp", "DecomposeAffineOpsPass.cpp", "DecomposeConvolutionToLowerDimOps.cpp", "DecomposeLinalgGeneric.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 76ed4cdad4fc..44c25aa009f0 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -86,6 +86,7 @@ iree_cc_library( "ConvertBf16ArithToF32.cpp" "ConvertBf16ToUInt16Buffers.cpp" "ConvertToDestinationPassingStylePass.cpp" + "ConvolutionToIGEMM.cpp" "DecomposeAffineOpsPass.cpp" "DecomposeConvolutionToLowerDimOps.cpp" "DecomposeLinalgGeneric.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp new file mode 100644 index 000000000000..7eb600353955 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp @@ -0,0 +1,104 @@ +// Copyright 2022 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/PassDetail.h" +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Codegen/Transforms/Transforms.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler { + +namespace { + +using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect; + +class ConvolutionToIGEMMPass + : public ConvolutionToIGEMMBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + + // Rewrite convolutions into a im2col and GEMM. + { + auto conv2dToIm2colControlFn = [](Operation *conv) { + // Don't transform convolutions that have a preset lowering config. + if (getLoweringConfig(conv)) { + return false; + } + return true; + }; + RewritePatternSet patterns(&getContext()); + iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns( + patterns, conv2dToIm2colControlFn); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } + + // The im2col transformation collapses some of the dimensions of the + // convolution operands. Try to push the reshape ops towards the boundaries + // of the function and fold with interface tensor ops. + // + // TODO(Max191): Allow for the im2col op to have multiple M dimensions, and + // generate a multi-M dim contraction instead of collapsing and + // propagating reshapes. It should ultimately become a pass option to + // decide whether to collapse the contraction dimensions into a single + // M/N/K dimension. + { + RewritePatternSet bubbleCollapseShapePatterns(context); + linalg::ControlFusionFn bubbleUpExpansionControlFn = + [](OpOperand *fusedOperand) { + Operation *producer = fusedOperand->get().getDefiningOp(); + Operation *consumer = fusedOperand->getOwner(); + + // Block only if one of the operations has a lowering configuration + // which means it likely expects tiling specific to its original + // shape. + if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { + return false; + } + return true; + }; + linalg::populateFoldReshapeOpsByCollapsingPatterns( + bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); + // Add patterns to do some additional cleanup (on top of canonicalizations + // that can be done later) of reshape ops. + tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); + linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::CollapseShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, + context); + tensor::ExpandShapeOp::getCanonicalizationPatterns( + bubbleCollapseShapePatterns, context); + populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); + if (failed(applyPatternsAndFoldGreedily( + getOperation(), std::move(bubbleCollapseShapePatterns)))) { + return signalPassFailure(); + } + } + } +}; + +} // namespace + +std::unique_ptr> +createConvolutionToIGEMMPass() { + return std::make_unique(); +} + +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.h b/compiler/src/iree/compiler/Codegen/Common/Passes.h index 2880477d0a2b..ade155049cfa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.h +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.h @@ -86,6 +86,10 @@ std::unique_ptr> createConvertToDestinationPassingStylePass( bool useWARForCooperativeMatrixCodegen = false); +/// Converts convolution operations to a GEMM with an im2col op on the image. +std::unique_ptr> +createConvolutionToIGEMMPass(); + // Decompose affine.apply operations into sub affine.apply that can be // hoisted in different loops. std::unique_ptr createDecomposeAffineOpsPass(); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index ed182941c372..5bda8d1d3827 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -70,6 +70,13 @@ def ConvertToDestinationPassingStyle : ]; } +def ConvolutionToIGEMM : + InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> { + let summary = + "Transforms convolution operations into an implicit GEMM format."; + let constructor = "mlir::iree_compiler::createConvolutionToIGEMMPass()"; +} + def DecomposeAffineOps: Pass<"decompose-affine-ops"> { let summary = "Decompose `affine.apply` operations into sub `affine.apply`"; let description = [{ diff --git a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel index f0e3a8e9f8ad..9651d49fbb11 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "convert_bf16_to_uint16_buffers.mlir", "convert_bf16_arith_to_f32.mlir", "convert_to_destination_passing_style.mlir", + "convolution_to_igemm.mlir", "convolutions.mlir", "erase_dead_alloc_and_stores.mlir", "decompose_affine_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt index 6f1dd785049a..d2b97e2e0a8e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt @@ -23,6 +23,7 @@ iree_lit_test_suite( "convert_bf16_arith_to_f32.mlir" "convert_bf16_to_uint16_buffers.mlir" "convert_to_destination_passing_style.mlir" + "convolution_to_igemm.mlir" "convolutions.mlir" "decompose_affine_ops.mlir" "decompose_conv2d.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir new file mode 100644 index 000000000000..46f30fe01b3c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir @@ -0,0 +1,92 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s + +#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> +func.func public @conv_with_consumer(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf16> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x14x14x16xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %0 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %1 = tensor.empty() : tensor<1x14x14x16xf16> + %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x14x14x16xf32>) outs(%1 : tensor<1x14x14x16xf16>) { + ^bb0(%in: f32, %out: f16): + %3 = arith.truncf %in : f32 to f16 + linalg.yield %3 : f16 + } -> tensor<1x14x14x16xf16> + return %2 : tensor<1x14x14x16xf16> +} +// CHECK: func.func public @conv_with_consumer +// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col +// CHECK-SAME: : tensor<1x196x36xf32>) -> tensor<1x196x36xf32> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: -> tensor<1x196x16xf32> +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK: %[[TRUNCF:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRUNCF]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf16> into tensor<1x14x14x16xf16> +// CHECK: return %[[EXPANDED]] : tensor<1x14x14x16xf16> + +// ----- + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer>, + #hal.descriptor_set.binding<2, storage_buffer> + ]> +]> +#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}> +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fold_with_interface_tensor() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x16x16x4xf32> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<3x3x4x16xf32> + %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<1x14x14x16xf32> + %cst = arith.constant 0.0 : f32 + %fill = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %6 = linalg.conv_2d_nhwc_hwcf + {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } + ins(%3, %4: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : tensor<1x14x14x16xf32> -> !flow.dispatch.tensor> + return + } +} + +// CHECK: func.func @fold_with_interface_tensor +// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x16x16x4xf32> +// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<36x16xf32> +// CHECK-DAG: %[[RES:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x196x16xf32> +// CHECK-DAG: %[[IM2COL:.+]] = iree_linalg_ext.im2col {{.*}} ins(%[[LHS]] : tensor<1x16x16x4xf32>){{.*}}-> tensor<1x196x36xf32> +// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}}outs(%[[RES]] : tensor<1x196x16xf32>) +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: ins(%[[IM2COL]], %[[RHS]] : tensor<1x196x36xf32>, tensor<36x16xf32>) +// CHECK-SAME: outs(%[[FILL]] : tensor<1x196x16xf32>) { +// CHECK: flow.dispatch.tensor.store %[[MATMUL]] + +// ----- + +#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> +#config = #iree_codegen.lowering_config +func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { + %cst = arith.constant 0.0 : f32 + %empty = tensor.empty() : tensor<1x14x14x16xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + %0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, + dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) + outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> + return %0 : tensor<1x14x14x16xf32> +} +// CHECK: func.func public @conv_with_lowering_config +// CHECK-NOT: iree_linalg_ext.im2col +// CHECK: linalg.conv_2d_nhwc_hwcf +// CHECK-SAME: lowering_config diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 957813aa8ce9..f116b97479dd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1054,8 +1054,8 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl( OpPassManager &modulePassManager) { { FunctionLikeNest funcPassManager(modulePassManager); - funcPassManager.addPredicatedPass( - clLLVMGPUUseIgemm, IREE::LinalgExt::createConvertConv2DToIm2ColOpPass); + funcPassManager.addPredicatedPass(clLLVMGPUUseIgemm, + createConvolutionToIGEMMPass); funcPassManager.addPass(createGPUGeneralizeNamedOpsPass); addCommonTargetExecutablePreprocessingPasses(funcPassManager); addEncodingToNopPasses(funcPassManager); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp index a5bb42a6404b..0e7b3b7708d6 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ConvertConv2DToIm2ColOp.cpp @@ -37,6 +37,8 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) { namespace { +using ControlFnTy = std::optional>; + // Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) // and linalg.matmul. // @@ -75,8 +77,16 @@ class ConvertConv2DNhwcHwcf final public: using OpRewritePattern::OpRewritePattern; + ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn) + : OpRewritePattern(context), + controlFn(controlFn) {} + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const override { + if (controlFn.has_value() && !controlFn.value()(convOp)) { + return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + } + auto inputType = llvm::cast(convOp.getInputs()[0].getType()); auto filterType = llvm::cast(convOp.getInputs()[1].getType()); auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); @@ -181,6 +191,9 @@ class ConvertConv2DNhwcHwcf final return success(); } + +private: + ControlFnTy controlFn; }; // For nchw, because the channels are to the left of the image shape dimensions, @@ -192,8 +205,16 @@ class ConvertConv2DNchwFchw final public: using OpRewritePattern::OpRewritePattern; + ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn) + : OpRewritePattern(context), + controlFn(controlFn) {} + LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp, PatternRewriter &rewriter) const override { + if (controlFn.has_value() && !controlFn.value()(convOp)) { + return rewriter.notifyMatchFailure(convOp, "controlFn failed."); + } + auto inputType = llvm::cast(convOp.getInputs()[0].getType()); auto filterType = llvm::cast(convOp.getInputs()[1].getType()); auto outputType = llvm::cast(convOp.getOutputs()[0].getType()); @@ -296,18 +317,19 @@ class ConvertConv2DNchwFchw final return success(); } + +private: + ControlFnTy controlFn; }; struct ConvertConv2DToIm2ColOpPass : ConvertConv2DToIm2ColOpBase { void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { - MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); - patterns.insert(context); + populateConv2DToIm2colOpPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); @@ -317,6 +339,12 @@ struct ConvertConv2DToIm2ColOpPass } // namespace +void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns, + ControlFnTy controlFn) { + patterns.insert( + patterns.getContext(), std::move(controlFn)); +} + std::unique_ptr> createConvertConv2DToIm2ColOpPass() { return std::make_unique(); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h index c5dcdee0ca0a..04c1e8f68c39 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Passes.h @@ -49,6 +49,12 @@ createDecomposeWinogradTransformPass(); std::unique_ptr> createConvertConv2DToIm2ColOpPass(); +// Patterns to convert linalg convolution ops into a gemm with an im2col +// op and reshapes on the inputs. +void populateConv2DToIm2colOpPatterns( + RewritePatternSet &patterns, + std::optional> controlFn = std::nullopt); + // Creates a pass to convert linalg convolution ops into a sequence of // linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd // tranformation.