diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp index 563425edd1cc..4a8d68e574b0 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/SimplifyGlobalAccesses.cpp @@ -87,20 +87,40 @@ static bool doesOpBlockMotion(Operation *op) { op->hasTrait(); } -static void moveOpUpInBlock(Block &block, Operation *op) { - while (op->getPrevNode()) { - if (doesOpBlockMotion(op->getPrevNode())) +static SetVector getOpsThatBlockMotion(Block &block) { + SetVector ops; + for (auto &op : block.getOperations()) { + if (doesOpBlockMotion(&op)) + ops.insert(&op); + } + return ops; +} + +static void moveOpUpInBlock(Block &block, Operation *op, + const SetVector &opsThatBlockMotion) { + // Find the earliest node that does not block op motion then move before it. + mlir::Operation *earliestValidNode = op; + while (earliestValidNode->getPrevNode()) { + if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode())) break; - op->moveBefore(op->getPrevNode()); + earliestValidNode = earliestValidNode->getPrevNode(); } + if (earliestValidNode != op) + op->moveBefore(earliestValidNode); } -static void moveOpDownInBlock(Block &block, Operation *op) { - while (op->getNextNode()) { - if (doesOpBlockMotion(op->getNextNode())) +static void +moveOpDownInBlock(Block &block, Operation *op, + const SetVector &opsThatBlockMotion) { + // Find the latest node that does not block op motion then move after it. + mlir::Operation *latestValidNode = op; + while (latestValidNode->getNextNode()) { + if (opsThatBlockMotion.contains(latestValidNode->getNextNode())) break; - op->moveAfter(op->getNextNode()); + latestValidNode = latestValidNode->getNextNode(); } + if (latestValidNode != op) + op->moveAfter(latestValidNode); } // Optimizes the load/store ops for each given bucket. @@ -109,6 +129,7 @@ static bool optimizeBuckets(Block &block, std::map> &buckets) { bool didRemoveAny = false; + auto opsThatBlockMotion = getOpsThatBlockMotion(block); for (auto &bucket : buckets) { // First perform basic load-store forwarding and such. auto &ops = bucket.second; @@ -164,7 +185,7 @@ optimizeBuckets(Block &block, // If the head op is a load we can move that to the top of the block. LLVM_DEBUG(llvm::dbgs() << "moving mutable global " << loadOp.getGlobalName() << " load upward\n"); - moveOpUpInBlock(block, ops.front()); + moveOpUpInBlock(block, ops.front(), opsThatBlockMotion); } if (auto storeOp = dyn_cast(ops.back())) { @@ -172,7 +193,7 @@ optimizeBuckets(Block &block, LLVM_DEBUG(llvm::dbgs() << "moving mutable global " << storeOp.getGlobalName() << " store downward\n"); - moveOpDownInBlock(block, ops.back()); + moveOpDownInBlock(block, ops.back(), opsThatBlockMotion); } } return didRemoveAny; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir index f4e530511b24..9e2077ad3e99 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/simplify_global_accesses.mlir @@ -134,3 +134,31 @@ func.func @side_effects() { } func.func private @other_fn() + +// ----- + +util.global private mutable @varA = dense<1> : tensor<2xi32> +util.global private mutable @varB = dense<2> : tensor<2xi32> + +// CHECK-LABEL: @ordering +func.func @ordering() { + %cst_top = arith.constant 1 : index + %varA_0 = util.global.load @varA {id = 0} : tensor<2xi32> + util.global.store %varA_0, @varA {id = 0} : tensor<2xi32> + %varB_0 = util.global.load @varB {id = 1} : tensor<2xi32> + util.global.store %varB_0, @varB {id = 1} : tensor<2xi32> + %cst_bottom = arith.constant 2 : index + + // Loads should be moved up (in any order). + // CHECK-DAG: %[[T0:.+]] = util.global.load @varA {id = 0 + // CHECK-DAG: %[[T1:.+]] = util.global.load @varB {id = 1 + // CHECK-NEXT: arith.constant + + // CHECK-NOT: NOT + + // Stores should be moved down (in any order). + // CHECK-NEXT: arith.constant + // CHECK-DAG: util.global.store %[[T0]], @varA {id = 0 + // CHECK-DAG: util.global.store %[[T1]], @varB {id = 1 + return +}