Skip to content

Commit

Permalink
[NFC][Flow] Remove use of fusion preprocessing when it isnt a preproc…
Browse files Browse the repository at this point in the history
…essing (iree-org#17899)

The Fusion preprocessing pass was used in multiple places, which is
not the intent of the pass. Remove the subsequent usage.  The only
reason for this double usage was for the pattern that moved reduction
dimensions to the innermost. Consolidate that pattern with the pattern
in `InterchangeTransposeGenericPass` (whose name is very convoluted
and does not represent what it actually does).

This commit also includes the following changes:
- Rename `InterchangeTransposeGenericPass` to `TransposeGenericOpsPass`.
- Reoder the passes in `Passes.td` to be alphabetical within each of the
following
  portions
  - Dispatch region preprocessing passes
  - Dispatch region formation passes
  - General flow passes.

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored Jul 15, 2024
1 parent 0bc1518 commit 2912a2a
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 335 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ iree_compiler_cc_library(
"InjectDispatchTracing.cpp",
"InjectTensorTracing.cpp",
"InsertDispatchDebugTargets.cpp",
"InterchangeTransposeGenericOps.cpp",
"MaterializeDefaultWorkgroupCountRegion.cpp",
"OutlineConstants.cpp",
"OutlineDispatchExterns.cpp",
Expand All @@ -69,6 +68,7 @@ iree_compiler_cc_library(
"SplitReduction.cpp",
"TensorPadToTensorInsertSlice.cpp",
"TopLevelSCFToCFG.cpp",
"TransposeGenericOps.cpp",
"VerifyInputLegality.cpp",
],
hdrs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ iree_cc_library(
"InjectDispatchTracing.cpp"
"InjectTensorTracing.cpp"
"InsertDispatchDebugTargets.cpp"
"InterchangeTransposeGenericOps.cpp"
"MaterializeDefaultWorkgroupCountRegion.cpp"
"OutlineConstants.cpp"
"OutlineDispatchExterns.cpp"
Expand All @@ -69,6 +68,7 @@ iree_cc_library(
"SplitReduction.cpp"
"TensorPadToTensorInsertSlice.cpp"
"TopLevelSCFToCFG.cpp"
"TransposeGenericOps.cpp"
"VerifyInputLegality.cpp"
DEPS
::PassesIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

namespace mlir::iree_compiler::IREE::Flow {

#define GEN_PASS_DEF_CONVERTREGIONTOWORKGROUPSPASS
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h.inc"

namespace {

/// Compute the dynamic dims of the given value and add them to the vector.
Expand Down Expand Up @@ -256,26 +253,4 @@ rewriteFlowDispatchRegionToFlowDispatchWorkgroups(
return workgroupsOp;
}

namespace {
struct ConvertRegionToWorkgroupsPass
: public IREE::Flow::impl::ConvertRegionToWorkgroupsPassBase<
ConvertRegionToWorkgroupsPass> {
void runOnOperation() override {
SmallVector<IREE::Flow::DispatchRegionOp> ops;
getOperation()->walk(
[&](IREE::Flow::DispatchRegionOp op) { ops.push_back(op); });

IRRewriter rewriter(getOperation()->getContext());
for (IREE::Flow::DispatchRegionOp regionOp : ops) {
if (failed(rewriteFlowDispatchRegionToFlowDispatchWorkgroups(regionOp,
rewriter))) {
signalPassFailure();
return;
}
}
}
};

} // namespace

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,6 @@ namespace mlir::iree_compiler::IREE::Flow {

namespace {

//===----------------------------------------------------------------------===//
// GenericOpInterchangePattern
//===----------------------------------------------------------------------===//

struct GenericOpInterchangePattern
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
SmallVector<unsigned> interchange;
bool needInterchange = false;
unsigned numParallelLoop = genericOp.getNumParallelLoops();
if (numParallelLoop == 0)
return failure();
for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) {
if (linalg::isParallelIterator(iter.value())) {
interchange.push_back(iter.index());
if (iter.index() >= numParallelLoop)
needInterchange = true;
}
}
// If all the parallel loops are outter loops skip the pattern.
if (!needInterchange)
return failure();
for (auto iter : llvm::enumerate(genericOp.getIteratorTypesArray())) {
if (linalg::isReductionIterator(iter.value())) {
interchange.push_back(iter.index());
}
}
return interchangeGenericOp(rewriter, genericOp, interchange);
}
};

//===----------------------------------------------------------------------===//
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -235,8 +202,7 @@ struct FusionPreprocessingPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
FoldSuccessiveTensorInsertSliceOps,
GenericOpInterchangePattern, GatherFusionPattern>(
FoldSuccessiveTensorInsertSliceOps, GatherFusionPattern>(
&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
Expand Down
51 changes: 30 additions & 21 deletions compiler/src/iree/compiler/Dialect/Flow/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,35 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
// producer-consumer fusion.
.addPass(IREE::Flow::createSinkReshapesPass)
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass);
.addPass(mlir::createCSEPass)

// 5. After all the reshape propagations, fuse elementwise operations
// even if the producer has multiple uses.
.addPass(IREE::Flow::createFuseMultiUseElementwiseProducerPass)

// 6. Some more "post elementwise fusion passes".
// a. Detensorize.
// TODO: This is probably not in the right place.
.addPredicatedPass(clDetensoring,
[&]() { return mlir::createLinalgDetensorizePass(); })
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)

// b. For ops with multiple reduction dimensions, collapse the
// reduction dimension.
// TODO: This pass is only needed till all backends can handle
// multiple reduction dimensions.
.addPredicatedPass(clCollapseReductionDims,
IREE::Flow::createCollapseReductionDimensionsPass)

// c. Split reduction operations into parallel and reduction, i.e
// .
.addPass(IREE::Flow::createSplitReductionPass)

// d. Transpose generic ops to
// - help with dispatch region formation.
// - move reduction iterators to be innermost.
.addPass(IREE::Flow::createTransposeGenericOpsPass);
}

// Pipeline to first create `flow.dispatch.region` ops and then lower to
Expand All @@ -207,7 +235,7 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
// Create dispatches for scalar operations as roots
.addPass(IREE::Flow::createFormScalarDispatchesPass)
// Create `flow.dispatch.region` centered around a root and fuse with
// producers
// producers and consumers.
.addPass([&]() {
return IREE::Flow::createFormDispatchRegionsPass(
FormDispatchRegionsPassOptions{
Expand Down Expand Up @@ -256,25 +284,6 @@ void addDispatchRegionCreationPasses(OpPassManager &passManager,
.addPass(mlir::createCSEPass);

addDispatchRegionCreationPreprocessingPasses(passManager);

FunctionLikeNest(passManager)
.addPass(IREE::Flow::createFuseMultiUseElementwiseProducerPass)
.addPredicatedPass(clDetensoring,
[&]() { return mlir::createLinalgDetensorizePass(); })
.addPass(IREE::Flow::createCanonicalizerPass)
.addPass(mlir::createCSEPass)
.addPredicatedPass(clCollapseReductionDims,
IREE::Flow::createCollapseReductionDimensionsPass)
// Split reduction operations into parallel and reduction.
.addPass(IREE::Flow::createSplitReductionPass)
// SplitReductionPass may create reduction dimension that are not the last
// dimension.
.addPass(IREE::Flow::createFusionPreprocessingPass)
// Normalize the input indexing map to make the input indexing map
// identity. This helps fusing named linalg op with a generic op with
// transpose.
.addPass(IREE::Flow::createInterchangeTransposeGenericOpsPass);

addDispatchRegionCreationPasses(passManager);
}

Expand Down
Loading

0 comments on commit 2912a2a

Please sign in to comment.