diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp index 3de3216abfde..10d426585162 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp @@ -7,7 +7,6 @@ #include #include -#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h" #include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Codegen/Common/GPU/PassDetail.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" @@ -29,9 +28,6 @@ #define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy" -using mlir::iree_compiler::IREE::LinalgExt::LinalgVectorizationPattern; -using mlir::iree_compiler::IREE::LinalgExt::VectorizationPatterns; - /// Prints the given `funcOp` after a leading `step` comment header. void debugPrint(mlir::func::FuncOp funcOp, const char *step) { LLVM_DEBUG({ @@ -274,14 +270,17 @@ static void populateTilingAndDistribute(RewritePatternSet &patterns, StringAttr::get(patterns.getContext(), kCopyDistributed))); } -static void populateVectorizationPatterns(RewritePatternSet &patterns) { - VectorizationPatterns::insert( - patterns, IREE::LinalgExt::LinalgVectorizationOptions(), - IREE::LinalgExt::LinalgTransformationFilter( - {StringAttr::get(patterns.getContext(), - getCopyToWorkgroupMemoryMarker()), - StringAttr::get(patterns.getContext(), kCopyDistributed)}, - std::nullopt)); +static void vectorizeDistributedCopies(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + SmallVector candidates; + funcOp.walk([&](linalg::GenericOp op) { candidates.push_back(op); }); + for (auto op : candidates) { + SmallVector vectorSizes; + SmallVector scalableVecDims; + scalableVecDims.resize(vectorSizes.size()); + (void)linalg::vectorize(rewriter, op, vectorSizes, scalableVecDims, + /*vectorizeGatherAccesses=*/true); + }; } /// Return a flattened Id Value by combining the 3D gpu thread IDs. @@ -436,12 +435,7 @@ class GPUDistributeSharedMemoryCopyPass debugPrint(funcOp, "After step 2: thread distribution"); // Step 3. Vectorize the distributed copies. - RewritePatternSet vectorizationPatterns(context); - populateVectorizationPatterns(vectorizationPatterns); - if (failed(applyPatternsAndFoldGreedily( - funcOp, std::move(vectorizationPatterns)))) { - return signalPassFailure(); - } + vectorizeDistributedCopies(funcOp); debugPrint(funcOp, "After step 3: vectorization"); // Step4. Finally unroll all the loop created diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h index 70aba5cb98dd..d6ec010dffc5 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h @@ -169,59 +169,6 @@ class TilingPatterns { } }; -/// -/// Linalg vectorization patterns. -/// -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern - : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgVectorizationOptions opts = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions opts = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgVectorizationOptions options; - LinalgTransformationFilter filter; -}; - -template -class VectorizationPatterns; - -template <> -class VectorizationPatterns<> { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &opts, - const LinalgTransformationFilter &f) {} -}; - -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &opts, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), opts, f); - VectorizationPatterns::insert(patterns, opts, f); - } -}; - /// /// Linalg promotion patterns. /// diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp index a344d5459796..cd6375628b06 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Transforms/Transforms.cpp @@ -100,32 +100,6 @@ LinalgTilingPattern::returningMatchAndRewrite(linalg::LinalgOp op, return res; } -LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgVectorizationOptions opts, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(opts)), filter(std::move(f)) {} - -LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgVectorizationOptions opts, - LinalgExt::LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(opts)), filter(f.addOpNameFilter(opName)) {} - -LogicalResult -LinalgVectorizationPattern::matchAndRewrite(linalg::LinalgOp linalgOp, - PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - SmallVector vectorSizes; - if (options.enableVectorMasking) - vectorSizes.append(options.vectorSizeComputationFunction( - linalgOp, options.canonicalVectorSizes)); - SmallVector scalableVecDims(vectorSizes.size(), false); - return vectorize(rewriter, linalgOp, vectorSizes, scalableVecDims, - options.vectorizeGatherAccesses); -} - } // namespace LinalgExt } // namespace IREE } // namespace iree_compiler