From f0d24cdab1bb931e54a9d517707375cb96f96543 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Fri, 12 Jul 2024 08:35:02 -0700 Subject: [PATCH] [Global opt] add flag to generalize matmul ops (#17877) Helps when the producer is a broadcast op. After adding the flag to sdxl scripts, I saw a decent decrease in the number of dispatches. Initially, I was trying to manually generalize+fuse broadcasts [branch here](https://github.com/IanWood1/iree/tree/broadcast_matmul), but quinn saw good results with just this. --------- Signed-off-by: Ian Wood --- .../compiler/Preprocessing/Common/BUILD.bazel | 1 + .../Preprocessing/Common/CMakeLists.txt | 1 + .../Common/GeneralizeLinalgMatMul.cpp | 54 +++++++++++++++++++ .../compiler/Preprocessing/Common/Passes.td | 8 +++ .../Preprocessing/Common/test/BUILD.bazel | 1 + .../Preprocessing/Common/test/CMakeLists.txt | 1 + .../Common/test/generalize_linalg_matmul.mlir | 12 +++++ 7 files changed, 78 insertions(+) create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp create mode 100644 compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir diff --git a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel index e004a550f728..1692c78bf800 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/BUILD.bazel @@ -33,6 +33,7 @@ iree_compiler_cc_library( "ApplyPDLPatterns.cpp", "ConvertConv2DToImg2Col.cpp", "ConvertConvToChannelsLast.cpp", + "GeneralizeLinalgMatMul.cpp", "InterpreterPass.cpp", "MakeSingleDispatchForFunction.cpp", "PadLinalgOps.cpp", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt index c9c127ccca57..4613d4bb404b 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/CMakeLists.txt @@ -29,6 +29,7 @@ iree_cc_library( "ApplyPDLPatterns.cpp" "ConvertConv2DToImg2Col.cpp" "ConvertConvToChannelsLast.cpp" + "GeneralizeLinalgMatMul.cpp" "InterpreterPass.cpp" "MakeSingleDispatchForFunction.cpp" "PadLinalgOps.cpp" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp new file mode 100644 index 000000000000..a533339875e0 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/GeneralizeLinalgMatMul.cpp @@ -0,0 +1,54 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::iree_compiler::Preprocessing { + +#define GEN_PASS_DEF_GENERALIZELINALGMATMULPASS +#include "iree/compiler/Preprocessing/Common/Passes.h.inc" // IWYU pragma: export + +namespace { + +struct GeneralizeLinalgMatMulPass + : public iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase< + GeneralizeLinalgMatMulPass> { + using iree_compiler::Preprocessing::impl::GeneralizeLinalgMatMulPassBase< + GeneralizeLinalgMatMulPass>::GeneralizeLinalgMatMulPassBase; + void runOnOperation() override { + auto funcOp = getOperation(); + SmallVector namedOpCandidates; + funcOp.walk([&](linalg::LinalgOp linalgOp) { + if (!IREE::Flow::isNonNullAndOutsideDispatch(linalgOp)) { + return; + } + if (isa_and_nonnull(linalgOp)) { + namedOpCandidates.push_back(linalgOp); + } + }); + + IRRewriter rewriter(&getContext()); + + for (auto linalgOp : namedOpCandidates) { + rewriter.setInsertionPoint(linalgOp); + FailureOr generalizedOp = + linalg::generalizeNamedOp(rewriter, linalgOp); + if (failed(generalizedOp)) { + linalgOp->emitOpError("failed to generalize operation"); + return signalPassFailure(); + } + } + } +}; +} // namespace +} // namespace mlir::iree_compiler::Preprocessing diff --git a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td index e3316f09a653..e4921b81fe88 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/Passes.td +++ b/compiler/src/iree/compiler/Preprocessing/Common/Passes.td @@ -132,4 +132,12 @@ def TransposeMatmulPass : Pass<"iree-preprocessing-transpose-matmul-pass"> { ]; } +def GeneralizeLinalgMatMulPass : + InterfacePass<"iree-preprocessing-generalize-linalg-matmul-experimental", "mlir::FunctionOpInterface"> { + let summary = "Convert linalg matmul ops to linalg.generics."; + let dependentDialects = [ + "mlir::linalg::LinalgDialect", + ]; +} + #endif // IREE_PREPROCESSING_COMMON_PASSES diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel index 3a5324f80696..54ebb1176caa 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/BUILD.bazel @@ -18,6 +18,7 @@ iree_lit_test_suite( [ "conv2d_to_img2col.mlir", "conv_to_channels_last.mlir", + "generalize_linalg_matmul.mlir", "make_single_dispatch_for_function.mlir", "pad_linalg_ops.mlir", "pad_to_intrinsics_mfma.mlir", diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt index a09c135dfe09..03c92b7423bc 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/CMakeLists.txt @@ -16,6 +16,7 @@ iree_lit_test_suite( SRCS "conv2d_to_img2col.mlir" "conv_to_channels_last.mlir" + "generalize_linalg_matmul.mlir" "make_single_dispatch_for_function.mlir" "pad_linalg_ops.mlir" "pad_to_intrinsics_mfma.mlir" diff --git a/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir b/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir new file mode 100644 index 000000000000..bb1949c17ef1 --- /dev/null +++ b/compiler/src/iree/compiler/Preprocessing/Common/test/generalize_linalg_matmul.mlir @@ -0,0 +1,12 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" --verify-each --split-input-file %s | FileCheck %s + +util.func public @generalize_matmul(%arg0: tensor<1x128x128xf32>, %arg1: tensor<1x128x128xf32>) -> tensor<1x128x128xf32> { + %0 = tensor.empty() : tensor<1x128x128xf32> + %1 = linalg.batch_matmul ins(%arg0, %arg1: tensor<1x128x128xf32>, tensor<1x128x128xf32>) outs(%0 : tensor<1x128x128xf32>) -> tensor<1x128x128xf32> + util.return %1 : tensor<1x128x128xf32> +} + +// CHECK-LABEL: util.func public @generalize_matmul +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x128x128xf32>, %[[ARG1:.+]]: tensor<1x128x128xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: %[[ARG0]], %[[ARG1]]