Skip to content

Commit

Permalink
[LinalgExt] Retire LinalgExt::ReverseOp (iree-org#17866)
Browse files Browse the repository at this point in the history
`LinalgExt::ReverseOp` is only lowered from `stablehlo::ReverseOp`. We
can expand `stablehlo::ReverseOp` to a different pattern and retire
`LinalgExt::ReverseOp`.

Fixes iree-org#16060

---------

Signed-off-by: Alan Li <me@alanli.org>
  • Loading branch information
lialan authored Jul 19, 2024
1 parent 69900ee commit 76cad82
Show file tree
Hide file tree
Showing 19 changed files with 90 additions and 697 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// Implements IREE-specific logic for lowering StableHLO/CHLO dialects to
// LinalgExt dialect.

#include <algorithm>
#include <cmath>
#include <complex>
#include <memory>
Expand Down Expand Up @@ -427,22 +428,52 @@ struct FftOpConversion final : OpConversionPattern<mlir::stablehlo::FftOp> {
struct ReverseOpConversion final
: OpConversionPattern<mlir::stablehlo::ReverseOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(mlir::stablehlo::ReverseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto ty = dyn_cast<RankedTensorType>(adaptor.getOperands()[0].getType());
if (!ty)
return failure();

Value input = op.getOperand();
auto inputTy = cast<ShapedType>(input.getType());
auto resultTy = cast<ShapedType>(op.getType());
ArrayRef<int64_t> dims = op.getDimensions();
Location loc = op.getLoc();
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, adaptor.getOperands()[0]);
Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, ty.getElementType());
rewriter.replaceOpWithNewOp<IREE::LinalgExt::ReverseOp>(
op, typeConverter->convertType(op.getType()), adaptor.getOperands(),
emptyTensor, rewriter.getI64TensorAttr(op.getDimensions()));
int64_t inputTyRank = inputTy.getRank();

// First fill the output buffer with the init value.
SmallVector<OpFoldResult> inputMixedSizes =
tensor::getMixedSizes(rewriter, loc, input);
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, inputMixedSizes, inputTy.getElementType());
SmallVector<AffineMap> affineMaps = {
rewriter.getMultiDimIdentityMap(resultTy.getRank())};

rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultTy, ArrayRef<Value>({}), ValueRange{emptyTensor}, affineMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
llvm::SmallVector<Value> indices;
for (unsigned int i = 0; i < inputTyRank; i++) {
Value index =
rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
if (std::find(dims.begin(), dims.end(), i) != dims.end()) {
auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, i);
auto sizeMinusOne =
rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
index);
}
indices.push_back(index);
}

auto extract = nestedBuilder.create<tensor::ExtractOp>(
nestedLoc, input, indices);
nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
extract.getResult());
});
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,12 +495,17 @@ func.func @reverse_dim1(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> {
return %0 : tensor<3x5xi32>
}
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
// CHECK-SAME: ins(%[[IN]] : tensor<3x5xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
// CHECK: return %[[REV]]

// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xi32>) {
// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index
// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xi32>
// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xi32>
// CHECK: linalg.yield %[[EXTRACTED]] : i32
// CHECK: return %[[GEN]]
// -----

func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
Expand All @@ -512,13 +517,18 @@ func.func @reverse_unsigned(%arg0: tensor<3x5xui32>) -> tensor<3x5xui32> {
// CHECK-LABEL: func.func @reverse_unsigned
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[IN]] : tensor<3x5xui32> to tensor<3x5xi32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xi32>
// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
// CHECK-SAME: dimensions(dense<1> : tensor<1xi64>)
// CHECK-SAME: ins(%[[BITCAST]] : tensor<3x5xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<3x5xi32>) : tensor<3x5xi32>
// CHECK: %[[BITCAST:.+]] = builtin.unrealized_conversion_cast %[[REV]] : tensor<3x5xi32> to tensor<3x5xui32>
// CHECK: return %[[BITCAST]]
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<3x5xui32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<3x5xui32>)
// CHECK: %[[SAME_DIM:.+]] = linalg.index 0 : index
// CHECK: %[[REV_DIM:.+]] = linalg.index 1 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C1_0:.+]] = arith.constant 1 : index
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C1_0]] : tensor<3x5xui32>
// CHECK: %[[DIMSUB1:.+]] = arith.subi %[[DIM]], %[[C1]] : index
// CHECK: %[[REV_IDX:.+]] = arith.subi %[[DIMSUB1]], %[[REV_DIM]] : index
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[SAME_DIM]], %[[REV_IDX]]] : tensor<3x5xui32>
// CHECK: linalg.yield %[[EXTRACTED]] : ui32
// CHECK: return %[[GEN]]

// -----

Expand All @@ -530,16 +540,32 @@ func.func @reverse_multi_dim(%arg0: tensor<?x?xi32>) -> tensor<?x?xi32> {
} : (tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[IN]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[IN]], %[[C1]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xi32>
// CHECK: %[[REV:.+]] = iree_linalg_ext.reverse
// CHECK-SAME: dimensions(dense<[0, 1]> : tensor<2xi64>)
// CHECK-SAME: ins(%[[IN]] : tensor<?x?xi32>)
// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xi32>) : tensor<?x?xi32>
// CHECK: return %[[REV]]
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[D:.+]] = tensor.dim %[[IN]], %[[C0]] : tensor<?x?xi32>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[D0:.+]] = tensor.dim %[[IN]], %[[C1]] : tensor<?x?xi32>
// CHECK: %[[INIT:.+]] = tensor.empty(%[[D]], %[[D0]]) : tensor<?x?xi32>
// CHECK: %[[GEN:.+]] = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<?x?xi32>) {

// First reverse dimension
// CHECK: %[[IDX0:.+]] = linalg.index 0 : index
// CHECK: %[[C1_1:.+]] = arith.constant 1 : index
// CHECK: %[[C0_2:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0_2]] : tensor<?x?xi32>
// CHECK: %[[DIM0SUB1:.+]] = arith.subi %[[DIM0]], %[[C1_1]] : index
// CHECK: %[[REV_IDX0:.+]] = arith.subi %[[DIM0SUB1]], %[[IDX0]] : index

// Second reverse dimension
// CHECK: %[[IDX1:.+]] = linalg.index 1 : index
// CHECK: %[[C1_4:.+]] = arith.constant 1 : index
// CHECK: %[[C1_5:.+]] = arith.constant 1 : index
// CHECK: %[[DIM1:.+]] = tensor.dim %arg0, %[[C1_5]] : tensor<?x?xi32>
// CHECK: %[[DIM1SUB1:.+]] = arith.subi %[[DIM1]], %[[C1_4]] : index
// CHECK: %[[REV_IDX1:.+]] = arith.subi %[[DIM1SUB1]], %[[IDX1]] : index

// CHECK: %[[EXTRACTED:.+]] = tensor.extract %arg0[%[[REV_IDX0]], %[[REV_IDX1]]] : tensor<?x?xi32>
// CHECK: linalg.yield %[[EXTRACTED]] : i32
// CHECK: return %[[GEN]]

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,46 +526,6 @@ func.func @reduce_window_max_4x6xf32() {

// -----

func.func @linalg_ext_reverse_dim0() {
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<2x3xf32>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%2 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_y]
%3 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_y]
scf.for %arg0 = %2 to %c2 step %3 {
%4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
%5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
scf.for %arg1 = %4 to %c3 step %5 {
%6 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1], sizes = [2, 3], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<2x3xf32>> -> tensor<2x3xf32>
%7 = tensor.empty() : tensor<2x3xf32>
%8 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%6 : tensor<2x3xf32>) outs(%7 : tensor<2x3xf32>) : tensor<2x3xf32>
%9 = affine.apply affine_map<()[s0] -> (-s0)>()[%arg0]
flow.dispatch.tensor.store %8, %1, offsets = [%9, %arg1], sizes = [2, 3], strides = [%c1, %c1] : tensor<2x3xf32> -> !flow.dispatch.tensor<writeonly:tensor<2x3xf32>>
}
}
return
}
// CHECK: func.func @linalg_ext_reverse_dim0()
// CHECK-DAG: %[[IN:.+]] = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer)
// CHECK-DAG: %[[OUT:.+]] = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer)
// CHECK: scf.for %[[IV0:.+]] =
// CHECK: scf.for %[[IV1:.+]] =
// CHECK-DAG: %[[IN_TILE:.+]] = flow.dispatch.tensor.load %[[IN]]
// CHECK-DAG: %[[OUT_TILE:.+]] = flow.dispatch.tensor.load %[[OUT]]
// CHECK: %[[REV_TILE:.+]] = iree_linalg_ext.reverse
// CHECK-SAME: ins(%[[IN_TILE]] : tensor<2x3xf32>)
// CHECK-SAME: outs(%[[OUT_TILE]] : tensor<2x3xf32>)
// CHECK: flow.dispatch.tensor.store %[[REV_TILE]], %[[OUT]]

// -----

func.func @sort1D() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<4xi32>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2170,30 +2170,6 @@ func.func @rank_reducing_no_op_subview() {

// -----

// CHECK-LABEL: func.func @reverse_dim(
// CHECK-DAG: %[[alloc:.*]] = memref.alloc()
// CHECK-DAG: %[[cst:.*]] = bufferization.to_memref
// CHECK: iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>)
// CHECK-SAME: ins(%[[cst]] :
// CHECK-SAME: outs(%[[alloc]] :
// CHECK: %[[load:.*]] = memref.load %[[alloc]]
// CHECK: return %[[load]]
func.func @reverse_dim(%pos: index) -> f32 {
%input = arith.constant dense<[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]> : tensor<2x3xf32>

%init = bufferization.alloc_tensor() : tensor<2x3xf32>
%0 = iree_linalg_ext.reverse
dimensions(dense<0> : tensor<1xi64>)
ins(%input : tensor<2x3xf32>)
outs(%init : tensor<2x3xf32>) : tensor<2x3xf32>

%1 = tensor.extract %0[%pos, %pos] : tensor<2x3xf32>
return %1 : f32
}

// -----

// CHECK-LABEL: func.func @fft_tensor(
// CHECK: memref.alloc
// CHECK: memref.alloc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ struct LinalgExtOpInterface

bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// TODO: Revisit this for Scatter/ReverseOp. We can then get rid of
// TODO: Revisit this for ScatterOp. We can then get rid of
// `bufferizesToMemoryRead` completely.
return !isa<IREE::LinalgExt::ScatterOp, IREE::LinalgExt::ReverseOp>(op);
return !isa<IREE::LinalgExt::ScatterOp>(op);
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
Expand Down Expand Up @@ -630,8 +630,6 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
LinalgExtOpInterface<IREE::LinalgExt::PackOp>>(*ctx);
IREE::LinalgExt::UnPackOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::UnPackOp>>(*ctx);
IREE::LinalgExt::ReverseOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::ScanOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::ScanOp>>(*ctx);
IREE::LinalgExt::ScatterOp::attachInterface<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,6 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry &registry) {
OuterParallelAsPartitionableLoops<IREE::LinalgExt::ScatterOp>>(*ctx);
IREE::LinalgExt::SortOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::SortOp>>(*ctx);
IREE::LinalgExt::ReverseOp::attachInterface<
OuterParallelAsPartitionableLoops<IREE::LinalgExt::ReverseOp>>(*ctx);
IREE::LinalgExt::TopkOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::TopkOp>>(*ctx);
IREE::LinalgExt::WinogradInputTransformOp::attachInterface<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,6 @@ util.func public @linalgext_scatter_fusion() -> tensor<8192x16x8x128xf32> {
// CHECK: %[[GEN2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)



// -----


#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @linalgext_reverse_fusion() -> tensor<10x10xi32> {
%0 = tensor.empty() : tensor<10x10xi64>
%1 = tensor.empty() : tensor<10x10xi32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x10xi64>) outs(%1 : tensor<10x10xi32>) {
^bb0(%in: i64, %out: i32):
%7 = arith.trunci %in : i64 to i32
linalg.yield %7 : i32
} -> tensor<10x10xi32>
%3 = tensor.empty() : tensor<10x10xi32>
%4 = iree_linalg_ext.reverse dimensions(dense<0> : tensor<1xi64>) ins(%2 : tensor<10x10xi32>) outs(%3 : tensor<10x10xi32>) : tensor<10x10xi32>

// dont fuse with with reverse's consumer
%5 = tensor.empty() : tensor<10x10xi32>
%6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<10x10xi32>) outs(%5 : tensor<10x10xi32>) {
^bb0(%in: i32, %out: i32):
%7 = arith.addi %in, %out : i32
linalg.yield %7 : i32
} -> tensor<10x10xi32>
util.return %6 : tensor<10x10xi32>
}

// CHECK: util.func public @linalgext_reverse_fusion
// CHECK: flow.dispatch.workgroups
// CHECK: %[[SHRUNK:.+]] = linalg.generic
// CHECK: %[[REVERSED:.+]] = iree_linalg_ext.reverse
// CHECK: ins(%[[SHRUNK]] : tensor<10x10xi32>)
// CHECK: flow.dispatch.workgroups
// CHECK: %[[GEN:.+]] = linalg.generic

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
Expand Down
66 changes: 0 additions & 66 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,71 +439,6 @@ ScanOp::reifyResultShapes(OpBuilder &b,
.reifyResultShapes(b, reifiedReturnShapes);
}

//===----------------------------------------------------------------------===//
// ReverseOp
//===----------------------------------------------------------------------===//

LogicalResult ReverseOp::verify() {
Operation *op = getOperation();
if (getNumDpsInputs() != 1) {
return op->emitOpError("expected exactly one input");
}
if (getNumDpsInits() != 1) {
return op->emitOpError("expected exactly one output");
}
auto inputType = cast<ShapedType>(getInput().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
if (inputType.getElementType() != outputType.getElementType()) {
return op->emitOpError(
"expected input/output element types to be identical");
}
ArrayRef<int64_t> inputShapes = inputType.getShape();
ArrayRef<int64_t> outputShapes = outputType.getShape();
if (inputShapes.size() != outputShapes.size()) {
return op->emitOpError("expexted input/output to have identical ranks");
}
if (llvm::any_of(llvm::zip_equal(inputShapes, outputShapes),
[](std::tuple<int64_t, int64_t> s) {
return !ShapedType::isDynamic(std::get<0>(s)) &&
!ShapedType::isDynamic(std::get<1>(s)) &&
std::get<0>(s) != std::get<1>(s);
})) {
return op->emitOpError("incompatible input/output shapes");
}

int64_t rank = getOperandRank();
llvm::SmallSetVector<int64_t, 4> s;
for (auto dim : getDimensionsArray()) {
if (dim < 0 || dim >= rank) {
return op->emitOpError("all the dimensions must be within [0, ")
<< rank << ")";
}
if (s.contains(dim)) {
return op->emitOpError("expected dimensions numbers are all unique");
}
s.insert(dim);
}

return success();
}

LogicalResult
ReverseOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

SmallVector<AffineMap> ReverseOp::getIndexingMapsForOperands() {
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getOperandRank()),
/*output=*/AffineMap(nullptr)};
}

SmallVector<AffineMap> ReverseOp::getIndexingMapsForResults() {
return {AffineMap(nullptr)};
}

//===----------------------------------------------------------------------===//
// TopkOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1583,7 +1518,6 @@ Im2colOp::reifyResultShapes(OpBuilder &b,
DEFINE_OP_GET_EFFECTS(ScatterOp)
DEFINE_OP_GET_EFFECTS(SortOp)
DEFINE_OP_GET_EFFECTS(FftOp)
DEFINE_OP_GET_EFFECTS(ReverseOp)
DEFINE_OP_GET_EFFECTS(ScanOp)
DEFINE_OP_GET_EFFECTS(TopkOp)
DEFINE_OP_GET_EFFECTS(PackOp)
Expand Down
Loading

0 comments on commit 76cad82

Please sign in to comment.