From a43d893b4d1946fabd7a6c7eb74c63ba7d42cdd5 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Fri, 27 Dec 2024 12:08:55 -0800 Subject: [PATCH] [Dispatch] Disable scatter fusion with producers (#19565) Backends don't currently support scatter fusion and will silently compile incorrect code. This should be turned off in order to prevent backends from generating incorrect results. I don't think any users are running into this currently, but its best to keep it off for now. Similar to https://github.com/iree-org/iree/pull/19535 but for both `indices` and `updates`. --------- Signed-off-by: Ian Wood --- .../iree/compiler/DispatchCreation/FormDispatchRegions.cpp | 3 ++- .../DispatchCreation/test/dispatch_linalg_ext_fusion.mlir | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp index 73d306bd7f6e..7cac15a828df 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp @@ -651,7 +651,8 @@ isFusableWithProducer(OpOperand &operand, } // Don't fuse attention with it's producer - if (isa(consumer)) { + // TODO: Enable scatter fusion when supported by backends. + if (isa(consumer)) { return false; } diff --git a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir index e1bc91eafa80..1575cf7a46b4 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_ext_fusion.mlir @@ -37,9 +37,9 @@ util.func public @linalgext_scatter_dispatch() -> tensor<8192x16x8x128xf32> { } // CHECK-LABEL: util.func public @linalgext_scatter_dispatch +// CHECK-DAG: %[[INDICES:.+]] = flow.dispatch.region +// CHECK-DAG: %[[UPDATE:.+]] = flow.dispatch.region // CHECK: %[[RESULT:.+]] = flow.dispatch.region -// CHECK: %[[INDICES:.+]] = linalg.generic -// CHECK: %[[UPDATE:.+]] = linalg.generic // CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter // CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) // CHECK: flow.return %[[SCATTER_RESULT]]