diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index b10e605a2335..2f5a48b1d986 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -29,7 +29,7 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint, Operation *op) { auto linalgOp = dyn_cast(op); - if (!linalgOp) { + if (!linalgOp || !linalg::isaContractionOpInterface(linalgOp)) { return failure(); } @@ -39,23 +39,20 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, const int64_t targetSubgroupSize = target.getPreferredSubgroupSize(); SmallVector bounds = linalgOp.getStaticLoopRanges(); - FailureOr contractionDims = - mlir::linalg::inferContractionDims(linalgOp); - if (failed(contractionDims)) { - return failure(); - } + mlir::linalg::ContractionDimensions contractionDims = + mlir::linalg::inferContractionDims(linalgOp).value(); - if (contractionDims->k.empty() || contractionDims->m.empty() || - contractionDims->n.empty()) { + if (contractionDims.k.empty() || contractionDims.m.empty() || + contractionDims.n.empty()) { return failure(); } // For now we are not being smart and trying to reshape dimensions to allow // for better usage of intrinsics, and instead are tiling all dimensions // except the inner most m, n, and k dimensions to 1. - int64_t mDim = contractionDims->m.back(); - int64_t nDim = contractionDims->n.back(); - int64_t kDim = contractionDims->k.back(); + int64_t mDim = contractionDims.m.back(); + int64_t nDim = contractionDims.n.back(); + int64_t kDim = contractionDims.k.back(); // Dynamic dims are expected to be taken care of earlier in the pipeline. if (ShapedType::isDynamic(bounds[mDim]) || @@ -159,19 +156,19 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target, SmallVector reductionTileSizes(linalgOp.getNumLoops(), 0); SmallVector subgroupTileSizes(linalgOp.getNumLoops(), 0); // Tile all batch dimensions with unit size. - for (int64_t batch : contractionDims->batch) { + for (int64_t batch : contractionDims.batch) { workgroupTileSizes[batch] = 1; } // Tile all m, n, and k dimensions to 1 except the innermost. Unit dims // from this tiling are folded before vectorization. - for (int64_t m : llvm::drop_end(contractionDims->m)) { + for (int64_t m : llvm::drop_end(contractionDims.m)) { workgroupTileSizes[m] = 1; } - for (int64_t n : llvm::drop_end(contractionDims->n)) { + for (int64_t n : llvm::drop_end(contractionDims.n)) { workgroupTileSizes[n] = 1; } - for (int64_t k : llvm::drop_end(contractionDims->k)) { + for (int64_t k : llvm::drop_end(contractionDims.k)) { reductionTileSizes[k] = 1; }