Skip to content

Commit

Permalink
Make TopK work with arbitrary rank (iree-org#15268)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman authored Oct 23, 2023
1 parent 99dc6bc commit e74287f
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1777,15 +1777,15 @@ struct ApproxTopK final : OpRewritePattern<mlir::stablehlo::CustomCallOp> {
if (!funcOp)
return rewriter.notifyMatchFailure(op, "computation function not found.");

int64_t k = cast<ShapedType>(op.getType(0)).getDimSize(1);
int64_t k = cast<ShapedType>(op.getType(0)).getShape().back();
auto input = op.getOperand(0);
auto iota = op.getOperand(1);

if (auto iotaOp =
dyn_cast_or_null<mlir::stablehlo::IotaOp>(iota.getDefiningOp())) {
int64_t iotaDim = iotaOp.getIotaDimension();
auto iotaLastDim = cast<ShapedType>(iotaOp.getType()).getRank() - 1;
if (iotaDim != iotaLastDim || iotaLastDim != 1) {
if (iotaDim != iotaLastDim) {
return rewriter.notifyMatchFailure(op, "Iota of last dim not found.");
}
}
Expand Down

0 comments on commit e74287f

Please sign in to comment.