Skip to content

Commit

Permalink
[GPU] Disable insert/extract slice lowering from pack/unpack by defau…
Browse files Browse the repository at this point in the history
…lt (iree-org#19590)

This PR is a follow-up to
llvm/llvm-project#117340.

It disables `lowerPadLikeWithInsertSlice` and
`lowerUnpadLikeWithExtractSlice` so `insertslice` or `extractslice`
won't appear when high dimensions are unit dimensions.

---------

Signed-off-by: jerryyin <zhuoryin@amd.com>
  • Loading branch information
jerryyin authored Jan 3, 2025
1 parent 1e935c4 commit 5a97523
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ struct LowerPackPattern : public OpRewritePattern<tensor::PackOp> {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerPackResult> res = linalg::lowerPack(rewriter, op);
FailureOr<linalg::LowerPackResult> res =
linalg::lowerPack(rewriter, op, /*lowerPadLikeWithInsertSlice=*/false);
if (failed(res)) {
return rewriter.notifyMatchFailure(
op, "cannot lower to pad + expand + transpose");
Expand All @@ -83,8 +84,8 @@ struct LowerUnPackPattern : public OpRewritePattern<tensor::UnPackOp> {
if (controlFn && failed(controlFn.value()(op))) {
return failure();
}
FailureOr<linalg::LowerUnPackOpResult> res =
linalg::lowerUnPack(rewriter, op);
FailureOr<linalg::LowerUnPackOpResult> res = linalg::lowerUnPack(
rewriter, op, /*lowerUnpadLikeWithExtractSlice=*/false);
if (failed(res)) {
return rewriter.notifyMatchFailure(
op, "cannot lower to empty + transpose + reshape + extract_slice");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
// CHECK-ALL-SAME: %[[PAD_VAL:[A-Za-z0-9]+]]:
// CHECK-ALL: %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1]
// CHECK-ALL: tensor.yield %[[PAD_VAL]]
// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK-ALL: return %[[INSERT]]

// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [1, 8, 1, 2] : tensor<8x2xf32> into tensor<1x8x1x2xf32>
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x8x1x2xf32>) outs(%[[OUT]] : tensor<1x1x8x2xf32>) permutation = [0, 2, 1, 3]
// CHECK-RESHAPE: return %[[TRANS]]

// -----

Expand All @@ -42,8 +47,13 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32
// CHECK-ALL-LABEL: func.func @simple_NC_to_CNnc
// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK-ALL: return %[[INSERT]]

// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK: return %[[INSERT]]

// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [1, 32, 1, 8] : tensor<32x8xf32> into tensor<1x32x1x8xf32>
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x32x1x8xf32>) outs(%[[OUT]] : tensor<1x1x32x8xf32>) permutation = [2, 0, 1, 3]
// CHECK-RESHAPE: return %[[TRANS]]

// -----

Expand Down Expand Up @@ -132,8 +142,11 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output:
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
// CHECK: %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1]

// CHECK-RESHAPE: %[[RES:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1]

// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x2xf32>
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x8x2xf32>) outs(%[[EMPTY]] : tensor<1x8x1x2xf32>) permutation = [0, 2, 1, 3]
// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-RESHAPE: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]]
// CHECK-RESHAPE: %[[RES:.+]] = linalg.copy ins(%[[SLICE]] : tensor<5x1xf32>) outs(%[[OUT]] : tensor<5x1xf32>) -> tensor<5x1xf32>
// CHECK-ALL: return %[[RES:.+]]

// -----
Expand All @@ -145,8 +158,13 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32
// CHECK-ALL-LABEL: func.func @simple_CNnc_to_NC
// CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]:
// CHECK-ALL: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]
// CHECK-ALL: return %[[TILE]]
// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1]

// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x1x8xf32>
// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x32x8xf32>) outs(%[[EMPTY]] : tensor<1x32x1x8xf32>) permutation = [1, 2, 0, 3]
// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-RESHAPE: %[[RESULT:.+]] = linalg.copy ins(%[[COLLAPSE]] : tensor<32x8xf32>) outs(%[[OUT]] : tensor<32x8xf32>) -> tensor<32x8xf32>
// CHECK-ALL: return %[[RESULT]]

// -----

Expand Down

0 comments on commit 5a97523

Please sign in to comment.