Skip to content

Commit

Permalink
Revert "Enable scatter fusion with index operand. (iree-org#19198)" (i…
Browse files Browse the repository at this point in the history
…ree-org#19535)

This reverts commit 4c00a22.

Seems to be cause of iree-org#19533
  • Loading branch information
MaheshRavishankar authored Dec 20, 2024
1 parent 83af679 commit 07f81f0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,14 @@ matchIteratorTypes(const llvm::SmallBitVector &rootOuterParallelLoop,

// If the candidate is all parallel, then it should be at least as parallel as
// the root.
for (int pos : llvm::seq<int>(0, std::min(candidateOuterParallelLoop.size(),
rootOuterParallelLoop.size()))) {
for (int pos : llvm::seq<int>(0, rootOuterParallelLoop.size())) {
// If we reach the end of the outer loops of the root, break out of the
// loop.
if (!rootOuterParallelLoop.test(pos))
break;
// If the root loop is parallel, the candidate loop should also be parallel.
if (!candidateOuterParallelLoop.test(pos))
if (pos >= candidateOuterParallelLoop.size() ||
!candidateOuterParallelLoop.test(pos))
return false;
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -922,35 +922,3 @@ util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<
// CHECK-SAME: ins(%[[DISPATCH1]],
// CHECK: flow.return %[[CUSTOM_OP]]
// CHECK: util.return %[[DISPATCH2]]

// -----

util.func @scatter_index_producer_fusion(%arg0 : tensor<?x1xi64>,
%arg1 : index, %arg2 : tensor<?x1x32x8x128xf16>,
%arg3 : tensor<?x32x8x128xf16>) -> tensor<?x32x8x128xf16> {
%empty = tensor.empty(%arg1) : tensor<?x1xi32>
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x1xi64>) outs(%empty : tensor<?x1xi32>) {
^bb0(%in: i64, %out: i32):
%1 = arith.trunci %in : i64 to i32
linalg.yield %1 : i32
} -> tensor<?x1xi32>
%1 = iree_linalg_ext.scatter
dimension_map = [0] unique_indices(true)
ins(%arg2, %0 : tensor<?x1x32x8x128xf16>, tensor<?x1xi32>)
outs(%arg3 : tensor<?x32x8x128xf16>) {
^bb0(%arg6: f16, %arg7: f16):
iree_linalg_ext.yield %arg6 : f16
} -> tensor<?x32x8x128xf16>
util.return %1 : tensor<?x32x8x128xf16>
}
// CHECK-LABEL: func public @scatter_index_producer_fusion
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%{{.+}}, %[[GENERIC]] :
// CHECK: flow.return %[[SCATTER]]
// CHECK: util.return %[[DISPATCH]]

0 comments on commit 07f81f0

Please sign in to comment.