From 1c23c651c520e884c2f3fd1cd73f4cd099ce25fc Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Fri, 18 Oct 2024 15:42:42 +0000 Subject: [PATCH] Update the fast math pass --- third_party/intel/backend/compiler.py | 2 +- third_party/intel/triton_xpu.cc | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 5d1a9fdc75..bc10672d43 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -279,7 +279,6 @@ def make_llir(src, metadata, options): if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1": intel.passes.ttgpuir.add_allocate_shared_memory(pm) intel.passes.ttgpuir.add_to_llvmir(pm) - intel.set_fast_math(mod) passes.convert.add_arith_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) @@ -292,6 +291,7 @@ def make_llir(src, metadata, options): context = llvm.context() llvm_mod = llvm.to_module(mod, context) intel.set_spv_target_triple(llvm_mod) + intel.set_fast_math(llvm_mod) if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 951de6ce35..5039942234 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -216,15 +216,19 @@ void init_triton_intel(py::module &&m) { }); // May do this after llvm ir according to user fmath flag. - m.def("set_fast_math", [](mlir::ModuleOp mod) { - using namespace mlir; - MLIRContext *ctx = mod.getContext(); - mod.walk([&](Operation *op) { - if (auto fmIf = dyn_cast(op)) - op->setAttr( - fmIf.getFastMathAttrName(), - arith::FastMathFlagsAttr::get(ctx, arith::FastMathFlags::fast)); - }); + m.def("set_fast_math", [](llvm::Module *mod) { + using namespace llvm; + for (auto &func : *mod) { + for (auto &bb : func) { + for (auto &inst : bb) { + if (auto *op = dyn_cast(&inst)) { + FastMathFlags FMF; + FMF.setFast(true); + inst.setFastMathFlags(FMF); + } + } + } + } }); m.def("set_spv_target_triple", [](llvm::Module *mod) {