From 4ad00ef95cadc100b5f0a704289b508aec1774a5 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Sat, 29 Jun 2024 12:11:10 -0700 Subject: [PATCH] [Flow] Make sink reshapes changes less conservative. (#17706) While deciding if a reshape needs "sinking", for a `tensor.expand_shape` -> `linalg.*`, first check was to check that the `linalg.*` operation could already fuse with one of its existing producers. That check was broadly aggressive. The fusion only kicks in when the iteration domains match. Eventually the actual dispatch formation logic needs to be commoned to a single place to do this better, but kicking that to a follow up. --------- Signed-off-by: MaheshRavishankar --- .../Dialect/Flow/Transforms/SinkReshapes.cpp | 101 ++++++++++++++---- .../Flow/Transforms/test/sink_reshapes.mlir | 84 +++++++++++++++ 2 files changed, 166 insertions(+), 19 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp index f32d8a96626a..d68320de1ab3 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/SinkReshapes.cpp @@ -64,35 +64,98 @@ static bool shouldSinkExpandShapeOp(OpOperand *opOperand) { if (!isNonNullAndOutsideDispatch({reshapeOp, consumer})) { return false; } + auto consumerGenericOp = dyn_cast(consumer); + if (!consumerGenericOp) { + return false; + } + // Only sink across parallel generic ops for now. + if (consumerGenericOp.getNumParallelLoops() != + consumerGenericOp.getNumLoops()) { + return false; + } // Do not sink reshapes across dequantize operations since they are - // cloned into their producers. + // cloned into their consumers. if (isDequantizationLikeOp(consumer)) { return false; } - // If the op is already fusable with producer using tile and fuse, - // do nothing. - if (llvm::any_of(consumer->getOpOperands(), [](OpOperand &opOperand) { - Operation *currProducer = opOperand.get().getDefiningOp(); - Operation *currConsumer = opOperand.getOwner(); - return isFusableUsingTileAndFuse(currProducer, currConsumer) && - // The check for the producer having a single use is not fully - // worked out. Ideally we can fuse with a producer irrespective - // of number of uses, but is a good thumb rule in practice. - llvm::hasSingleElement(currProducer->getUses()); - })) { + // First check that the expand_shape producer and consumer can be fused. + Operation *reshapeProducer = reshapeOp.getSrc().getDefiningOp(); + if (!reshapeProducer) { return false; } - - // Do not sink if consumer is a contraction/matmul like op. - if (auto linalgConsumerOp = dyn_cast(consumer)) { - if (linalg::isaContractionOpInterface(linalgConsumerOp)) - return false; + if (!isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(), + consumer)) { + return false; } - return isFusableUsingTileAndFuse(reshapeOp.getSrc().getDefiningOp(), - consumer); + // If the op is already fusable with producer using tile and fuse, + // do nothing. + for (OpOperand &opOperand : consumer->getOpOperands()) { + Operation *currProducer = opOperand.get().getDefiningOp(); + if (!currProducer) { + continue; + } + + // The check for the producer having a single use is not fully + // worked out. Ideally we can fuse with a producer irrespective + // of number of uses, but is a good thumb rule in practice. + if (!llvm::hasSingleElement(currProducer->getUses())) { + continue; + } + + // Check if a producer can already be tiled and fused with the consumer. + if (!isFusableUsingTileAndFuse(currProducer, consumer)) { + continue; + } + + // There is already a tile-and-fusable producer to fuse with. Still prefer + // fusing with the producer whose parallel iteration space rank matches + // the consumer parallel iteration space rank to avoid loss of parallelism. + if (auto currLinalgProducer = dyn_cast(currProducer)) { + auto reshapeLinalgProducer = dyn_cast(reshapeProducer); + if (!reshapeLinalgProducer) { + // For now we will prefer to fold with Linalg op. So if the reshape + // producer is not a Linalg op, bail. + return false; + } + + // Somehow this logic does not seem to work well when the reshape producer + // is an elementwise operation. For one, should never have a reshape + // "after" an elementwise operation, since bubble expand shape should + // already account for it, and fuse the elementwise producer of reshape + // and the consumer (which is also elementwise). Needs more investigation + // but removes regressions and lit test failures. + if (reshapeLinalgProducer.getNumLoops() == + reshapeLinalgProducer.getNumParallelLoops() && + currLinalgProducer.getNumLoops() != + currLinalgProducer.getNumParallelLoops()) { + return false; + } + + unsigned currConsumerNumParallelLoops = + consumerGenericOp.getNumParallelLoops(); + unsigned currProducerNumParallelLoops = + currLinalgProducer.getNumParallelLoops(); + if (currProducerNumParallelLoops == currConsumerNumParallelLoops) { + // If the producer has same number of parallel loops as consumer, + // then this is the operand to fuse along. So do nothing. + return false; + } + // If the producer has less number of parallel loops as the consumer, + // ignore this operand. + if (currProducerNumParallelLoops < currConsumerNumParallelLoops) { + continue; + } + unsigned reshapeProducerNumParallelLoops = + reshapeLinalgProducer.getNumParallelLoops(); + if (currProducerNumParallelLoops < reshapeProducerNumParallelLoops) { + return false; + } + } + } + return true; } void SinkReshapesPass::runOnOperation() { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir index 04d0c588d064..92587fd3a407 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/sink_reshapes.mlir @@ -127,3 +127,87 @@ func.func @do_not_sink_across_dequantize_ops(%arg0: tensor) -> tensor<2 // CHECK: %[[DEQUANT:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND]] : // CHECK: return %[[DEQUANT]] + +// ----- + +// Check that reshape sinks based with better estimate of what producers +// -> consumer are fusable. +func.func @better_producer_estimate(%lhs : tensor<2x4096x640xi32>, %rhs : tensor<2x640x640xi32>, + %fill0 : tensor<2x4096x640xi32>, %fill1 : tensor<2x4096xi32>) -> tensor<2x4096x640x1xf16> { + %bmm = linalg.batch_matmul_transpose_b ins(%lhs, %rhs : tensor<2x4096x640xi32>, tensor<2x640x640xi32>) + outs(%fill0 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32> + %reduction = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs : tensor<2x4096x640xi32>) outs(%fill1 : tensor<2x4096xi32>) { + ^bb0(%in: i32, %out: i32): + %12 = arith.addi %in, %out : i32 + linalg.yield %12 : i32 + } -> tensor<2x4096xi32> + %expanded = tensor.expand_shape %bmm [[0], [1], [2, 3]] output_shape [2, 4096, 640, 1] + : tensor<2x4096x640xi32> into tensor<2x4096x640x1xi32> + %empty = tensor.empty() : tensor<2x4096x640x1xf16> + %quant = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%expanded, %reduction : tensor<2x4096x640x1xi32>, tensor<2x4096xi32>) + outs(%empty : tensor<2x4096x640x1xf16>) { + ^bb0(%in: i32, %in_3: i32, %out: f16): + %14 = arith.subi %in, %in_3 : i32 + %16 = arith.sitofp %14 : i32 to f32 + %18 = arith.truncf %16 : f32 to f16 + linalg.yield %18 : f16 + } -> tensor<2x4096x640x1xf16> + return %quant : tensor<2x4096x640x1xf16> +} +// CHECK-LABEL: func @better_producer_estimate( +// CHECK: %[[BMM:.+]] = linalg.batch_matmul_transpose_b +// CHECK: %[[REDUCTION:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[BMM]], %[[REDUCTION]] : +// CHECK: %[[COLLAPSE:.+]] = tensor.expand_shape %[[GENERIC]] +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @reduce_broadcast(%arg0: tensor<4x768xf32>, %arg1: tensor<4xf32>, + %arg2: tensor<4xf32>, %arg3: tensor<1x4x768xf32>) -> tensor<1x4x768xf32> { + %cst = arith.constant 9.000000e+00 : f32 + %cst_0 = arith.constant 8.000000e+00 : f32 + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0, %arg1 : tensor<4x768xf32>, tensor<4xf32>) + outs(%arg2 : tensor<4xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %3 = arith.subf %in, %in_2 : f32 + %4 = arith.mulf %3, %3 : f32 + %5 = arith.addf %out, %4 : f32 + linalg.yield %5 : f32 + } -> tensor<4xf32> + %expanded = tensor.expand_shape %0 [[0, 1]] output_shape [1, 4] + : tensor<4xf32> into tensor<1x4xf32> + %1 = tensor.empty() : tensor<1x4x768xf32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg3, %expanded : tensor<1x4x768xf32>, tensor<1x4xf32>) + outs(%1 : tensor<1x4x768xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %9 = arith.mulf %in, %in_2 : f32 + linalg.yield %9 : f32 + } -> tensor<1x4x768xf32> + return %2 : tensor<1x4x768xf32> +} +// CHECK-LABEL: func @reduce_broadcast( +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%{{.+}}, %[[GENERIC1]] : +// CHECK: tensor.expand_shape %[[GENERIC2]]