Skip to content

Commit

Permalink
[LinalgExt] Drop the unit dims on scatter ops 2/3 (iree-org#19450)
Browse files Browse the repository at this point in the history
This change adds patterns to drop the unit dims of a
`iree_linalg_ext.scatter`'s `%updates` tensor. It only drops the leading
unit dimensions from the portion of `updates` that represents the
indexed dimensions.


See the main issue iree-org#19091

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jan 6, 2025
1 parent 0820f10 commit 340ffbb
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final
linalg::ControlFusionFn controlFoldingReshapes;
};

/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
/// The dims in `update` between the batch dims and the continuous slice
/// represent the indexed dimensions. Remove the leading unit dims from the
/// indexed dims.
struct FoldScatterNonIterationUnitDims final
: public OpRewritePattern<ScatterOp> {
FoldScatterNonIterationUnitDims(MLIRContext *context,
linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
options(std::move(options)) {}

LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
if (options.rankReductionStrategy !=
linalg::ControlDropUnitDims::RankReductionStrategy::
ReassociativeReshape) {
return rewriter.notifyMatchFailure(
scatterOp, "Only reassociative reshape strategy supported");
}
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();

// Find the number of leading unit dimensions
int64_t rankOfContiguousSlice =
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
ArrayRef<int64_t> indexedDims =
scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
int64_t numDimsToDrop =
llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
scatterOp.getUpdateSliceShape().begin() - 1;

int64_t batchRank = scatterOp.getBatchRank();
llvm::erase_if(canDrop, [&](unsigned dimPos) {
return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
});
if (canDrop.empty()) {
return failure();
}

SmallVector<int64_t> droppedUpdateShape;
droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
if (!llvm::is_contained(canDrop, idx)) {
droppedUpdateShape.push_back(dimLen);
}
}

auto reassoc =
getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
assert(reassoc.has_value() && "expected reassociation to be valid");
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
scatterOp.getLoc(),
RankedTensorType::get(droppedUpdateShape,
scatterOp.getUpdateType().getElementType()),
scatterOp.getUpdates(), reassoc.value());

rewriter.modifyOpInPlace(scatterOp, [&]() {
scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
});
return success();
}

private:
linalg::ControlDropUnitDims options;
};

} // namespace

/// Return the `reassociation` indices to use to collapse the operand when the
Expand Down Expand Up @@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), controlFoldingReshapes);
}

SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
}

void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Default function to drop unit dims for for linalgext ops.
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op);

/// Drop unit extent dims from linalg ext ops
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/DispatchCreation/Passes.h"
Expand Down Expand Up @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() {
if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
if (isa<IREE::LinalgExt::LinalgExtOp>(op)) {
return IREE::LinalgExt::defaultControlDropUnitDims(op);
}
return defaultFn(op);
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(
applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,47 @@ module @fold_stream_parameter {
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
// CHECK: util.func public @fold_stream_parameter
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>

// -----

util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter0
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x2x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter1
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

// TODO: remove other unit dims.
util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x1x4x128xf16>
util.return %0 : tensor<?x2x1x4x128xf16>
}
// CHECK-LABEL: func public @scatter_noop
// CHECK-NOT: tensor.collapse_shape
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter

0 comments on commit 340ffbb

Please sign in to comment.