Skip to content

Commit

Permalink
Optimize moveOp[Up,Down]InBlock functions in `SimplifyGlobalAccesse…
Browse files Browse the repository at this point in the history
…s`. (iree-org#15245)

Following discussion in [this Discord
thread](https://discord.com/channels/689900678990135345/1163611525009920082).

Profiling revealed that cleanup patterns in the Stream dialect can be
particularly slow for large programs (1000s of ops in a function, with
100s of global constants):

![image](https://github.com/openxla/iree/assets/4010439/4f3f2211-f159-4635-b175-6915e0246c8e)

Of particular note, the `moveOpUpInBlock` and `moveOpDownInBlock`
functions in `SimplifyGlobalAccesses` were shown to be inefficient (loop
around "can move down -> move down"). This optimizes those functions by
caching the set of ops in each block that block motion and deferring op
movement until the final movement location is determined.

Results are:
* Stream compilation phase from ~17s to ~15s on
llama2_7b_int4_stripped.mlir on my machine
* `SimplifyGlobalAccesses` mean time from 10.66ms to 1.45ms (no 1s+
outliers)
* `SimplifyGlobalAccesses` median time from 113us to 151us (more
overhead for very small blocks)

Yellow ("this trace") is baseline, Red ("external trace") is with this
PR:

![image](https://github.com/openxla/iree/assets/4010439/8291f319-8e8f-4932-b33a-fb5146bf8698)

Structural changes like reordering passes or splitting up
blocks/functions could yield larger improvements.
  • Loading branch information
ScottTodd authored Oct 19, 2023
1 parent df00df9 commit 02e34b0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,40 @@ static bool doesOpBlockMotion(Operation *op) {
op->hasTrait<OpTrait::IsTerminator>();
}

static void moveOpUpInBlock(Block &block, Operation *op) {
while (op->getPrevNode()) {
if (doesOpBlockMotion(op->getPrevNode()))
static SetVector<Operation *> getOpsThatBlockMotion(Block &block) {
SetVector<Operation *> ops;
for (auto &op : block.getOperations()) {
if (doesOpBlockMotion(&op))
ops.insert(&op);
}
return ops;
}

static void moveOpUpInBlock(Block &block, Operation *op,
const SetVector<Operation *> &opsThatBlockMotion) {
// Find the earliest node that does not block op motion then move before it.
mlir::Operation *earliestValidNode = op;
while (earliestValidNode->getPrevNode()) {
if (opsThatBlockMotion.contains(earliestValidNode->getPrevNode()))
break;
op->moveBefore(op->getPrevNode());
earliestValidNode = earliestValidNode->getPrevNode();
}
if (earliestValidNode != op)
op->moveBefore(earliestValidNode);
}

static void moveOpDownInBlock(Block &block, Operation *op) {
while (op->getNextNode()) {
if (doesOpBlockMotion(op->getNextNode()))
static void
moveOpDownInBlock(Block &block, Operation *op,
const SetVector<Operation *> &opsThatBlockMotion) {
// Find the latest node that does not block op motion then move after it.
mlir::Operation *latestValidNode = op;
while (latestValidNode->getNextNode()) {
if (opsThatBlockMotion.contains(latestValidNode->getNextNode()))
break;
op->moveAfter(op->getNextNode());
latestValidNode = latestValidNode->getNextNode();
}
if (latestValidNode != op)
op->moveAfter(latestValidNode);
}

// Optimizes the load/store ops for each given bucket.
Expand All @@ -109,6 +129,7 @@ static bool
optimizeBuckets(Block &block,
std::map<StringRef, SmallVector<Operation *>> &buckets) {
bool didRemoveAny = false;
auto opsThatBlockMotion = getOpsThatBlockMotion(block);
for (auto &bucket : buckets) {
// First perform basic load-store forwarding and such.
auto &ops = bucket.second;
Expand Down Expand Up @@ -164,15 +185,15 @@ optimizeBuckets(Block &block,
// If the head op is a load we can move that to the top of the block.
LLVM_DEBUG(llvm::dbgs() << "moving mutable global "
<< loadOp.getGlobalName() << " load upward\n");
moveOpUpInBlock(block, ops.front());
moveOpUpInBlock(block, ops.front(), opsThatBlockMotion);
}
if (auto storeOp =
dyn_cast<IREE::Util::GlobalStoreOpInterface>(ops.back())) {
// If the tail op is a store we can move that to the bottom of the block.
LLVM_DEBUG(llvm::dbgs()
<< "moving mutable global " << storeOp.getGlobalName()
<< " store downward\n");
moveOpDownInBlock(block, ops.back());
moveOpDownInBlock(block, ops.back(), opsThatBlockMotion);
}
}
return didRemoveAny;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,31 @@ func.func @side_effects() {
}

func.func private @other_fn()

// -----

util.global private mutable @varA = dense<1> : tensor<2xi32>
util.global private mutable @varB = dense<2> : tensor<2xi32>

// CHECK-LABEL: @ordering
func.func @ordering() {
%cst_top = arith.constant 1 : index
%varA_0 = util.global.load @varA {id = 0} : tensor<2xi32>
util.global.store %varA_0, @varA {id = 0} : tensor<2xi32>
%varB_0 = util.global.load @varB {id = 1} : tensor<2xi32>
util.global.store %varB_0, @varB {id = 1} : tensor<2xi32>
%cst_bottom = arith.constant 2 : index

// Loads should be moved up (in any order).
// CHECK-DAG: %[[T0:.+]] = util.global.load @varA {id = 0
// CHECK-DAG: %[[T1:.+]] = util.global.load @varB {id = 1
// CHECK-NEXT: arith.constant

// CHECK-NOT: NOT

// Stores should be moved down (in any order).
// CHECK-NEXT: arith.constant
// CHECK-DAG: util.global.store %[[T0]], @varA {id = 0
// CHECK-DAG: util.global.store %[[T1]], @varB {id = 1
return
}

0 comments on commit 02e34b0

Please sign in to comment.