From 340ffbb0995b98da5a0b4db11aacc72c2fb0f204 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 6 Jan 2025 12:51:27 -0800 Subject: [PATCH] [LinalgExt] Drop the unit dims on scatter ops 2/3 (#19450) This change adds patterns to drop the unit dims of a `iree_linalg_ext.scatter`'s `%updates` tensor. It only drops the leading unit dimensions from the portion of `updates` that represents the indexed dimensions. See the main issue https://github.com/iree-org/iree/issues/19091 --------- Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 77 +++++++++++++++++++ .../Dialect/LinalgExt/Transforms/Transforms.h | 7 ++ .../DispatchCreation/FoldUnitExtentDims.cpp | 7 ++ .../DispatchCreation/test/fold_unit_dims.mlir | 44 +++++++++++ 4 files changed, 135 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index e87efd9f2099..901288bd4788 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final linalg::ControlFusionFn controlFoldingReshapes; }; +/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand. +/// The dims in `update` between the batch dims and the continuous slice +/// represent the indexed dimensions. Remove the leading unit dims from the +/// indexed dims. +struct FoldScatterNonIterationUnitDims final + : public OpRewritePattern { + FoldScatterNonIterationUnitDims(MLIRContext *context, + linalg::ControlDropUnitDims options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + if (options.rankReductionStrategy != + linalg::ControlDropUnitDims::RankReductionStrategy:: + ReassociativeReshape) { + return rewriter.notifyMatchFailure( + scatterOp, "Only reassociative reshape strategy supported"); + } + llvm::SmallVector canDrop = options.controlFn(scatterOp); + const ArrayRef updateShape = scatterOp.getUpdateType().getShape(); + + // Find the number of leading unit dimensions + int64_t rankOfContiguousSlice = + scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth(); + ArrayRef indexedDims = + scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice); + int64_t numDimsToDrop = + llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) - + scatterOp.getUpdateSliceShape().begin() - 1; + + int64_t batchRank = scatterOp.getBatchRank(); + llvm::erase_if(canDrop, [&](unsigned dimPos) { + return dimPos < batchRank || dimPos > batchRank + numDimsToDrop; + }); + if (canDrop.empty()) { + return failure(); + } + + SmallVector droppedUpdateShape; + droppedUpdateShape.reserve(updateShape.size() - canDrop.size()); + for (auto [idx, dimLen] : llvm::enumerate(updateShape)) { + if (!llvm::is_contained(canDrop, idx)) { + droppedUpdateShape.push_back(dimLen); + } + } + + auto reassoc = + getReassociationIndicesForCollapse(updateShape, droppedUpdateShape); + assert(reassoc.has_value() && "expected reassociation to be valid"); + auto collapseOp = rewriter.create( + scatterOp.getLoc(), + RankedTensorType::get(droppedUpdateShape, + scatterOp.getUpdateType().getElementType()), + scatterOp.getUpdates(), reassoc.value()); + + rewriter.modifyOpInPlace(scatterOp, [&]() { + scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult()); + }); + return success(); + } + +private: + linalg::ControlDropUnitDims options; +}; + } // namespace /// Return the `reassociation` indices to use to collapse the operand when the @@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns( patterns.getContext(), controlFoldingReshapes); } +SmallVector defaultControlDropUnitDims(Operation *op) { + auto fusionOp = cast(op); + return llvm::to_vector(llvm::seq(0, fusionOp.getNumLoops())); +} + +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) { + patterns.add(patterns.getContext(), options); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h index 8bf84cab2574..8da0225e27ef 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps( RewritePatternSet &patterns, const linalg::ControlFusionFn &controlFusionFn); +/// Default function to drop unit dims for for linalgext ops. +SmallVector defaultControlDropUnitDims(Operation *op); + +/// Drop unit extent dims from linalg ext ops +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options); + /// Helper struct to hold the results of collapsing an operation. struct CollapseResult { SmallVector results; diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index ca08c3c04d37..ecaf21777f18 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "iree/compiler/DispatchCreation/Passes.h" @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() { if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) { return SmallVector{}; } + if (isa(op)) { + return IREE::LinalgExt::defaultControlDropUnitDims(op); + } return defaultFn(op); }; linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); + IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, + options); linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); if (failed( applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir index 249a8b1cba4b..62dbba7e59d9 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir @@ -106,3 +106,47 @@ module @fold_stream_parameter { // CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> // CHECK: util.func public @fold_stream_parameter // CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32> + +// ----- + +util.func public @scatter0(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter0 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +util.func public @scatter1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter1 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +// TODO: remove other unit dims. +util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter_noop +// CHECK-NOT: tensor.collapse_shape +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter