From 78c13a56ec09bda4dd1a10e73e28a758c83a5ec8 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Mon, 2 Dec 2024 22:05:04 -0500 Subject: [PATCH] Insert freeze between masked loads and sdiv/srem instructions (#2775) Close #2726 From the code comments: The Triton masked load pattern can generate instances where the mask value causes undefined behavior in sdiv/srem instructions. The language allows this UB as the result of those arithmetic instructions is never used, and control flow to avoid computation of these instructions would negatively affect performance. But, LLVM SimplifyCFG aggressively marks code paths with undefined behavior as dead. This can result in removal of the mask path and incorrect results from legal Triton kernels due to masked elements being used in computation. Run a pass to add a freeze instruction between masked loads and sdiv/srem to signal to LLVM we consider the sdiv/srem operands to be well defined. The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a `freeze` instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations. The pass needs to run after every instance of `InstCombine` because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after any `InstCombine`. Note that the directory structure for this pass is a little different than `BreakStructPhiNodesPass` because we are already using those directories in `third_party` for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR. --------- --- bin/CMakeLists.txt | 2 + bin/triton-llvm-opt.cpp | 8 ++ python/test/regression/test_divide.py | 84 +++++++++++++++++++ test/LLVMIR/freeze-masked-div-rem.ll | 57 +++++++++++++ third_party/intel/lib/CMakeLists.txt | 1 + third_party/intel/lib/LLVMIR/CMakeLists.txt | 6 ++ .../lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp | 51 +++++++++++ third_party/intel/lib/LLVMIR/LLVMPasses.h | 11 +++ third_party/intel/triton_xpu.cc | 16 ++++ 9 files changed, 236 insertions(+) create mode 100644 python/test/regression/test_divide.py create mode 100644 test/LLVMIR/freeze-masked-div-rem.ll create mode 100644 third_party/intel/lib/LLVMIR/CMakeLists.txt create mode 100644 third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp create mode 100644 third_party/intel/lib/LLVMIR/LLVMPasses.h diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index b66ef71193..c0398fb60b 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE TritonTransforms TritonGPUTransforms TritonNvidiaGPUTransforms + TritonIntelLLVMIR MLIRGPUToROCDLTransforms ${dialect_libs} ${conversion_libs} @@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE LLVMSupport LLVMOption LLVMCodeGen + TritonIntelLLVMIR TritonIntelGPUIR ) export_executable_symbols_for_plugins(triton-llvm-opt) diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp index 1ec804cb50..f521394f47 100644 --- a/bin/triton-llvm-opt.cpp +++ b/bin/triton-llvm-opt.cpp @@ -1,6 +1,7 @@ /// Trimmed down clone of llvm opt to be able to test triton custom llvm ir /// passes. #include "lib/Target/LLVMIR/LLVMPasses.h" +#include "third_party/intel/lib/LLVMIR/LLVMPasses.h" #include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -42,6 +43,11 @@ static cl::opt llvm::cl::desc("run pass to break phi struct"), cl::init(false)); +static cl::opt FreezeMaskedDivRem( + "freeze-masked-div-rem", + llvm::cl::desc("run pass to insert freeze between masked load and div/rem"), + cl::init(false)); + namespace { static std::function makeOptimizingPipeline() { return [](Module *m) -> Error { @@ -62,6 +68,8 @@ static std::function makeOptimizingPipeline() { llvm::FunctionPassManager fpm; if (BreakStructPhiNodes) fpm.addPass(BreakStructPhiNodesPass()); + if (FreezeMaskedDivRem) + fpm.addPass(FreezeMaskedDivRemPass()); mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); mpm.run(*m, mam); return Error::success(); diff --git a/python/test/regression/test_divide.py b/python/test/regression/test_divide.py new file mode 100644 index 0000000000..282b7b5c50 --- /dev/null +++ b/python/test/regression/test_divide.py @@ -0,0 +1,84 @@ +# flake8: noqa: F821, F841 +import torch +import pytest + +import triton +import triton.language as tl + +aten = torch.ops.aten + + +def patch_kernel(template, to_replace): + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +@pytest.mark.parametrize("float_div", [True, False]) +@pytest.mark.parametrize("floor", [True, False]) +@pytest.mark.parametrize("trunc", [True, False]) +def test_divide(float_div, floor, trunc, device): + # regression test for various division cases + + @triton.jit + def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(a + (x0), xmask) + tmp2 = tl.load(b + (x0), xmask) + # custom bits + tmp1 = tmp0.to(tl.float32) + tmp3 = tmp2.to(tl.float32) + tmp4 = tmp1 / tmp3 + tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), + tmp0 // tmp2) + tmp6 = tmp0 // tmp2 + GENERATE_OUTPUTS_HERE + + torch.manual_seed(0) + + outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else "" + outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else "" + outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else "" + + divide_kernel = patch_kernel(divide_kernel, + {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"}) + + def launch_triton(a, b): + output0 = torch.zeros_like(a) + output1 = torch.zeros_like(a) + output2 = torch.zeros_like(a) + output3 = torch.zeros_like(a) + output4 = torch.zeros_like(a) + + n_elements = output0.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), ) + + divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128) + + return (output0, output1, output2, output3, output4) + + def launch_torch(a, b): + return ( + aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a), + aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a), + aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a), + a / b if float_div is True else torch.zeros_like(a), + a // b if floor is True else torch.zeros_like(a), + ) + + a = torch.randint(2**32, 2**40, [100, 100], device=device) + b = torch.randint(-10, -1, [100, 100], device=device) + + for iter in range(100): + triton_result = launch_triton(a, b) + torch_result = launch_torch(a, b) + + for i in range(5): + torch.testing.assert_close( + triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: + f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}") diff --git a/test/LLVMIR/freeze-masked-div-rem.ll b/test/LLVMIR/freeze-masked-div-rem.ll new file mode 100644 index 0000000000..0909a0b994 --- /dev/null +++ b/test/LLVMIR/freeze-masked-div-rem.ll @@ -0,0 +1,57 @@ +; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s + +define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { +; CHECK-LABEL: @phi_div_of_zero_okay( +entry: + %cmp = icmp ult i8 %i, 9 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %y = load i8, ptr %v, align 8 + br label %if.end + +if.end: + %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] + ; CHECK: [[F0:%.*]] = freeze i8 %yy + ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] + %z = sdiv i8 %x, %yy + br i1 %cmp, label %if2.then, label %if2.end + +if2.then: + store i8 %z, ptr %v, align 8 + br label %if2.end + +if2.end: + ret void +} + +define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { +; CHECK-LABEL: @two_phi_div_of_zero_okay( +entry: + %cmp = icmp ult i8 %i, 9 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %y = load i8, ptr %v, align 8 + %vv = getelementptr inbounds i64, ptr %v, i64 1 + %b = load i8, ptr %vv, align 8 + br label %if.end + +if.end: + %bb = phi i8 [ %b, %if.then ], [ undef, %entry ] + %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] + ; CHECK: [[F0:%.*]] = freeze i8 %yy + ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] + %z = sdiv i8 %x, %yy + ; CHECK: [[F1:%.*]] = freeze i8 %bb + ; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]] + %zz = sdiv i8 %x, %bb + br i1 %cmp, label %if2.then, label %if2.end + +if2.then: + store i8 %z, ptr %v, align 8 + br label %if2.end + +if2.end: + ret void +} diff --git a/third_party/intel/lib/CMakeLists.txt b/third_party/intel/lib/CMakeLists.txt index b2d8e610d0..2b58d7f122 100644 --- a/third_party/intel/lib/CMakeLists.txt +++ b/third_party/intel/lib/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(GPUToTritonGEN) +add_subdirectory(LLVMIR) add_subdirectory(Target) add_subdirectory(TritonAnnotateModule) add_subdirectory(TritonGENToLLVM) diff --git a/third_party/intel/lib/LLVMIR/CMakeLists.txt b/third_party/intel/lib/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000..6da101e950 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonIntelLLVMIR + LLVMIRFreezeMaskedDivRem.cpp + + DEPENDS + LLVMIRIncGen + ) diff --git a/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp b/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp new file mode 100644 index 0000000000..5344f92d70 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp @@ -0,0 +1,51 @@ +#include "LLVMPasses.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiNode(PHINode *PhiNode) { + if (none_of(PhiNode->incoming_values(), [](Use &U) { + Constant *C = dyn_cast(&U); + return isa(U) || C && C->isNullValue(); + })) { + return false; + } + + bool Changed = false; + BasicBlock *BB = const_cast(PhiNode->getParent()); + for (Instruction &I : *BB) { + if (I.getOpcode() == Instruction::SDiv || + I.getOpcode() == Instruction::SRem) { + const size_t OpIdx = 1; + if (I.getOperand(OpIdx) == PhiNode) { + auto *freezePhi = new FreezeInst( + PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); + I.setOperand(OpIdx, freezePhi); + Changed = true; + } + } + } + return Changed; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + + for (BasicBlock &BB : F) { + for (PHINode &PhiNode : BB.phis()) { + Changed |= processPhiNode(&PhiNode); + } + } + + return Changed; +} + +PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F, + FunctionAnalysisManager &FAM) { + const auto b = runOnFunction(F); + + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/intel/lib/LLVMIR/LLVMPasses.h b/third_party/intel/lib/LLVMIR/LLVMPasses.h new file mode 100644 index 0000000000..72f71dd983 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/LLVMPasses.h @@ -0,0 +1,11 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" + +namespace llvm { + +struct FreezeMaskedDivRemPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "FreezeMaskedDivRemPass"; } +}; + +} // namespace llvm diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index e6d13915ee..362e404c5f 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -16,6 +16,7 @@ #include "intel/include/TritonAnnotateModule/Passes.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "intel/include/TritonToTritonGPUWarp/Passes.h" +#include "intel/lib/LLVMIR/LLVMPasses.h" #include "triton/Target/SPIRV/SPIRVTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -204,6 +205,21 @@ void init_triton_intel(py::module &&m) { fpm.addPass(BreakStructPhiNodesPass()); fpm.addPass(InstCombinePass()); }); + pb.registerPeepholeEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // The Triton masked load pattern can generate instances where the + // mask value causes undefined behavior in sdiv/srem instructions. The + // language allows this UB as the result of those arithmetic + // instructions is never used, and control flow to avoid computation + // of these instructions would negatively affect performance. But, + // LLVM SimplifyCFG aggressively marks code paths with undefined + // behavior as dead. This can result in removal of the mask path and + // incorrect results from legal Triton kernels due to masked elements + // being used in computation. Run a pass to add a freeze instruction + // between masked loads and sdiv/srem to signal to LLVM we consider + // the sdiv/srem operands to be well defined. + fpm.addPass(FreezeMaskedDivRemPass()); + }); mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); mpm.run(*mod, mam); });