diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel index 5c9f3ce03cca..1b3fa5e176f9 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/BUILD.bazel @@ -96,6 +96,7 @@ iree_compiler_cc_library( "//llvm-external-projects/iree-dialects:IREELinalgTransformDialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt index 9c4a07f141d5..e90ef1bee6ae 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt @@ -78,6 +78,7 @@ iree_cc_library( IREELinalgTransformDialect LLVMSupport MLIRAffineDialect + MLIRAffineUtils MLIRAnalysis MLIRArithDialect MLIRArithUtils diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp index 6b0445e60ac5..19a0a43f7325 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/CollapseDimensions.cpp @@ -12,6 +12,8 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" @@ -27,6 +29,8 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include + #define DEBUG_TYPE "iree-flow-collapse-dimensions" namespace mlir { @@ -48,47 +52,88 @@ struct CollapseDimensionsPass /// Searches the same sequence in all the affine maps and collapses these /// dimensions. It only applies these to "parallel" loops without mixing them -/// with "reduction" types. +/// with "reduction" types. It is expected that the `genericOp` has projected +/// permutations only as indexing maps. (Checked using `isEligibleForCollapse`). static SmallVector getCollapsibleLoops(linalg::GenericOp genericOp) { SmallVector contiguousLoops; - SmallVector pDims; + SmallVector pDims, rDims; genericOp.getParallelDims(pDims); - if (pDims.size() < 2) - return contiguousLoops; - - llvm::SmallDenseSet pLoops(pDims.begin(), pDims.end()); + genericOp.getReductionDims(rDims); + llvm::SmallDenseSet pDimsSet, rDimsSet; + pDimsSet.insert(pDims.begin(), pDims.end()); + rDimsSet.insert(rDims.begin(), rDims.end()); auto hasAllMapsSameSequence = [&](AffineExpr preExpr, AffineExpr nextExpr) { + // Check that all indexing maps of the `genericOp` + // - Either both `preExpr` and `nextExpr` contiguous, or + // - are missing in + // Then `preExpr` and `nextExpr` can be collapsed. for (AffineMap map : genericOp.getIndexingMapsArray()) { - bool foundSeq = false; + // If map has no results, no need to check. + if (map.getNumResults() == 0) { + continue; + } for (auto [index, resultExpr] : llvm::enumerate(map.getResults())) { + // If we find the preExpr, we should find the nextExpr. + if (resultExpr == preExpr) { + if (index == map.getNumResults() - 1) { + // Reached end of list. Return false; + return false; + } + if (map.getResult(index + 1) != nextExpr) { + return false; + } + } + // If we find nextExpr the previous one should be `prevExpr`. + // This is redundant check for the most part, but is cheap enough, so + // #YOLO if (resultExpr == nextExpr) { - foundSeq = (index > 0 && preExpr == map.getResult(index - 1)); - break; + if (index == 0) { + // match at beginning of the list. Return false; + return false; + } + if (map.getResult(index - 1) != preExpr) { + return false; + } } } - if (!foundSeq) - return false; } return true; }; + auto hasSameIteratorType = [&](AffineExpr preExpr, AffineExpr nextExpr) { + unsigned prePos = preExpr.cast().getPosition(); + unsigned nextPos = nextExpr.cast().getPosition(); + return (pDimsSet.count(prePos) && pDimsSet.count(nextPos)) || + (rDimsSet.count(prePos) && rDimsSet.count(nextPos)); + }; ReassociationIndices range; AffineExpr preExpr; + // Find the largest sequence of dimensions that are + // - Either preserved in all maps, or + // - are completely absent + // This sequence can be collapsed. To find the sequence, + // 1) Take the result expressions of one of the indexing maps + // 2) Find a sequence of 2 that is found in all maps + // 3) Then take last element of this sequence and the next + // result expression, and check if this sequence of 2 is + // found in all maps. If so, add to sequence (to get a sequence of 3) + // and repeat till the last element of sequence and the next result + // expression is not found as a sequence in all maps. for (auto nextExpr : genericOp.getIndexingMapsArray().front().getResults()) { - unsigned pos = nextExpr.cast().getPosition(); if (!range.empty()) { - if (!hasAllMapsSameSequence(preExpr, nextExpr) || !pLoops.count(pos)) { - if (range.size() > 1) + if (!hasAllMapsSameSequence(preExpr, nextExpr) || + !hasSameIteratorType(preExpr, nextExpr)) { + if (range.size() > 1) { contiguousLoops.push_back({range.begin(), range.end()}); + } range.clear(); } } + range.push_back(nextExpr.cast().getPosition()); preExpr = nextExpr; - if (pLoops.count(pos)) - range.push_back(pos); } if (range.size() > 1) contiguousLoops.push_back(range); @@ -107,22 +152,6 @@ getCollapsibleLoops(linalg::GenericOp genericOp) { return contiguousLoops; } -/// Collapse possible dimension of the given linalg.generic -static FailureOr> -collapseLinalgGeneric(IRRewriter &rewriter, linalg::GenericOp genericOp, - SmallVector &collapseIndices) { - rewriter.setInsertionPoint(genericOp->getParentOp()); - FailureOr> replacements = - mlir::linalg::collapseGenericOpIterationDims(genericOp, collapseIndices, - rewriter); - if (failed(replacements) || replacements->empty()) { - return rewriter.notifyMatchFailure(genericOp, - "failed to collapse dimensions"); - } - - return replacements; -} - /// Returns true if the given op is collapsable. static bool isEligibleForCollapse(linalg::GenericOp genericOp) { // TODO(guray) There is no mechanism to tell the collapsed indexes to @@ -154,101 +183,298 @@ static bool isEligibleForCollapse(linalg::GenericOp genericOp) { /// without any producers. static FailureOr findRootGenericOp(DispatchRegionOp regionOp) { - SmallVector computeOps; - auto &ops = regionOp.getBody().front().getOperations(); - for (Operation &op : ops) { - if (isa(op)) - computeOps.push_back(&op); + if (!llvm::hasSingleElement(regionOp.getBody())) { + return failure(); } - // Looking for root without producer - if (computeOps.size() != 1 || ops.size() != 2) + + // Check the yielded value is from a single `linalg.generic`. + auto returnOp = + cast(regionOp.getBody().front().getTerminator()); + auto collapsibleOp = dyn_cast_or_null( + returnOp->getOperand(0).getDefiningOp()); + if (!collapsibleOp) { return failure(); - auto genericOp = llvm::dyn_cast(computeOps.front()); - if (!genericOp) + } + for (auto returnVal : returnOp->getOperands().drop_front()) { + if (returnVal.getDefiningOp() != collapsibleOp.getOperation()) { + return failure(); + } + } + + // Check that the operands of the generic op are defined outside the dispatch. + for (OpOperand *inputOperands : collapsibleOp.getDpsInputOperands()) { + Operation *definingOp = inputOperands->get().getDefiningOp(); + if (definingOp && + definingOp->getParentOfType() == regionOp) { + return failure(); + } + } + + // Check that the output is either a `tensor.empty` or a `linalg.fill` op by + // traversing the operations that define the `init` operands of the + // `collapsibleOp`. + std::deque worklist; + llvm::SmallDenseSet visited; + auto addDefiningOpToWorklist = [&](Value v) { + Operation *definingOp = v.getDefiningOp(); + if (definingOp && + definingOp->getParentOfType() == regionOp && + !visited.count(definingOp)) { + worklist.push_back(definingOp); + visited.insert(definingOp); + } + }; + for (Value initOperand : collapsibleOp.getDpsInits()) { + addDefiningOpToWorklist(initOperand); + } + + while (!worklist.empty()) { + Operation *op = worklist.front(); + worklist.pop_front(); + if (auto fillOp = dyn_cast(op)) { + addDefiningOpToWorklist(fillOp.getDpsInitOperand(0)->get()); + continue; + } + if (isa(op)) { + continue; + } return failure(); - return genericOp; + } + return collapsibleOp; } -/// Generate a new dispatch.region and workload according with the collapsed -/// linalg Generic Op -static LogicalResult -generateNewDispatchRegion(IRRewriter &rewriter, DispatchRegionOp regionOp, - SmallVector collapseResults, - linalg::GenericOp newGenericOp) { +/// Hoist `tensor.collapse_shape` ops at the beginning of the `dispatchOp` +/// and `tensor.expand_shape` ops at the end of the `dispatchOp`, out of the +/// dispatch. +static FailureOr +hoistTensorReshapesOutOfDispatchRegion(RewriterBase &rewriter, + DispatchRegionOp dispatchOp) { + // Only do this for `dispatchOp` with a single operation. + if (!llvm::hasSingleElement(dispatchOp.getBody())) { + return failure(); + } + Block &body = dispatchOp.getBody().front(); + auto returnOp = cast(body.getTerminator()); + + // 1. Get the slice of operations within `dispatchOp` that produce the yielded + // value. + BackwardSliceOptions sliceOptions; + sliceOptions.filter = [&](Operation *op) { + return op->getParentOfType(); + }; + SetVector slice; + getBackwardSlice(returnOp, &slice, sliceOptions); + + // 2. Get the leaf operations that are tensor.collapse_shape ops. + SmallVector leafs; + for (Operation *op : slice) { + auto collapseShapeOp = dyn_cast(op); + if (!collapseShapeOp) { + continue; + } + if (llvm::all_of(op->getOperands(), [&](Value operand) { + Operation *definingOp = operand.getDefiningOp(); + return !definingOp || slice.count(definingOp) == 0; + })) { + leafs.push_back(collapseShapeOp); + } + } + + // 3. Clone the leaf `tensor.collapse_shape` ops outside the dispatch. OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(regionOp->getParentOp()); + rewriter.setInsertionPoint(dispatchOp); + for (auto reshapeOp : leafs) { + Operation *clonedOp = rewriter.clone(*reshapeOp.getOperation()); + rewriter.replaceOp(reshapeOp, clonedOp->getResults()); + } - auto maybeRegionOp = Flow::wrapOpInDispatchRegion(rewriter, newGenericOp); - if (failed(maybeRegionOp)) - return failure(); + // 4. From the yielded values find any that are produced by + // `tensor.expand_shape` operation and move them out of the dispatch. For + // this a new `DispatchRegionOp` is needed. For values that are yielded and + // produced from `tensor.expand_shape`, the type of the result changes. The + // dynamic dimensions of the result type also need to be updated. + SmallVector newReturnTypes; + SmallVector newDynamicDims; + SmallVector newYieldVals; + SmallVector> allReassociationIndices; + ValueRange dynamicDimsList = dispatchOp.getResultDims(); + Location loc = dispatchOp.getLoc(); + for (Value yieldedValue : returnOp->getOperands()) { + auto expandShapeOp = yieldedValue.getDefiningOp(); + if (!expandShapeOp) { + // 4a. Keep the same yield value if the producer is not a + // `tensor.expand_shape` op. + newReturnTypes.push_back(yieldedValue.getType()); + newYieldVals.push_back(yieldedValue); + continue; + } - // Replace old regionOp with the result of collapse - rewriter.replaceOp(regionOp, collapseResults); + // 4b. The return type is same as the type of the source of the + // `tensor.expand_shape`. + RankedTensorType collapsedShapeType = expandShapeOp.getSrcType(); + newReturnTypes.push_back(collapsedShapeType); + newYieldVals.push_back(expandShapeOp.getSrc()); + SmallVector reassociation = + expandShapeOp.getReassociationIndices(); + ArrayRef expandedShape = expandShapeOp.getResultType().getShape(); + + // 4c. Dynamic dims of the result shape is obtained by taking the static + // shape + dynamic dims and collapsing them using the same reassociation + // map as the `tensor.expand_shape`. + for (auto [index, shape] : llvm::enumerate(collapsedShapeType.getShape())) { + int64_t staticCollapsedShape = 1; + SmallVector dynamicCollapsedDims; + for (auto collapsedDim : reassociation[index]) { + if (expandedShape[collapsedDim] == ShapedType::kDynamic) { + dynamicCollapsedDims.push_back(dynamicDimsList.front()); + dynamicDimsList = dynamicDimsList.drop_front(); + } else { + staticCollapsedShape *= expandedShape[collapsedDim]; + } + } + + if (dynamicCollapsedDims.empty()) { + // If there are no dynamic dims, there is nothing to do. + continue; + } + SmallVector exprs(dynamicCollapsedDims.size()); + bindSymbolsList(rewriter.getContext(), + MutableArrayRef(exprs)); + AffineExpr multiplyAll = exprs.front(); + for (auto expr : ArrayRef(exprs).drop_front()) { + multiplyAll = multiplyAll * expr; + } + if (staticCollapsedShape != 1) { + multiplyAll = multiplyAll * staticCollapsedShape; + } + OpFoldResult collapsedShape = affine::makeComposedFoldedAffineApply( + rewriter, loc, multiplyAll, dynamicCollapsedDims); + newDynamicDims.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, collapsedShape)); + } + allReassociationIndices.emplace_back(std::move(reassociation)); + } + + // 5. Create the new dispatch op. + auto newDispatchOp = rewriter.create( + loc, newReturnTypes, newDynamicDims, dispatchOp.getWorkload()); + + // 5a. Move the body over, but replace the `flow.return` to use the new yield + // values. + Region &newBody = newDispatchOp.getBody(); + rewriter.inlineRegionBefore(dispatchOp.getBody(), newBody, newBody.begin()); + { + Operation *terminator = newBody.front().getTerminator(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp(terminator, newYieldVals); + } - return success(); + // 5b. Move the workgroup count region over. + Region &workgroupCountRegion = dispatchOp.getWorkgroupCount(); + if (!workgroupCountRegion.empty()) { + Region &newWorkgroupCountRegion = newDispatchOp.getWorkgroupCount(); + rewriter.inlineRegionBefore(workgroupCountRegion, newWorkgroupCountRegion, + newWorkgroupCountRegion.begin()); + } + + // 6. Map the modified result values back to their original shape using + // `tensor.expand_shape` operations. + ArrayRef> allReassociationIndicesRef( + allReassociationIndices); + for (auto [index, returnValue] : + llvm::enumerate(newDispatchOp.getResults())) { + Value origResult = dispatchOp->getResult(index); + if (returnValue.getType() == origResult.getType()) { + rewriter.replaceAllUsesWith(origResult, returnValue); + continue; + } + auto newExpandShapeOp = rewriter.create( + loc, origResult.getType(), returnValue, + allReassociationIndicesRef.front()); + allReassociationIndicesRef = allReassociationIndicesRef.drop_front(); + rewriter.replaceAllUsesWith(origResult, newExpandShapeOp.getResult()); + } + rewriter.eraseOp(dispatchOp); + return newDispatchOp; } /// Traverses DispatchRegionOps to find linalg genericOps that has no /// producers and tries to collapse its dimensions. -static LogicalResult collapseDimensions(IRRewriter &rewriter, - DispatchRegionOp ®ionOp) { +static bool collapseDimensions(IRRewriter &rewriter, + DispatchRegionOp ®ionOp) { // Step 1. Find the root linalg.generic Op with no producer std::optional genericOp = findRootGenericOp(regionOp); if (!genericOp.has_value()) - return success(); + return false; // Step 2. Check whether it is possible to collapse if (!isEligibleForCollapse(genericOp.value())) - return success(); + return false; SmallVector collapseIndices; collapseIndices = getCollapsibleLoops(genericOp.value()); if (collapseIndices.empty()) - return success(); + return false; // Step 3. Collapse dimensions - auto maybeReplacements = - collapseLinalgGeneric(rewriter, genericOp.value(), collapseIndices); - if (failed(maybeReplacements)) - return failure(); - auto expandshapeOp = - maybeReplacements->front().getDefiningOp(); - if (!expandshapeOp) - return failure(); - auto newGenericOp = - expandshapeOp.getOperand().getDefiningOp(); - if (!newGenericOp) - return failure(); - - // Step 4. Generate new dispatch region and replace old one users - if (failed(generateNewDispatchRegion(rewriter, regionOp, *maybeReplacements, - newGenericOp))) - return failure(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(genericOp.value()); - return success(); + FailureOr> maybeReplacements = + mlir::linalg::collapseGenericOpIterationDims(genericOp.value(), + collapseIndices, rewriter); + if (failed(maybeReplacements)) + return false; + rewriter.replaceOp(genericOp.value(), maybeReplacements.value()); + return true; } void CollapseDimensionsPass::runOnOperation() { mlir::FunctionOpInterface funcOp = getOperation(); - IRRewriter rewriter(funcOp->getContext()); + MLIRContext *context = funcOp->getContext(); + IRRewriter rewriter(context); + + SmallVector modifiedDispatchOps; + funcOp->walk([&](DispatchRegionOp dispatchOp) { + if (collapseDimensions(rewriter, dispatchOp)) { + modifiedDispatchOps.push_back(dispatchOp); + } + }); - auto walkResult = funcOp->walk([&](DispatchRegionOp regionOp) { - if (failed(collapseDimensions(rewriter, regionOp))) - return WalkResult::interrupt(); - return WalkResult::advance(); + LLVM_DEBUG({ + llvm::dbgs() << "[CollapseDims] : After collapsing generic ops: \n"; + funcOp.print(llvm::dbgs()); + llvm::dbgs() << "\n"; }); - if (walkResult.wasInterrupted()) { - funcOp->emitOpError("failed in collapsing dimensions pass"); - return signalPassFailure(); - } - RewritePatternSet canonicalizationPatterns(&getContext()); - memref::populateResolveRankedShapedTypeResultDimsPatterns( - canonicalizationPatterns); - tensor::populateFoldTensorEmptyPatterns(canonicalizationPatterns); - if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(canonicalizationPatterns)))) { - funcOp->emitOpError("failed to apply cleanup patterns"); - return signalPassFailure(); + // Move all the `tensor.collapse_shape` leafs and `tensor.expand_shape` roots + // of the modified dispatches out of the dispatch. + for (auto dispatchOp : modifiedDispatchOps) { + Region &body = dispatchOp.getBody(); + assert(llvm::hasSingleElement(body) && "expected op with a single body"); + Block &block = body.front(); + RewritePatternSet moveReshapeOps(&getContext()); + linalg::FillOp::getCanonicalizationPatterns(moveReshapeOps, context); + memref::populateResolveRankedShapedTypeResultDimsPatterns(moveReshapeOps); + tensor::populateFoldTensorEmptyPatterns(moveReshapeOps); + SmallVector candidateOps; + block.walk([&](Operation *op) { + if (isa(op)) { + candidateOps.push_back(op); + } + }); + if (failed( + applyOpPatternsAndFold(candidateOps, std::move(moveReshapeOps)))) { + funcOp.emitOpError( + "failed to propagate reshape ops introduced during collapse"); + return signalPassFailure(); + } + + if (failed(hoistTensorReshapesOutOfDispatchRegion( + rewriter, cast(dispatchOp)))) { + dispatchOp->emitOpError("failed to hoist reshapes out of dispatch"); + return signalPassFailure(); + } } } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp index 20746e728650..a3f1f6747fcb 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FusionOfTensorOps.cpp @@ -114,10 +114,16 @@ static bool areFusableOps(MLIRContext *context, OpOperand *fusedOperand) { // broadcast this ends up redundantly computing operations without more // parallelism. if (auto linalgConsumerOp = dyn_cast(consumerOp)) { - return linalgConsumerOp.getNumParallelLoops() == - linalgConsumerOp.getNumLoops() || - linalgConsumerOp.getMatchingIndexingMap(fusedOperand) - .isPermutation(); + if (linalgConsumerOp.getNumParallelLoops() == + linalgConsumerOp.getNumLoops()) { + return true; + } + if (linalgConsumerOp.getNumReductionLoops() != 1 || + !linalgConsumerOp.getMatchingIndexingMap(fusedOperand) + .isPermutation()) { + return false; + } + return true; } // All other cases dont fuse. diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir index 21b6fe95133a..5e4ec9b3eba1 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/collapse_linalg_generic_on_tensors.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-collapse-dimensions))" %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-flow-form-dispatch-regions{fuse-multi-use=true}, iree-flow-clone-producers-into-dispatch-regions, iree-flow-collapse-dimensions, cse))" %s | FileCheck %s !type = tensor<2x4x8x16x32x64xf32> util.global private @"__transpose_10_input" {noinline} = dense<1.0> : !type @@ -23,14 +23,14 @@ func.func @collapse1() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: func.func @collapse1 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} -// CHECK: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @collapse1 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2x4x8x16x32x64xf32> into tensor<2097152xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<2097152xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<2097152xf32>) outs(%[[OUT]] : tensor<2097152xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2, 3, 4, 5]] : tensor<2097152xf32> into tensor<2x4x8x16x32x64xf32> // ----- @@ -58,15 +58,15 @@ func.func @collapse2() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func.func @collapse2 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func.func @collapse2 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<2x4x8x32x32x64x128xf32> into tensor<8x8x32x32x8192xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x32x32x8192xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<8x8x32x32x8192xf32>) outs(%[[OUT]] : tensor<8x8x32x32x8192xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5, 6]] : tensor<8x8x32x32x8192xf32> into tensor<2x4x8x32x32x64x128xf32> // ----- !type = tensor<2x4x8x16x32x64x128x256xf32> @@ -93,14 +93,14 @@ func.func @collapse3() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-LABEL: func.func @collapse3 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func.func @collapse3 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<2x4x8x16x32x64x128x256xf32> into tensor<8x8x1073741824xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x1073741824xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel", "reduction", "parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<8x8x1073741824xf32>) outs(%[[OUT]] : tensor<8x8x1073741824xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3, 4, 5, 6, 7]] : tensor<8x8x1073741824xf32> into tensor<2x4x8x16x32x64x128x256xf32> // ----- @@ -127,15 +127,15 @@ func.func @collapse4() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> -// CHECK-LABEL: func.func @collapse4 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> +// CHECK-LABEL: func.func @collapse4 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x8x16x64x64x128x256xf32> into tensor<8x8x16x64x64x32768xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x8x16x64x64x32768xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<8x8x16x64x64x32768xf32>) outs(%[[OUT]] : tensor<8x8x16x64x64x32768xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x8x16x64x64x32768xf32> into tensor<2x4x8x16x64x64x128x256xf32> // ----- @@ -167,18 +167,18 @@ func.func @collapse5() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)> -// CHECK-LABEL: func.func @collapse5 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]} -// CHECK: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d2, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d2, d1, d4, d5)> +// CHECK-LABEL: func.func @collapse5 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[IN1:.+]] = tensor.collapse_shape %[[INPUT1:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[IN2:.+]] = tensor.collapse_shape %[[INPUT2:.+]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<2x4x32x32x32x64x128x256xf32> into tensor<8x32x32x32x64x32768xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x32x32x32x64x32768xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]], #[[$MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "parallel"]} +// CHECK-SAME: ins(%[[IN]], %[[IN1]], %[[IN2]] : tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>, tensor<8x32x32x32x64x32768xf32>) outs(%[[OUT]] : tensor<8x32x32x32x64x32768xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1], [2], [3], [4], [5], [6, 7]] : tensor<8x32x32x32x64x32768xf32> into tensor<2x4x32x32x32x64x128x256xf32> // ----- @@ -205,15 +205,15 @@ func.func @collapse6() -> !type { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> -// CHECK-LABEL: func.func @collapse6 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5)> +// CHECK-LABEL: func.func @collapse6 +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x4x8x16x16x64x128xf32> into tensor<32x2x32x16x16x8192xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<32x2x32x16x16x8192xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<32x2x32x16x16x8192xf32>) outs(%[[OUT]] : tensor<32x2x32x16x16x8192xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1], [2, 3], [4], [5], [6, 7]] : tensor<32x2x32x16x16x8192xf32> into tensor<32x2x4x8x16x16x64x128xf32> // ----- @@ -239,24 +239,23 @@ func.func @collapse7() -> !type_out { return %result: !type_out } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d1, d0)> // CHECK-LABEL: func.func @collapse7 -// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>) -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32> +// CHECK: %[[IN:.+]] = tensor.collapse_shape %[[INPUT:.+]] {{\[}}[0, 1, 2]] : tensor<2x4x8xf32> into tensor<64xf32> +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<64x16xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} +// CHECK-SAME: ins(%[[IN]] : tensor<64xf32>) outs(%[[OUT]] : tensor<64x16xf32>) +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0, 1, 2], [3]] : tensor<64x16xf32> into tensor<2x4x8x16xf32> // ----- !type_in = tensor<16x4x32x2xf32> !type_out = tensor<8x16x4x32x8x2xf32> -func.func @collapse8() -> !type_out { +func.func @collapse8(%input : !type_in) -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index - %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d3, d0, d1) %6 = linalg.generic { indexing_maps = [ @@ -272,15 +271,16 @@ func.func @collapse8() -> !type_out { return %6: !type_out } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func.func @collapse8 -// CHECK: %[[IN:.+]] = tensor.empty() : tensor<2048x2xf32> -// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32> -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} -// CHECK: ins(%[[IN]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32 -// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32> +// CHECK-SAME: (%[[IN:.+]]: tensor<16x4x32x2xf32>) +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1, 2], [3]{{\]}} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<8x2048x8x2xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[COLLAPSE]] : tensor<2048x2xf32>) outs(%[[OUT]] : tensor<8x2048x8x2xf32 +// CHECK: tensor.expand_shape %[[RES]] {{\[}}[0], [1, 2, 3], [4], [5]] : tensor<8x2048x8x2xf32> into tensor<8x16x4x32x8x2xf32> // ----- @@ -304,7 +304,7 @@ func.func @dont_collapse() -> !type_out { return %6: !type_out } // CHECK-LABEL: func.func @dont_collapse -// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]], #[[$MAP2:.+]]], iterator_types = ["parallel", "parallel", "parallel"]} // ----- @@ -333,11 +333,11 @@ func.func @collapse9() -> !type_out { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d3, d5)> // CHECK-LABEL: func.func @collapse9 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel", "parallel", "parallel"]} // ----- @@ -345,10 +345,9 @@ func.func @collapse9() -> !type_out { !type_in = tensor<10x10x30xf32> !type_out = tensor<20x10x10x30x20xf32> -func.func @collapse10() -> !type_out { +func.func @collapse10(%input : !type_in) -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index - %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d1, d3, d0) @@ -364,21 +363,18 @@ func.func @collapse10() -> !type_out { return %result: !type_out } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0)> -// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> // CHECK-LABEL: func.func @collapse10 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel"]} // ----- !type_in = tensor<10x20xf32> !type_out = tensor<10x20xf32> -func.func @collapse11() -> !type_out { +func.func @collapse11(%input : !type_in) -> !type_out { %cst = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index - %input = tensor.empty() : !type_in %output = tensor.empty() : !type_out // Can collapse (d1, d0) @@ -394,10 +390,10 @@ func.func @collapse11() -> !type_out { } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @collapse11 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} // ----- @@ -420,7 +416,7 @@ func.func @dont_collapse_dueto_index(%height : index, %width : index) -> !type { } // CHECK-LABEL: func.func @dont_collapse -// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]} +// CHECK: linalg.generic {indexing_maps = [#[[$MAP:.+]]], iterator_types = ["parallel", "parallel"]} // ----- @@ -456,8 +452,146 @@ func.func @collapse12() -> (!type,!type,!type,!type) { return %6, %7, %8, %9 : !type,!type,!type,!type } -// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @collapse12 -// CHECK: %[[RES:.+]] = flow.dispatch.region -// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} +// CHECK: %[[RES:.+]] = flow.dispatch.region +// CHECK: linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} + +// ----- + +func.func @multi_reduce_dim(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} { + %cst = arith.constant -0.000000e+00 : f32 + %0 = hal.tensor.import %arg0 : !hal.buffer_view -> tensor<2x32x10x4096xf32> + %1 = tensor.empty() : tensor<2x32xf32> + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<2x32xf32>) -> tensor<2x32xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0 : tensor<2x32x10x4096xf32>) outs(%2 : tensor<2x32xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %6 = arith.addf %arg1, %arg2 : f32 + linalg.yield %6 : f32 + } -> tensor<2x32xf32> + %4 = tensor.expand_shape %3 [[0], [1, 2, 3]] : tensor<2x32xf32> into tensor<2x32x1x1xf32> + %5 = hal.tensor.export %4 : tensor<2x32x1x1xf32> -> !hal.buffer_view + return %5 : !hal.buffer_view +} + +// Check that we collapse dimensions. +// CHECK-LABEL: @multi_reduce_dim( +// CHECK-DAG: %[[ARG0:.+]] = hal.tensor.import +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[COLLAPSE]] : +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: flow.return %[[GENERIC]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[DISPATCH]] {{\[}}[0, 1]{{\]}} + +// ----- + +// Collapsing is not supported when an input is broadcasted; we can't collapse +// the input from tensor<4xf32> to tensor<32xf32> for example. + +func.func @input_broadcast(%arg0: tensor<4x8xf32>, %arg1: tensor<4xf32>) -> tensor { + %empty = tensor.empty() : tensor + %reduce = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> ()>], iterator_types = ["reduction", "reduction"]} ins(%arg0, %arg1 : tensor<4x8xf32>, tensor<4xf32>) outs(%empty : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %out: f32): + %div = arith.divf %arg2, %arg3 : f32 + %add = arith.addf %out, %div : f32 + linalg.yield %add : f32 + } -> tensor + return %reduce : tensor +} +// CHECK: @input_broadcast +// CHECK-NOT: tensor.collapse_shape + +// ----- + +// Do nothing if the dispatch is not a single elementwise op (with tensor.empty/linalg.fill producers) + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3, d4)> +#map3 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> +#map4 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> +module { + func.func @quantized_matmul(%arg0: tensor<4096x32x128xi8>, %arg1: tensor<1x1x32x128xf32>) -> tensor<1x1x4096xf32> { + %cst = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %cst_0 = arith.constant dense_resource<__elided__> : tensor<4096x32xf32> + %0 = flow.dispatch.region -> (tensor<1x1x4096xf32>) { + %cst_1 = arith.constant 0.000000e+00 : f32 + %1 = tensor.empty() : tensor<1x1x4096xf32> + %2 = tensor.empty() : tensor<4096x32x128xf32> + %3 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x1x4096xf32>) -> tensor<1x1x4096xf32> + %4 = linalg.generic {indexing_maps = [#map, #map1, #map1, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %cst, %cst_0 : tensor<4096x32x128xi8>, tensor<4096x32xf32>, tensor<4096x32xf32>) outs(%2 : tensor<4096x32x128xf32>) { + ^bb0(%in: i8, %in_2: f32, %in_3: f32, %out: f32): + %6 = arith.extui %in : i8 to i32 + %7 = arith.uitofp %6 : i32 to f32 + %8 = arith.subf %7, %in_3 : f32 + %9 = arith.mulf %8, %in_2 : f32 + linalg.yield %9 : f32 + } -> tensor<4096x32x128xf32> + %5 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%arg1, %4 : tensor<1x1x32x128xf32>, tensor<4096x32x128xf32>) outs(%3 : tensor<1x1x4096xf32>) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %6 = arith.mulf %in, %in_2 : f32 + %7 = arith.addf %6, %out : f32 + linalg.yield %7 : f32 + } -> tensor<1x1x4096xf32> + flow.return %5 : tensor<1x1x4096xf32> + } + return %0 : tensor<1x1x4096xf32> + } +} + +// CHECK-LABEL: func.func @quantized_matmul +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] +// CHECK: flow.return +// CHECK: return %[[DISPATCH]] + +// ----- + +module { + func.func @batchnorm_failure_repro(%arg0 : tensor<2x4xf32>, %arg1 : tensor<4xf32>) -> tensor<2x4xf32> { + %0 = tensor.empty() : tensor<2x4xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<2x4xf32>, tensor<4xf32>) outs(%0 : tensor<2x4xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<2x4xf32> + return %1 : tensor<2x4xf32> + } +} +// CHECK-LABEL: func @batchnorm_failure_repro +// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK: flow.return %[[GENERIC]] +// CHECK: return %[[DISPATCH]] + +// ----- + +module { + func.func @catch_invalid_collapse(%arg0 : tensor<10x20x30xf32>) -> tensor<10x30x40xf32> { + %0 = tensor.empty() : tensor<10x30x40xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<10x20x30xf32>) outs(%0 : tensor<10x30x40xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0 : f32 + } -> tensor<10x30x40xf32> + return %1 : tensor<10x30x40xf32> + } +} +// CHECK-LABEL: func @catch_invalid_collapse +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]