Skip to content

Commit

Permalink
Collapse dims when producer is unpack op (iree-org#17725)
Browse files Browse the repository at this point in the history
iree-org#17594

---------

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jul 21, 2024
1 parent b0512e2 commit 0d0b989
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/FormDispatchRegions.h"
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Analysis/SliceAnalysis.h"
Expand Down Expand Up @@ -105,9 +105,26 @@ getCollapsibleLoops(linalg::GenericOp genericOp) {
(rDimsSet.count(prePos) && rDimsSet.count(nextPos));
};

// Find all dims that are used to iterate over operands that aren't produced
// outside of the dispatch
auto regionOp = cast<DispatchRegionOp>(genericOp->getParentOp());
llvm::SmallSet<unsigned, 8> preservedDims;
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
auto definingOp = operand->get().getDefiningOp();
if (!definingOp ||
definingOp->getParentOfType<DispatchRegionOp>() != regionOp)
continue;
for (AffineExpr expr :
genericOp.getMatchingIndexingMap(operand).getResults()) {
preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
}
}

ReassociationIndices range;
AffineExpr preExpr;
// Find the largest sequence of dimensions that are
// - Not used to index operands with defining ops
// AND
// - Either preserved in all maps, or
// - are completely absent
// This sequence can be collapsed. To find the sequence,
Expand All @@ -119,16 +136,18 @@ getCollapsibleLoops(linalg::GenericOp genericOp) {
// and repeat till the last element of sequence and the next result
// expression is not found as a sequence in all maps.
for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) {
unsigned position = cast<AffineDimExpr>(nextExpr).getPosition();
if (!range.empty()) {
if (!hasAllMapsSameSequence(preExpr, nextExpr) ||
if (preservedDims.contains(position) ||
!hasAllMapsSameSequence(preExpr, nextExpr) ||
!hasSameIteratorType(preExpr, nextExpr)) {
if (range.size() > 1) {
contiguousLoops.push_back({range.begin(), range.end()});
}
range.clear();
}
}
range.push_back(cast<AffineDimExpr>(nextExpr).getPosition());
range.push_back(position);
preExpr = nextExpr;
}
if (range.size() > 1)
Expand Down Expand Up @@ -167,6 +186,15 @@ static bool isEligibleForCollapse(linalg::GenericOp genericOp) {
return false;
}

// TODO(#17948) GPU codegen fails when trying to collapse the
// dimensions of an elementwise op in the case of elementwise(contraction).
// For now, don't collapse when there is a linalgOp producer.
if (llvm::any_of(genericOp.getDpsInputs(), [](Value val) -> bool {
return val.getDefiningOp<linalg::LinalgOp>();
})) {
return false;
}

// TODO(guray) Collapsing caused performance regression in a cpu
// benchmark, so we disable it.
if (genericOp.hasIndexSemantics())
Expand All @@ -176,7 +204,6 @@ static bool isEligibleForCollapse(linalg::GenericOp genericOp) {
}

/// Traverses all the the Ops in DispatchRegionOps and finds linalg.generic Op
/// without any producers.
static FailureOr<linalg::GenericOp>
findRootGenericOp(DispatchRegionOp regionOp) {
if (!llvm::hasSingleElement(regionOp.getBody())) {
Expand All @@ -200,15 +227,6 @@ findRootGenericOp(DispatchRegionOp regionOp) {
}
}

// Check that the operands of the generic op are defined outside the dispatch.
for (OpOperand *inputOperands : collapsibleOp.getDpsInputOperands()) {
Operation *definingOp = inputOperands->get().getDefiningOp();
if (definingOp &&
definingOp->getParentOfType<DispatchRegionOp>() == regionOp) {
return failure();
}
}

// Check that the output is either a `tensor.empty` or a `linalg.fill` op by
// traversing the operations that define the `init` operands of the
// `collapsibleOp`.
Expand Down Expand Up @@ -398,11 +416,11 @@ hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter,
return newDispatchOp;
}

/// Traverses DispatchRegionOps to find linalg genericOps that has no
/// producers and tries to collapse its dimensions.
/// Traverses DispatchRegionOps to find linalg genericOps and collapses
/// dimensions without modifying operands with producers
static bool collapseDimensions(IRRewriter &rewriter,
DispatchRegionOp &regionOp) {
// Step 1. Find the root linalg.generic Op with no producer
// Step 1. Find the root linalg.generic Op
std::optional<linalg::GenericOp> genericOp = findRootGenericOp(regionOp);
if (!genericOp.has_value())
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,36 @@ util.func public @do_not_collapse_cst_in_place(%arg0: tensor<1x1x2304xf32>) {
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: ins(%[[COLLAPSED_ARG0]], %[[COLLAPSED_CST]]
// CHECK: flow.return %[[RES]]


// -----
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d1)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
util.func public @unpack_collapse(%arg0: tensor<2x320x128x128xf32>, %arg1: tensor<320xf32>, %arg2: tensor<320xf32>, %arg3: tensor<1x5x2x64xf32>) -> tensor<2x320x128x128xf16> {
%dispatch = flow.dispatch.region -> (tensor<2x320x128x128xf16>) {
%0 = tensor.empty() : tensor<2x320xf32>
%unpack = tensor.unpack %arg3 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [2, 64] into %0 : tensor<1x5x2x64xf32> -> tensor<2x320xf32>
%1 = tensor.empty() : tensor<2x320x128x128xf16>
%2 = linalg.generic {
indexing_maps = [#map, #map1, #map2, #map1, #map],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%arg0, %arg1, %unpack, %arg2 : tensor<2x320x128x128xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>)
outs(%1 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %in_0: f32, %in_1: f32, %in_2: f32, %out: f16):
%3 = arith.addf %in_1, %in_2 : f32
%4 = arith.addf %in, %in_0 : f32
%5 = arith.truncf %3 : f32 to f16
%6 = arith.truncf %4 : f32 to f16
%7 = arith.addf %6, %5 : f16
linalg.yield %7 : f16
} -> tensor<2x320x128x128xf16>
flow.return %2 : tensor<2x320x128x128xf16>
}
util.return %dispatch : tensor<2x320x128x128xf16>
}

// CHECK-LABEL: util.func public @unpack_collapse
// CHECK: %[[GEN:.+]] = linalg.generic
// CHECK-SAME: tensor<2x320x16384xf32>, tensor<320xf32>, tensor<2x320xf32>, tensor<320xf32>

0 comments on commit 0d0b989

Please sign in to comment.