Skip to content

Commit

Permalink
[Flow] Add patterns to lower tensor.reshape to flow.tensor.reshape (i…
Browse files Browse the repository at this point in the history
…ree-org#15226)

`tensor.reshape` takes a 1D tensor to specify the target shape, thereby supporting unranked reshapes. IREE does not support this, however in practice unranked tensors are rarely generated by frontends so here we simply opt to extract the necessary tensor extents to a static list (hopefully folding with a `tensor.from_elements`).
  • Loading branch information
saienduri authored Oct 27, 2023
1 parent c878412 commit 222bcf2
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,61 @@ struct ConvertTensorFromElementsPattern
}
};

/// Returns a sizes array with the dynamic dims.
static SmallVector<Value> getDynamicTensorSizes(OpBuilder &builder,
Location loc,
RankedTensorType type,
Value tensor) {
SmallVector<Value> sizes;
for (const auto [idx, size] : enumerate(type.getShape())) {
if (type.isDynamicDim(idx)) {
Value dim = builder.create<tensor::DimOp>(loc, tensor, idx);
sizes.push_back(dim);
}
}
return sizes;
}

/// Convert tensor.reshape ops into flow.tensor.reshape ops where possible.
struct ConvertTensorDialectReshapeOpPattern
: public OpRewritePattern<tensor::ReshapeOp> {
using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ReshapeOp op,
PatternRewriter &rewriter) const override {
if (op->getParentOfType<Flow::DispatchWorkgroupsOp>()) {
return failure();
}
auto loc = op.getLoc();
Value input = op.getSource();
auto inputType = dyn_cast<RankedTensorType>(input.getType());
auto shapeOperandType = dyn_cast<ShapedType>(op.getShape().getType());
auto resultType = dyn_cast<ShapedType>(op.getResult().getType());

if (!inputType) {
return rewriter.notifyMatchFailure(op, "not ranked shaped types");
}

SmallVector<Value> srcSizes;
srcSizes = getDynamicTensorSizes(rewriter, loc, inputType, input);

// flow.reshape only takes dynamic dims for the result, source dims
// (ignore static dimensions)
SmallVector<Value> destSizes;
for (int i = 0; i < shapeOperandType.getShape()[0]; i++) {
Value idx = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value element = rewriter.create<tensor::ExtractOp>(loc, op.getShape(),
ValueRange({idx}));
if (ShapedType::isDynamic(resultType.getShape()[i])) {
destSizes.push_back(element);
}
}

rewriter.replaceOpWithNewOp<IREE::Flow::TensorReshapeOp>(
op, resultType, input, srcSizes, destSizes);
return success();
}
};

/// Converts linalg.tensor_reshape operations into flow.tensor.reshape
/// operations.
template <typename TensorReshapeOp>
Expand Down Expand Up @@ -248,6 +303,7 @@ void populateTensorToFlowConversionPatterns(MLIRContext *context,
ConvertTensorCastPattern, ConvertTensorExtractPattern,
ConvertTensorExtractSlicePattern, ConvertTensorInsertSlicePattern,
ConvertTensorInsertPattern, ConvertTensorFromElementsPattern,
ConvertTensorDialectReshapeOpPattern,
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,69 @@ func.func @turn_fill_into_splat(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg
// CHECK-DAG: %[[RD1:.+]] = affine.apply #[[MAP]](%[[D1]])[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SPLAT:.+]] = flow.tensor.splat %[[VAL]] : tensor<?x?xf32>{%[[RD0]], %[[RD1]]}
// CHECK: flow.tensor.update %[[ARG0]], %[[SPLAT]]

// -----

func.func @static_tensor_reshape(%arg0: tensor<2x4xf32>, %arg1: tensor<2xindex>) -> tensor<1x8xf32> {
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<2x4xf32> -> tensor<1x8xf32>
// CHECK: return %[[RESULT]]
%0 = tensor.reshape %arg0(%arg1)
: (tensor<2x4xf32>, tensor<2xindex>) -> tensor<1x8xf32>
return %0 : tensor<1x8xf32> }

// -----

func.func @dynamic_tensor_reshape(%arg0: tensor<2x4xf32>, %arg1: tensor<2xindex>) -> tensor<?x?xf32> {
// CHECK: func.func @dynamic_tensor_reshape
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2xindex>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL:.+]] = flow.tensor.load %[[ARG1]][%[[C0]]] : tensor<2xindex>
// CHECK-DAG: %[[VAL1:.+]] = flow.tensor.load %[[ARG1]][%[[C1]]] : tensor<2xindex>
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<2x4xf32> -> tensor<?x?xf32>{%[[VAL]], %[[VAL1]]}
// CHECK: return %[[RESULT]]
%0 = tensor.reshape %arg0(%arg1)
: (tensor<2x4xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32> }

// -----

func.func @mix_dynamic_and_static_tensor_reshape(%arg0: tensor<2x4xf32>, %arg1: tensor<2xindex>) -> tensor<1x?xf32> {
// CHECK: func.func @mix_dynamic_and_static_tensor_reshape
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x4xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2xindex>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL:.+]] = flow.tensor.load %[[ARG1]][%[[C1]]] : tensor<2xindex>
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<2x4xf32> -> tensor<1x?xf32>{%[[VAL]]}
// CHECK: return %[[RESULT]]
%0 = tensor.reshape %arg0(%arg1)
: (tensor<2x4xf32>, tensor<2xindex>) -> tensor<1x?xf32>
return %0 : tensor<1x?xf32> }

// -----

func.func @dynamic_input_and_output_tensor_reshape(%arg0: tensor<?x4xf32>, %arg1: tensor<2xindex>) -> tensor<1x?xf32> {
// CHECK: func.func @dynamic_input_and_output_tensor_reshape
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x4xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<2xindex>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4xf32>
// CHECK-DAG: %[[VAL:.+]] = flow.tensor.load %[[ARG1]][%[[C1]]] : tensor<2xindex>
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %[[ARG0]] : tensor<?x4xf32>{%[[D0]]} -> tensor<1x?xf32>{%[[VAL]]}
// CHECK: return %[[RESULT]]
%0 = tensor.reshape %arg0(%arg1)
: (tensor<?x4xf32>, tensor<2xindex>) -> tensor<1x?xf32>
return %0 : tensor<1x?xf32> }

// -----
func.func @from_elements_test_reshape(%arg0: tensor<?x4xf32>, %arg1: index, %arg2: index) -> tensor<?x1xf32> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[D1:.*]] = tensor.dim %arg0, %[[C0:.*]] : tensor<?x4xf32>
// CHECK-DAG: %[[RESULT:.*]] = flow.tensor.reshape %arg0 : tensor<?x4xf32>{%[[D1]]} -> tensor<?x1xf32>{%arg1}
// CHECK: return %[[RESULT]]
%0 = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
%1 = tensor.reshape %arg0(%0)
: (tensor<?x4xf32>, tensor<2xindex>) -> tensor<?x1xf32>
return %1 : tensor<?x1xf32> }

0 comments on commit 222bcf2

Please sign in to comment.