forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Codegen][IGEMM] Add new pass for IGEMM transformation with reshape p…
…ropagation (iree-org#18161) This PR adds a new pass to perform the IGEMM transformation in Codegen. The new pass uses the `Conv2DToIm2colOp` patterns plus some reshape propagation and cleanup patterns. The PR also adds a control function on the `Conv2DToIm2colOp` patterns, in order to avoid transforming configured operations. This separates the `Conv2DToIm2colOp` transformation from the codegen-specific IGEMM pipeline, and addresses an issue with fusions that requires reshape propagation. When there are consumers of the convolution op, the consumer needs to also be collapsed in order to tile and fuse it with the GEMM. Adding reshape propagation is just one solution to the fusion issue. The other potential solution is to allow the im2col op to have multiple M dimensions in its result, and create a multi-M contraction instead of the collapsed version. This second solution is ideal as long as backends are able to handle the multi-M contraction, but it requires more work to change the im2col op semantics. For now this PR fixes the issue, and the alternative solution is left as a TODO. --------- Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
- Loading branch information
Showing
11 changed files
with
251 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
104 changes: 104 additions & 0 deletions
104
compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
// Copyright 2022 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/PassDetail.h" | ||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Codegen/Transforms/Transforms.h" | ||
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" | ||
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Interfaces/FunctionInterfaces.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Pass/PassRegistry.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
namespace { | ||
|
||
using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect; | ||
|
||
class ConvolutionToIGEMMPass | ||
: public ConvolutionToIGEMMBase<ConvolutionToIGEMMPass> { | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>(); | ||
} | ||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
|
||
// Rewrite convolutions into a im2col and GEMM. | ||
{ | ||
auto conv2dToIm2colControlFn = [](Operation *conv) { | ||
// Don't transform convolutions that have a preset lowering config. | ||
if (getLoweringConfig(conv)) { | ||
return false; | ||
} | ||
return true; | ||
}; | ||
RewritePatternSet patterns(&getContext()); | ||
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns( | ||
patterns, conv2dToIm2colControlFn); | ||
if (failed(applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
|
||
// The im2col transformation collapses some of the dimensions of the | ||
// convolution operands. Try to push the reshape ops towards the boundaries | ||
// of the function and fold with interface tensor ops. | ||
// | ||
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and | ||
// generate a multi-M dim contraction instead of collapsing and | ||
// propagating reshapes. It should ultimately become a pass option to | ||
// decide whether to collapse the contraction dimensions into a single | ||
// M/N/K dimension. | ||
{ | ||
RewritePatternSet bubbleCollapseShapePatterns(context); | ||
linalg::ControlFusionFn bubbleUpExpansionControlFn = | ||
[](OpOperand *fusedOperand) { | ||
Operation *producer = fusedOperand->get().getDefiningOp(); | ||
Operation *consumer = fusedOperand->getOwner(); | ||
|
||
// Block only if one of the operations has a lowering configuration | ||
// which means it likely expects tiling specific to its original | ||
// shape. | ||
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) { | ||
return false; | ||
} | ||
return true; | ||
}; | ||
linalg::populateFoldReshapeOpsByCollapsingPatterns( | ||
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn); | ||
// Add patterns to do some additional cleanup (on top of canonicalizations | ||
// that can be done later) of reshape ops. | ||
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns); | ||
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, | ||
context); | ||
tensor::CollapseShapeOp::getCanonicalizationPatterns( | ||
bubbleCollapseShapePatterns, context); | ||
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns, | ||
context); | ||
tensor::ExpandShapeOp::getCanonicalizationPatterns( | ||
bubbleCollapseShapePatterns, context); | ||
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns); | ||
if (failed(applyPatternsAndFoldGreedily( | ||
getOperation(), std::move(bubbleCollapseShapePatterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>> | ||
createConvolutionToIGEMMPass() { | ||
return std::make_unique<ConvolutionToIGEMMPass>(); | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
compiler/src/iree/compiler/Codegen/Common/test/convolution_to_igemm.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s | ||
|
||
#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> | ||
func.func public @conv_with_consumer(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf16> { | ||
%cst = arith.constant 0.0 : f32 | ||
%empty = tensor.empty() : tensor<1x14x14x16xf32> | ||
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
%0 = linalg.conv_2d_nhwc_hwcf | ||
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } | ||
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) | ||
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
%1 = tensor.empty() : tensor<1x14x14x16xf16> | ||
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x14x14x16xf32>) outs(%1 : tensor<1x14x14x16xf16>) { | ||
^bb0(%in: f32, %out: f16): | ||
%3 = arith.truncf %in : f32 to f16 | ||
linalg.yield %3 : f16 | ||
} -> tensor<1x14x14x16xf16> | ||
return %2 : tensor<1x14x14x16xf16> | ||
} | ||
// CHECK: func.func public @conv_with_consumer | ||
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col | ||
// CHECK-SAME: : tensor<1x196x36xf32>) -> tensor<1x196x36xf32> | ||
// CHECK: %[[FILL:.+]] = linalg.fill | ||
// CHECK-SAME: -> tensor<1x196x16xf32> | ||
// CHECK: %[[MATMUL:.+]] = linalg.generic | ||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] | ||
// CHECK: %[[TRUNCF:.+]] = linalg.generic | ||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] | ||
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRUNCF]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf16> into tensor<1x14x14x16xf16> | ||
// CHECK: return %[[EXPANDED]] : tensor<1x14x14x16xf16> | ||
|
||
// ----- | ||
|
||
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [ | ||
#hal.descriptor_set.layout<0, bindings = [ | ||
#hal.descriptor_set.binding<0, storage_buffer>, | ||
#hal.descriptor_set.binding<1, storage_buffer>, | ||
#hal.descriptor_set.binding<2, storage_buffer> | ||
]> | ||
]> | ||
#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}> | ||
#map = affine_map<(d0, d1) -> (d0, d1)> | ||
module { | ||
func.func @fold_with_interface_tensor() { | ||
%c0 = arith.constant 0 : index | ||
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>> | ||
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>> | ||
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> | ||
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>> -> tensor<1x16x16x4xf32> | ||
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>> -> tensor<3x3x4x16xf32> | ||
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> -> tensor<1x14x14x16xf32> | ||
%cst = arith.constant 0.0 : f32 | ||
%fill = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
%6 = linalg.conv_2d_nhwc_hwcf | ||
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> } | ||
ins(%3, %4: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) | ||
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : tensor<1x14x14x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> | ||
return | ||
} | ||
} | ||
|
||
// CHECK: func.func @fold_with_interface_tensor | ||
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x16x16x4xf32> | ||
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<36x16xf32> | ||
// CHECK-DAG: %[[RES:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x196x16xf32> | ||
// CHECK-DAG: %[[IM2COL:.+]] = iree_linalg_ext.im2col {{.*}} ins(%[[LHS]] : tensor<1x16x16x4xf32>){{.*}}-> tensor<1x196x36xf32> | ||
// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}}outs(%[[RES]] : tensor<1x196x16xf32>) | ||
// CHECK: %[[MATMUL:.+]] = linalg.generic | ||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] | ||
// CHECK-SAME: ins(%[[IM2COL]], %[[RHS]] : tensor<1x196x36xf32>, tensor<36x16xf32>) | ||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x196x16xf32>) { | ||
// CHECK: flow.dispatch.tensor.store %[[MATMUL]] | ||
|
||
// ----- | ||
|
||
#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)> | ||
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]> | ||
func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> { | ||
%cst = arith.constant 0.0 : f32 | ||
%empty = tensor.empty() : tensor<1x14x14x16xf32> | ||
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
%0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config, | ||
dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} | ||
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>) | ||
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32> | ||
return %0 : tensor<1x14x14x16xf32> | ||
} | ||
// CHECK: func.func public @conv_with_lowering_config | ||
// CHECK-NOT: iree_linalg_ext.im2col | ||
// CHECK: linalg.conv_2d_nhwc_hwcf | ||
// CHECK-SAME: lowering_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters