diff --git a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp index 4c2cb2550a7d..d408c057e5de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/DecomposePackUnPackOps.cpp @@ -56,7 +56,8 @@ struct LowerPackPattern : public OpRewritePattern { if (controlFn && failed(controlFn.value()(op))) { return failure(); } - FailureOr res = linalg::lowerPack(rewriter, op); + FailureOr res = + linalg::lowerPack(rewriter, op, /*lowerPadLikeWithInsertSlice=*/false); if (failed(res)) { return rewriter.notifyMatchFailure( op, "cannot lower to pad + expand + transpose"); @@ -83,8 +84,8 @@ struct LowerUnPackPattern : public OpRewritePattern { if (controlFn && failed(controlFn.value()(op))) { return failure(); } - FailureOr res = - linalg::lowerUnPack(rewriter, op); + FailureOr res = linalg::lowerUnPack( + rewriter, op, /*lowerUnpadLikeWithExtractSlice=*/false); if (failed(res)) { return rewriter.notifyMatchFailure( op, "cannot lower to empty + transpose + reshape + extract_slice"); diff --git a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir index b37fa1a74e9d..9f51fe5ad72a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/decompose_pack_unpack_ops.mlir @@ -30,8 +30,13 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x // CHECK-ALL-SAME: %[[PAD_VAL:[A-Za-z0-9]+]]: // CHECK-ALL: %[[PAD:.+]] = tensor.pad %[[IN]] low[0, 0] high[3, 1] // CHECK-ALL: tensor.yield %[[PAD_VAL]] -// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] -// CHECK-ALL: return %[[INSERT]] + +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[PAD]] into %[[OUT]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]] output_shape [1, 8, 1, 2] : tensor<8x2xf32> into tensor<1x8x1x2xf32> +// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x8x1x2xf32>) outs(%[[OUT]] : tensor<1x1x8x2xf32>) permutation = [0, 2, 1, 3] +// CHECK-RESHAPE: return %[[TRANS]] // ----- @@ -42,8 +47,13 @@ func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32 // CHECK-ALL-LABEL: func.func @simple_NC_to_CNnc // CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]: // CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-ALL: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK-ALL: return %[[INSERT]] + +// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[IN]] into %[[OUT]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] +// CHECK: return %[[INSERT]] + +// CHECK-RESHAPE: %[[EXPAND:.+]] = tensor.expand_shape %[[IN]] {{\[}}[0, 1], [2, 3]] output_shape [1, 32, 1, 8] : tensor<32x8xf32> into tensor<1x32x1x8xf32> +// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[EXPAND]] : tensor<1x32x1x8xf32>) outs(%[[OUT]] : tensor<1x1x32x8xf32>) permutation = [2, 0, 1, 3] +// CHECK-RESHAPE: return %[[TRANS]] // ----- @@ -132,8 +142,11 @@ func.func @simple_unpack_and_extract_slice(%input: tensor<1x1x8x2xf32>, %output: // CHECK: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1] // CHECK: %[[RES:.+]] = tensor.extract_slice %[[TILE]][0, 0] [5, 1] [1, 1] -// CHECK-RESHAPE: %[[RES:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] - +// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x8x1x2xf32> +// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x8x2xf32>) outs(%[[EMPTY]] : tensor<1x8x1x2xf32>) permutation = [0, 2, 1, 3] +// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-RESHAPE: %[[SLICE:.+]] = tensor.extract_slice %[[COLLAPSE]] +// CHECK-RESHAPE: %[[RES:.+]] = linalg.copy ins(%[[SLICE]] : tensor<5x1xf32>) outs(%[[OUT]] : tensor<5x1xf32>) -> tensor<5x1xf32> // CHECK-ALL: return %[[RES:.+]] // ----- @@ -145,8 +158,13 @@ func.func @simple_CNnc_to_NC(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<32x8xf32 // CHECK-ALL-LABEL: func.func @simple_CNnc_to_NC // CHECK-ALL-SAME: %[[IN:[A-Za-z0-9]+]]: // CHECK-ALL-SAME: %[[OUT:[A-Za-z0-9]+]]: -// CHECK-ALL: %[[TILE:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK-ALL: return %[[TILE]] +// CHECK: %[[RESULT:.+]] = tensor.extract_slice %[[IN]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] + +// CHECK-RESHAPE: %[[EMPTY:.+]] = tensor.empty() : tensor<1x32x1x8xf32> +// CHECK-RESHAPE: %[[TRANS:.+]] = linalg.transpose ins(%[[IN]] : tensor<1x1x32x8xf32>) outs(%[[EMPTY]] : tensor<1x32x1x8xf32>) permutation = [1, 2, 0, 3] +// CHECK-RESHAPE: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-RESHAPE: %[[RESULT:.+]] = linalg.copy ins(%[[COLLAPSE]] : tensor<32x8xf32>) outs(%[[OUT]] : tensor<32x8xf32>) -> tensor<32x8xf32> +// CHECK-ALL: return %[[RESULT]] // -----