diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 73d306bd7f6e..7cac15a828df 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -651,7 +651,8 @@ isFusableWithProducer(OpOperand &operand, } // Don't fuse attention with it's producer - if (isa(consumer)) { + // TODO: Enable scatter fusion when supported by backends. + if (isa(consumer)) { return false; } diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir index e1bc91eafa80..1575cf7a46b4 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir @@ -37,9 +37,9 @@ util.func public @linalgext_scatter_dispatch() -> tensor<8192x16x8x128xf32> { } // CHECK-LABEL: util.func public @linalgext_scatter_dispatch +// CHECK-DAG: %[[INDICES:.+]] = flow.dispatch.region +// CHECK-DAG: %[[UPDATE:.+]] = flow.dispatch.region // CHECK: %[[RESULT:.+]] = flow.dispatch.region -// CHECK: %[[INDICES:.+]] = linalg.generic -// CHECK: %[[UPDATE:.+]] = linalg.generic // CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter // CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) // CHECK: flow.return %[[SCATTER_RESULT]]