Skip to content

Commit

Permalink
[Dispatch] Disable UnpackLikeOp+ExtractSlice fusion. (iree-org#19408)
Browse files Browse the repository at this point in the history
It is no longer needed because unset_encoding ops carries the slicing
semantics. Instead of adding complexity on the checks (whether the
consumer is rank-reducing slice or not), we can disable the fusion at
all.

The revision updates the test cases that were created before the
unset_encoding evolution and add a negative test for the issue.

Fixes iree-org#19386

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Dec 8, 2024
1 parent 62903cc commit 39c56de
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,28 +182,9 @@ static bool isPackLikeOp(Operation *op) {
return isa<IREE::Encoding::SetEncodingOp, tensor::PackOp>(op);
}

/// Returns true if the operation is an `unpack` op or an `unset_encoding` op,
/// or an `extract_slice` op whose source operand matches those criteria,
/// recursively.
/// The idea is that we want to ensure that `extract_slice` ops can't prevent
/// fusion between a `unset_encoding` producer and some linalg consumer. In
/// %0 = unset_encoding ...
/// %1 = extract_slice %0 ...
/// %2 = linalg.generic ins(%1) ...
/// we are not content to be fusing %1 into %0, we also want to be fusing %2,
/// so we want to prevent %1 from acting as a consumer fusion barrier.
static bool isUnpackLikeOpViaExtractSliceOps(Operation *op) {
if (isa<IREE::Encoding::UnsetEncodingOp, tensor::UnPackOp>(op)) {
return true;
}
if (isa<tensor::ExtractSliceOp>(op)) {
Value source = op->getOperand(0);
Operation *producer = source.getDefiningOp();
if (isUnpackLikeOpViaExtractSliceOps(producer)) {
return true;
}
}
return false;
/// Returns true if the operation is an `unpack` op or an `unset_encoding` op.
static bool isUnpackLikeOp(Operation *op) {
return isa<IREE::Encoding::UnsetEncodingOp, tensor::UnPackOp>(op);
}

/// Since `iree_encoding.set_encoding` doesnt have padding semantics a
Expand Down Expand Up @@ -476,18 +457,7 @@ isFusableWithConsumer(OpOperand &fusedOperand,

// Fuse unset_encoding operations with `tensor.extract_slice` and elementwise
// generic ops.
if (isUnpackLikeOpViaExtractSliceOps(producer)) {
// Fuse `unset_encoding` -> `extract_slice` op since they get folded into
// `unpack` on materialization.
if (isa<tensor::ExtractSliceOp>(consumer)) {
auto sliceOp = cast<tensor::ExtractSliceOp>(consumer);
return llvm::all_of(
sliceOp.getMixedOffsets(),
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
llvm::all_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return isConstantIntValue(ofr, 1);
});
}
if (isUnpackLikeOp(producer)) {
// Fuse `unset_encoding/unpack` -> elementwise operations. Fuse unpack with
// non-overlapping reductions (i.e., the reduction dimension is not packed).
if (auto consumerLinalgOp = dyn_cast<linalg::LinalgOp>(consumer)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1958,17 +1958,15 @@ util.func public @pad_and_set_encoding_op(%arg0 : tensor<?x?xf32>)
// -----

#encoding = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32]>
util.func public @unset_encoding_and_slice(
util.func public @unset_encoding_with_encoded_slice(
%arg0: tensor<?x?xf32, #encoding>,
%arg1 : index, %arg2 : index) -> tensor<?x?xf32> {
%0 = iree_encoding.unset_encoding %arg0
: tensor<?x?xf32, #encoding> -> tensor<?x?xf32>{%arg1, %arg2}
%1 = tensor.extract_slice %0[0, 0] [%arg1, %arg2] [1, 1]
: tensor<?x?xf32> to tensor<?x?xf32>
util.return %1 : tensor<?x?xf32>
util.return %0 : tensor<?x?xf32>
}
// CHECK: #[[ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32]>
// CHECK: util.func public @unset_encoding_and_slice
// CHECK: util.func public @unset_encoding_with_encoded_slice
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #[[ENCODING]]>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
Expand All @@ -1991,8 +1989,7 @@ util.func public @unset_encoding_and_slice(
// CHECK-SAME: sizes = [%[[D0_W]], %[[D1_W]]]
// CHECK-SAME: !flow.dispatch.tensor<readonly:tensor<?x?xf32, #[[ENCODING]]>>{%[[D0_W]], %[[D1_W]]}
// CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding %[[LOAD]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNSET_ENCODING]][0, 0] [%[[ARG0_W]], %[[ARG1_W]]]
// CHECK: flow.dispatch.tensor.store %[[SLICE]], %[[OUTARG]]
// CHECK: flow.dispatch.tensor.store %[[UNSET_ENCODING]], %[[OUTARG]]
// CHECK-SAME: sizes = [%[[ARG0_W]], %[[ARG1_W]]]
// CHECK-SAME: !flow.dispatch.tensor<writeonly:tensor<?x?xf32>>{%[[ARG0_W]], %[[ARG1_W]]}
// CHECK: flow.return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,38 +304,34 @@ util.func public @unset_encoding_elementwise_fusion(
// -----

#encoding = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32]>
util.func public @unset_encoding_slice_elementwise_fusion(
util.func public @unset_encoding_elementwise_fusion(
%arg0: tensor<?x?xf32, #encoding>,
%arg1: tensor<?xf32>, %arg2 : index, %arg3 : index) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = iree_encoding.unset_encoding %arg0
: tensor<?x?xf32, #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32]>> -> tensor<?x?xf32>{%arg2, %arg3}
%1 = tensor.extract_slice %0[0, 0] [%arg2, %arg3] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%2 = tensor.dim %1, %c0 : tensor<?x?xf32>
%3 = tensor.dim %1, %c1 : tensor<?x?xf32>
%4 = tensor.empty(%2, %3) : tensor<?x?xf32>
%5 = linalg.generic {
%1 = tensor.empty(%arg2, %arg3) : tensor<?x?xf32>
%2 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%1, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
outs(%4 : tensor<?x?xf32>) {
ins(%0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
outs(%1 : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
%6 = arith.addf %b0, %b1 : f32
linalg.yield %6 : f32
%3 = arith.addf %b0, %b1 : f32
linalg.yield %3 : f32
} -> tensor<?x?xf32>
util.return %5 : tensor<?x?xf32>
util.return %2 : tensor<?x?xf32>
}
// CHECK: #[[$ENCODING:.+]] = #iree_encoding.encoding<operand_index = 0 : i64, op_type = matmul, element_types = [f32, f32, f32]>
// CHECK-LABEL: util.func public @unset_encoding_slice_elementwise_fusion(
// CHECK-LABEL: util.func public @unset_encoding_elementwise_fusion(
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32, #[[$ENCODING]]>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
// CHECK: %[[RESULT0:.+]] = flow.dispatch.region
// CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding %[[ARG0]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNSET_ENCODING]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {{.*}} ins(%[[SLICE]]
// CHECK: %[[GENERIC:.+]] = linalg.generic {{.*}} ins(%[[UNSET_ENCODING]]
// CHECK: flow.return %[[GENERIC]]
// CHECK: util.return %[[RESULT0]]

Expand Down Expand Up @@ -382,6 +378,22 @@ util.func public @unpack_elementwise_fusion(

// -----

#encoding = #iree_encoding.encoding<operand_index = 2 : index, op_type = matmul, element_types = [f32, f32, f32]>
util.func public @unset_encoding_slice(%arg0: tensor<1x50x384xf32, #encoding>) -> tensor<384xf32> {
%0 = iree_encoding.unset_encoding %arg0 : tensor<1x50x384xf32, #encoding> -> tensor<1x50x384xf32>
%extracted_slice = tensor.extract_slice %0[0, 0, 0] [1, 1, 384] [1, 1, 1] : tensor<1x50x384xf32> to tensor<384xf32>
util.return %extracted_slice : tensor<384xf32>
}
// CHECK-LABEL: util.func public @unset_encoding_slice
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[UNSET_ENCODING:.+]] = iree_encoding.unset_encoding
// CHECK: flow.return %[[UNSET_ENCODING]]
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[RESULT]]
// CHECK: util.return %[[SLICE]]

// -----

util.func public @unpack_non_intersecting_reduction(
%arg0: tensor<?x?x?xf32>,
%arg1: tensor<?xf32>) -> tensor<?xf32> {
Expand Down

0 comments on commit 39c56de

Please sign in to comment.