From 02e34b0cd64382994ada4be77eb097f64c954597 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 19 Oct 2023 15:51:38 -0700 Subject: [PATCH] Optimize `moveOp[Up,Down]InBlock` functions in `SimplifyGlobalAccesses`. (#15245) Following discussion in [this Discord thread](https://discord.com/channels/689900678990135345/1163611525009920082). Profiling revealed that cleanup patterns in the Stream dialect can be particularly slow for large programs (1000s of ops in a function, with 100s of global constants): ![image](https://github.com/openxla/iree/assets/4010439/4f3f2211-f159-4635-b175-6915e0246c8e) Of particular note, the `moveOpUpInBlock` and `moveOpDownInBlock` functions in `SimplifyGlobalAccesses` were shown to be inefficient (loop around "can move down -> move down"). This optimizes those functions by caching the set of ops in each block that block motion and deferring op movement until the final movement location is determined. Results are: * Stream compilation phase from ~17s to ~15s on llama2_7b_int4_stripped.mlir on my machine * `SimplifyGlobalAccesses` mean time from 10.66ms to 1.45ms (no 1s+ outliers) * `SimplifyGlobalAccesses` median time from 113us to 151us (more overhead for very small blocks) Yellow ("this trace") is baseline, Red ("external trace") is with this PR: ![image](https://github.com/openxla/iree/assets/4010439/8291f319-8e8f-4932-b33a-fb5146bf8698) Structural changes like reordering passes or splitting up blocks/functions could yield larger improvements. --- .../Transforms/SimplifyGlobalAccesses.cpp | 41 ++++++++++++++----- .../test/simplify_global_accesses.mlir | 28 +++++++++++++ 2 files changed, 59 insertions(+), 10 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index 563425edd1cc..4a8d68e574b0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -87,20 +87,40 @@ static bool doesOpBlockMotion(Operation *op) { op->hasTrait(); } -static void moveOpUpInBlock(Block &block, Operation *op) { - while (op->getPrevNode()) { - if (doesOpBlockMotion(op->getPrevNode())) +static SetVector getOpsThatBlockMotion(Block &block) { + SetVector ops; + for (auto &op : block.getOperations()) { + if (doesOpBlockMotion(&op)) + ops.insert(&op); + } + return ops; +} + +static void moveOpUpInBlock(Block &block, Operation *op, + const SetVector &opsThatBlockMotion) { + // Find the earliest node that does not block op motion then move before it. + mlir::Operation *earliestValidNode = op; + while (earliestValidNode->getPrevNode()) { + if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) break; - op->moveBefore(op->getPrevNode()); + earliestValidNode = earliestValidNode->getPrevNode(); } + if (earliestValidNode != op) + op->moveBefore(earliestValidNode); } -static void moveOpDownInBlock(Block &block, Operation *op) { - while (op->getNextNode()) { - if (doesOpBlockMotion(op->getNextNode())) +static void +moveOpDownInBlock(Block &block, Operation *op, + const SetVector &opsThatBlockMotion) { + // Find the latest node that does not block op motion then move after it. + mlir::Operation *latestValidNode = op; + while (latestValidNode->getNextNode()) { + if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) break; - op->moveAfter(op->getNextNode()); + latestValidNode = latestValidNode->getNextNode(); } + if (latestValidNode != op) + op->moveAfter(latestValidNode); } // Optimizes the load/store ops for each given bucket. @@ -109,6 +129,7 @@ static bool optimizeBuckets(Block &block, std::map> &buckets) { bool didRemoveAny = false; + auto opsThatBlockMotion = getOpsThatBlockMotion(block); for (auto &bucket : buckets) { // First perform basic load-store forwarding and such. auto &ops = bucket.second; @@ -164,7 +185,7 @@ optimizeBuckets(Block &block, // If the head op is a load we can move that to the top of the block. LLVM_DEBUG(llvm::dbgs() << "moving mutable global " << loadOp.getGlobalName() << " load upward\n"); - moveOpUpInBlock(block, ops.front()); + moveOpUpInBlock(block, ops.front(), opsThatBlockMotion); } if (auto storeOp = dyn_cast(ops.back())) { @@ -172,7 +193,7 @@ optimizeBuckets(Block &block, LLVM_DEBUG(llvm::dbgs() << "moving mutable global " << storeOp.getGlobalName() << " store downward\n"); - moveOpDownInBlock(block, ops.back()); + moveOpDownInBlock(block, ops.back(), opsThatBlockMotion); } } return didRemoveAny; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir index f4e530511b24..9e2077ad3e99 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir @@ -134,3 +134,31 @@ func.func @side_effects() { } func.func private @other_fn() + +// ----- + +util.global private mutable @varA = dense<1> : tensor<2xi32> +util.global private mutable @varB = dense<2> : tensor<2xi32> + +// CHECK-LABEL: @ordering +func.func @ordering() { + %cst_top = arith.constant 1 : index + %varA_0 = util.global.load @varA {id = 0} : tensor<2xi32> + util.global.store %varA_0, @varA {id = 0} : tensor<2xi32> + %varB_0 = util.global.load @varB {id = 1} : tensor<2xi32> + util.global.store %varB_0, @varB {id = 1} : tensor<2xi32> + %cst_bottom = arith.constant 2 : index + + // Loads should be moved up (in any order). + // CHECK-DAG: %[[T0:.+]] = util.global.load @varA {id = 0 + // CHECK-DAG: %[[T1:.+]] = util.global.load @varB {id = 1 + // CHECK-NEXT: arith.constant + + // CHECK-NOT: NOT + + // Stores should be moved down (in any order). + // CHECK-NEXT: arith.constant + // CHECK-DAG: util.global.store %[[T0]], @varA {id = 0 + // CHECK-DAG: util.global.store %[[T1]], @varB {id = 1 + return +}