Skip to content

Commit

Permalink
Add a Flow specific canonicalizer pass (iree-org#17836)
Browse files Browse the repository at this point in the history
Certain patterns that are borderline canonicalizations or better suited
as a canonical form for certain phases benefit from having a phase
specific canonicalization pass. The only pattern added here for now is
consecutive insert/extract slice folding which is always beneficial in
Flow, but not in Codegen.
  • Loading branch information
qedawkins authored Jul 12, 2024
1 parent f07c96c commit 05dfe0b
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_compiler_cc_library(
srcs = [
"AnnotateDispatches.cpp",
"BubbleUpExpandShapes.cpp",
"Canonicalizer.cpp",
"CaptureDynamicDims.cpp",
"CleanupTensorShapes.cpp",
"CloneProducersIntoDispatchRegions.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ iree_cc_library(
SRCS
"AnnotateDispatches.cpp"
"BubbleUpExpandShapes.cpp"
"Canonicalizer.cpp"
"CaptureDynamicDims.cpp"
"CleanupTensorShapes.cpp"
"CloneProducersIntoDispatchRegions.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"

#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_CANONICALIZERPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

namespace {

/// Canonicalize operations in nested regions.
struct CanonicalizerPass
: public impl::CanonicalizerPassBase<CanonicalizerPass> {
using IREE::Flow::impl::CanonicalizerPassBase<
CanonicalizerPass>::CanonicalizerPassBase;
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Inherit the same config defaults from the upstream canonicalizer pass.
config.useTopDownTraversal = true;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Normal;

RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);

// Some Flow specific patterns we want to pull in for common
// canonicalization.
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(owningPatterns);

patterns =
std::make_shared<FrozenRewritePatternSet>(std::move(owningPatterns));
return success();
}
void runOnOperation() override {
// Canonicalization is best-effort. Non-convergence is not a pass failure.
LogicalResult didConverge =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
if (this->testConvergence && failed(didConverge)) {
getOperation()->emitError("Canonicalizer failed to converge");
return signalPassFailure();
}
}
GreedyRewriteConfig config;
std::shared_ptr<const FrozenRewritePatternSet> patterns;
};

} // namespace

} // namespace mlir::iree_compiler::IREE::Flow
22 changes: 11 additions & 11 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ using FunctionLikeNest =
static void addCleanupPatterns(OpPassManager &passManager) {
FunctionLikeNest(passManager)
// Standard MLIR cleanup.
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// Simplify util.global accesses; this can help with data flow tracking as
Expand Down Expand Up @@ -161,14 +161,14 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
ElementwiseOpFusionPassOptions{
clEnableElementWiseFuseMultiReduction});
})
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// 2. Bubble up expand_shape ops (or sink collapse_shape ops) to get
// elementwise operation into higher dimensions for more fusion
// opportunities.
.addPass(IREE::Flow::createBubbleUpExpandShapesPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// 3. Perform elementwise operation fusion again (now with higher
Expand All @@ -178,13 +178,13 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
ElementwiseOpFusionPassOptions{
clEnableElementWiseFuseMultiReduction});
})
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// 4. After elementwise operation fusion sink reshapes that block
// producer-consumer fusion.
.addPass(IREE::Flow::createSinkReshapesPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);
}

Expand Down Expand Up @@ -231,7 +231,7 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
// acts as a contiguous view of the tensor
// - Apply tensor -> flow patterns
.addPass(IREE::Flow::createConvertTensorToFlowPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
/// Creates the workgroup count region where the materialized computation
/// is derived as a program slice of the body of the dispatch. This method
/// - Computes the `workload` to use for the `workgroupsOp`, which are
Expand All @@ -252,7 +252,7 @@ void addDispatchRegionCreationPasses(OpPassManager &passManager,
FunctionLikeNest(passManager)
// Preprocess the input to a form more amenable for fusion.
.addPass(IREE::Flow::createFusionPreprocessingPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);

addDispatchRegionCreationPreprocessingPasses(passManager);
Expand All @@ -261,7 +261,7 @@ void addDispatchRegionCreationPasses(OpPassManager &passManager,
.addPass(IREE::Flow::createFuseMultiUseElementwiseProducerPass)
.addPredicatedPass(clDetensoring,
[&]() { return mlir::createLinalgDetensorizePass(); })
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
.addPredicatedPass(clCollapseReductionDims,
IREE::Flow::createCollapseReductionDimensionsPass)
Expand Down Expand Up @@ -317,7 +317,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,

FunctionLikeNest(passManager)
.addPass(IREE::Flow::createCaptureDynamicDimsPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
.addPass([&]() {
return IREE::Flow::createInitializeEmptyTensorsPass(
Expand Down Expand Up @@ -357,7 +357,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
IREE::Util::createStripDebugOpsPass());

// Cleanup identity ops that clutter up the IR and canonicalize.
FunctionLikeNest(passManager).addPass(mlir::createCanonicalizerPass);
FunctionLikeNest(passManager).addPass(IREE::Flow::createCanonicalizerPass);

// Deduplicate executables created from dispatch regions.
// Note: this only deduplicates equivalent executables. We could in addition
Expand Down Expand Up @@ -424,7 +424,7 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
// Cleanup executable contents.
{
auto executablePassManager = passManager.nest<IREE::Flow::ExecutableOp>();
executablePassManager.addPass(mlir::createCanonicalizerPass());
executablePassManager.addPass(IREE::Flow::createCanonicalizerPass());
executablePassManager.addPass(mlir::createCSEPass());
}

Expand Down
9 changes: 9 additions & 0 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def BubbleUpExpandShapesPass :
];
}

def CanonicalizerPass :
Pass<"iree-flow-canonicalize", ""> {
let summary = "Flow specific canonicalization pass";
let options = [
Option<"testConvergence", "test-convergence", "bool",
/*default=*/"false", "Fails if the patterns fail to converge">
];
}

def CaptureDynamicDimsPass :
Pass<"iree-flow-capture-dynamic-dims", ""> {
let summary = "Captures dynamic shape dimensions required by dispatch operands/results and control flow operations.";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s --iree-flow-convert-region-to-workgroups -canonicalize -cse -split-input-file | FileCheck %s
// RUN: iree-opt %s --iree-flow-convert-region-to-workgroups --iree-flow-canonicalize -cse -split-input-file | FileCheck %s

// CHECK-LABEL: util.func public @foo(
// CHECK: %[[argA:.*]]: tensor<?x?xf32>, %[[argB:.*]]: tensor<5x10xf32>, %[[argC:.*]]: tensor<10x11xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, canonicalize, cse)" %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-convert-tensor-to-flow, canonicalize, iree-flow-materialize-default-workgroup-count-region), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s
util.func public @tile_matmul_alone(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions, iree-flow-clone-producers-into-dispatch-regions,iree-flow-convert-dispatch-regions-to-workgroups), cse, canonicalize, cse)" %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-form-dispatch-regions, iree-flow-clone-producers-into-dispatch-regions,iree-flow-convert-dispatch-regions-to-workgroups), cse, iree-flow-canonicalize, cse)" %s | FileCheck %s

util.func public @no_fuse_quantized(%arg0 : tensor<?x113x113x64xi8>, %arg1 : tensor<3x3x64xi8>,
%arg2 : i32, %arg3 : i32) -> tensor<?x56x56x64xi8> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-interchange-transpose-generic-ops,iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-convert-dispatch-regions-to-workgroups, canonicalize, cse))" --mlir-print-local-scope %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-flow-interchange-transpose-generic-ops,iree-flow-form-dispatch-regions{aggressive-fusion=true}, iree-flow-convert-dispatch-regions-to-workgroups, iree-flow-canonicalize, cse))" --mlir-print-local-scope %s | FileCheck %s

util.func @fuse_conv(%arg0 : tensor<2x130x130x16xf32>, %arg1 : tensor<3x3x16x320xf32>) -> tensor<2x320x128x128xf32> {
%empty = tensor.empty() : tensor<2x128x128x320xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-interchange-transpose-generic-ops --canonicalize -cse --mlir-print-local-scope %s | FileCheck %s
// RUN: iree-opt --split-input-file --verify-diagnostics --iree-flow-interchange-transpose-generic-ops --iree-flow-canonicalize -cse --mlir-print-local-scope %s | FileCheck %s

util.func @supported_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x320xf16>) -> tensor<2x320x128x128xf16> {
%empty = tensor.empty() : tensor<2x128x128x320xf32>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice --canonicalize %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice=skip-one-linalg-use-case --canonicalize %s | FileCheck %s --check-prefix=SKIP
// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice --iree-flow-canonicalize %s | FileCheck %s
// RUN: iree-opt --split-input-file --iree-flow-tensor-pad-to-tensor-insert-slice=skip-one-linalg-use-case --iree-flow-canonicalize %s | FileCheck %s --check-prefix=SKIP

util.func public @tensor_pad(%arg0 : tensor<?x?xf32>, %arg1 : tensor<f32>, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
Expand Down
14 changes: 7 additions & 7 deletions compiler/src/iree/compiler/GlobalOptimization/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void buildGlobalOptimizationPassPipeline(
return createFuseDequantizationMatmulPass(
clEnableQuantizedMatmulReassociation);
})
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
// Propagate transposes immediately before set encoding/data tiling
// because transpose propagation cannot take an opinion on the preferred
Expand All @@ -150,13 +150,13 @@ void buildGlobalOptimizationPassPipeline(
return createPropagateLinalgTransposePass(
transformOptions.options.aggressiveTransposePropagation);
})
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);

if (clEnableFuseHorizontalContractions) {
FunctionLikeNest(mainPassManager)
.addPass(createFuseHorizontalContractionsPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);
}

Expand All @@ -171,7 +171,7 @@ void buildGlobalOptimizationPassPipeline(
if (clEnableEarlyMaterialization) {
mainPassManager.addPass(createMaterializeHomogeneousEncodingsPass());
}
mainPassManager.addPass(createCanonicalizerPass());
mainPassManager.addPass(IREE::Flow::createCanonicalizerPass());
mainPassManager.addPass(createCSEPass());
mainPassManager.addPass(createSimplifyPackUnpackPass());
FunctionLikeNest(mainPassManager).addPass(createDataLayoutPropagationPass);
Expand All @@ -183,7 +183,7 @@ void buildGlobalOptimizationPassPipeline(
// Hoist loop invariants (e.g. from scf loops) with zero-trip-check.
FunctionLikeNest(mainPassManager)
.addPass(createGlobalLoopInvariantCodeMotionPass)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);

// Simplify util.global accesses early on; this can help with dispatch
Expand All @@ -196,7 +196,7 @@ void buildGlobalOptimizationPassPipeline(
mainPassManager.addPass(IREE::Util::createApplyPatternsPass());
mainPassManager.addPass(IREE::Util::createFoldGlobalsPass());
mainPassManager.addPass(IREE::Util::createIPOPass());
mainPassManager.addPass(createCanonicalizerPass());
mainPassManager.addPass(IREE::Flow::createCanonicalizerPass());
mainPassManager.addPass(createCSEPass());

if (transformOptions.options.constExprHoisting) {
Expand All @@ -214,7 +214,7 @@ void buildGlobalOptimizationPassPipeline(
}

FunctionLikeNest(mainPassManager)
.addPass(mlir::createCanonicalizerPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);

FunctionLikeNest(mainPassManager)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul{enable-quantized-matmul-reassociation=true},canonicalize))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-dequantization-matmul{enable-quantized-matmul-reassociation=true},iree-flow-canonicalize))" %s | FileCheck %s

util.func public @grouped_quantized_matmul_reassociate(%arg0: tensor<11008x32x128xi4>, %arg1: tensor<32x128xf32>, %arg2: tensor<11008x32xf32>, %arg3: tensor<11008x32xf32>) -> tensor<11008xf32> {
%cst = arith.constant 0.000000e+00 : f32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-silu-horizontal-matmul,canonicalize))" %s | FileCheck %s
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-global-opt-fuse-silu-horizontal-matmul,iree-flow-canonicalize))" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @silu_horizontal_matmul_fusion(%arg0: index, %arg1: tensor<?x5120xf16>, %arg2: tensor<13824x5120xf16>, %arg3: tensor<13824x5120xf16>) -> tensor<?x13824xf16> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt --iree-global-opt-raise-special-ops -canonicalize --split-input-file --mlir-print-local-scope %s | FileCheck %s
// RUN: iree-opt --iree-global-opt-raise-special-ops --iree-flow-canonicalize --split-input-file --mlir-print-local-scope %s | FileCheck %s

// CHECK-LABEL: @softmax
// CHECK-SAME: %[[ARG:.+]]: tensor<?x?x?xf32>
Expand Down

0 comments on commit 05dfe0b

Please sign in to comment.