Skip to content

Commit

Permalink
[Codegen] Ensure hoisted extraction replaced by induction var. (iree-…
Browse files Browse the repository at this point in the history
…org#17975)

This commit teaches the compiler to replace hoisted extraction of IV by
the newly generated IV with the correct shapes. Previously we would
hoist the extraction and replace the IV uses by the hoisted extraction,
however this may not always be correct since the IV's value may be
updated in the loop.

The main motivation of this PR is to fix numerical issue caused by such
case that exists in the attention-cpp pipeline. Although this happens at
the vector level as opposed to the test cases we have for at tensor
level, we can re-use said test. Specific example for this case will be
left in the comment section of this PR.

---------

Signed-off-by: Stanley Winata <stanley.winata@amd.com>
  • Loading branch information
raikonenfnu authored Jul 21, 2024
1 parent 0d0b989 commit 907b2cd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,13 @@ hoistLoopInvariantSubsetAtIterArg(RewriterBase &rewriter,
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
return {insertion.getSourceOperand().get()};
};

// replaceInitOperandUsesInLoop is set to true S.T we will use new IV
// instead of hoisted out extract.
FailureOr<LoopLikeOpInterface> newLoop =
loopLike.replaceWithAdditionalYields(
rewriter, extraction.getResult(),
/*replaceInitOperandUsesInLoop=*/false, newYieldValuesFn);
/*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
if (failed(newLoop))
return loopLike;
loopLike = *newLoop;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ func.func @subset_hoisting_invariant_tensor(%init: tensor<64x64xf32>, %t: tensor

// CHECK-LABEL: @subset_hoisting_invariant_tensor
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK: tensor.extract_slice
// CHECK: scf.for {{.*}} iter_args(%[[IV:.+]] = {{.*}})
// CHECK: %[[SLICE:.+]] = tensor.extract_slice
// CHECK-NOT: tensor.extract_slice
// CHECK: linalg.add ins(%[[IV]], %[[SLICE]] : {{.*}})
// CHECK: scf.yield
// CHECK: tensor.insert_slice

Expand Down

0 comments on commit 907b2cd

Please sign in to comment.