diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index b66ef71193..c0398fb60b 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE TritonTransforms TritonGPUTransforms TritonNvidiaGPUTransforms + TritonIntelLLVMIR MLIRGPUToROCDLTransforms ${dialect_libs} ${conversion_libs} @@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE LLVMSupport LLVMOption LLVMCodeGen + TritonIntelLLVMIR TritonIntelGPUIR ) export_executable_symbols_for_plugins(triton-llvm-opt) diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp index 1ec804cb50..f521394f47 100644 --- a/bin/triton-llvm-opt.cpp +++ b/bin/triton-llvm-opt.cpp @@ -1,6 +1,7 @@ /// Trimmed down clone of llvm opt to be able to test triton custom llvm ir /// passes. #include "lib/Target/LLVMIR/LLVMPasses.h" +#include "third_party/intel/lib/LLVMIR/LLVMPasses.h" #include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -42,6 +43,11 @@ static cl::opt llvm::cl::desc("run pass to break phi struct"), cl::init(false)); +static cl::opt FreezeMaskedDivRem( + "freeze-masked-div-rem", + llvm::cl::desc("run pass to insert freeze between masked load and div/rem"), + cl::init(false)); + namespace { static std::function makeOptimizingPipeline() { return [](Module *m) -> Error { @@ -62,6 +68,8 @@ static std::function makeOptimizingPipeline() { llvm::FunctionPassManager fpm; if (BreakStructPhiNodes) fpm.addPass(BreakStructPhiNodesPass()); + if (FreezeMaskedDivRem) + fpm.addPass(FreezeMaskedDivRemPass()); mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); mpm.run(*m, mam); return Error::success(); diff --git a/python/test/regression/test_divide.py b/python/test/regression/test_divide.py new file mode 100644 index 0000000000..282b7b5c50 --- /dev/null +++ b/python/test/regression/test_divide.py @@ -0,0 +1,84 @@ +# flake8: noqa: F821, F841 +import torch +import pytest + +import triton +import triton.language as tl + +aten = torch.ops.aten + + +def patch_kernel(template, to_replace): + kernel = triton.JITFunction(template.fn) + for key, value in to_replace.items(): + kernel.src = kernel.src.replace(key, value) + return kernel + + +@pytest.mark.parametrize("float_div", [True, False]) +@pytest.mark.parametrize("floor", [True, False]) +@pytest.mark.parametrize("trunc", [True, False]) +def test_divide(float_div, floor, trunc, device): + # regression test for various division cases + + @triton.jit + def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * XBLOCK + xindex = xoffset + tl.arange(0, XBLOCK)[:] + xmask = xindex < xnumel + x0 = xindex + tmp0 = tl.load(a + (x0), xmask) + tmp2 = tl.load(b + (x0), xmask) + # custom bits + tmp1 = tmp0.to(tl.float32) + tmp3 = tmp2.to(tl.float32) + tmp4 = tmp1 / tmp3 + tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), + tmp0 // tmp2) + tmp6 = tmp0 // tmp2 + GENERATE_OUTPUTS_HERE + + torch.manual_seed(0) + + outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else "" + outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else "" + outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else "" + + divide_kernel = patch_kernel(divide_kernel, + {"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"}) + + def launch_triton(a, b): + output0 = torch.zeros_like(a) + output1 = torch.zeros_like(a) + output2 = torch.zeros_like(a) + output3 = torch.zeros_like(a) + output4 = torch.zeros_like(a) + + n_elements = output0.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), ) + + divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128) + + return (output0, output1, output2, output3, output4) + + def launch_torch(a, b): + return ( + aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a), + aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a), + aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a), + a / b if float_div is True else torch.zeros_like(a), + a // b if floor is True else torch.zeros_like(a), + ) + + a = torch.randint(2**32, 2**40, [100, 100], device=device) + b = torch.randint(-10, -1, [100, 100], device=device) + + for iter in range(100): + triton_result = launch_triton(a, b) + torch_result = launch_torch(a, b) + + for i in range(5): + torch.testing.assert_close( + triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: + f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}") diff --git a/test/LLVMIR/freeze-masked-div-rem.ll b/test/LLVMIR/freeze-masked-div-rem.ll new file mode 100644 index 0000000000..0909a0b994 --- /dev/null +++ b/test/LLVMIR/freeze-masked-div-rem.ll @@ -0,0 +1,57 @@ +; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s + +define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { +; CHECK-LABEL: @phi_div_of_zero_okay( +entry: + %cmp = icmp ult i8 %i, 9 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %y = load i8, ptr %v, align 8 + br label %if.end + +if.end: + %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] + ; CHECK: [[F0:%.*]] = freeze i8 %yy + ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] + %z = sdiv i8 %x, %yy + br i1 %cmp, label %if2.then, label %if2.end + +if2.then: + store i8 %z, ptr %v, align 8 + br label %if2.end + +if2.end: + ret void +} + +define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { +; CHECK-LABEL: @two_phi_div_of_zero_okay( +entry: + %cmp = icmp ult i8 %i, 9 + br i1 %cmp, label %if.then, label %if.end + +if.then: + %y = load i8, ptr %v, align 8 + %vv = getelementptr inbounds i64, ptr %v, i64 1 + %b = load i8, ptr %vv, align 8 + br label %if.end + +if.end: + %bb = phi i8 [ %b, %if.then ], [ undef, %entry ] + %yy = phi i8 [ %y, %if.then ], [ 0, %entry ] + ; CHECK: [[F0:%.*]] = freeze i8 %yy + ; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] + %z = sdiv i8 %x, %yy + ; CHECK: [[F1:%.*]] = freeze i8 %bb + ; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]] + %zz = sdiv i8 %x, %bb + br i1 %cmp, label %if2.then, label %if2.end + +if2.then: + store i8 %z, ptr %v, align 8 + br label %if2.end + +if2.end: + ret void +} diff --git a/third_party/intel/lib/CMakeLists.txt b/third_party/intel/lib/CMakeLists.txt index b2d8e610d0..2b58d7f122 100644 --- a/third_party/intel/lib/CMakeLists.txt +++ b/third_party/intel/lib/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(Analysis) add_subdirectory(Dialect) add_subdirectory(GPUToTritonGEN) +add_subdirectory(LLVMIR) add_subdirectory(Target) add_subdirectory(TritonAnnotateModule) add_subdirectory(TritonGENToLLVM) diff --git a/third_party/intel/lib/LLVMIR/CMakeLists.txt b/third_party/intel/lib/LLVMIR/CMakeLists.txt new file mode 100644 index 0000000000..6da101e950 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_triton_library(TritonIntelLLVMIR + LLVMIRFreezeMaskedDivRem.cpp + + DEPENDS + LLVMIRIncGen + ) diff --git a/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp b/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp new file mode 100644 index 0000000000..5344f92d70 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp @@ -0,0 +1,51 @@ +#include "LLVMPasses.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiNode(PHINode *PhiNode) { + if (none_of(PhiNode->incoming_values(), [](Use &U) { + Constant *C = dyn_cast(&U); + return isa(U) || C && C->isNullValue(); + })) { + return false; + } + + bool Changed = false; + BasicBlock *BB = const_cast(PhiNode->getParent()); + for (Instruction &I : *BB) { + if (I.getOpcode() == Instruction::SDiv || + I.getOpcode() == Instruction::SRem) { + const size_t OpIdx = 1; + if (I.getOperand(OpIdx) == PhiNode) { + auto *freezePhi = new FreezeInst( + PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); + I.setOperand(OpIdx, freezePhi); + Changed = true; + } + } + } + return Changed; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + + for (BasicBlock &BB : F) { + for (PHINode &PhiNode : BB.phis()) { + Changed |= processPhiNode(&PhiNode); + } + } + + return Changed; +} + +PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F, + FunctionAnalysisManager &FAM) { + const auto b = runOnFunction(F); + + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/third_party/intel/lib/LLVMIR/LLVMPasses.h b/third_party/intel/lib/LLVMIR/LLVMPasses.h new file mode 100644 index 0000000000..72f71dd983 --- /dev/null +++ b/third_party/intel/lib/LLVMIR/LLVMPasses.h @@ -0,0 +1,11 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" + +namespace llvm { + +struct FreezeMaskedDivRemPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + static StringRef name() { return "FreezeMaskedDivRemPass"; } +}; + +} // namespace llvm diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index e6d13915ee..362e404c5f 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -16,6 +16,7 @@ #include "intel/include/TritonAnnotateModule/Passes.h" #include "intel/include/TritonIntelGPUToLLVM/Passes.h" #include "intel/include/TritonToTritonGPUWarp/Passes.h" +#include "intel/lib/LLVMIR/LLVMPasses.h" #include "triton/Target/SPIRV/SPIRVTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -204,6 +205,21 @@ void init_triton_intel(py::module &&m) { fpm.addPass(BreakStructPhiNodesPass()); fpm.addPass(InstCombinePass()); }); + pb.registerPeepholeEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // The Triton masked load pattern can generate instances where the + // mask value causes undefined behavior in sdiv/srem instructions. The + // language allows this UB as the result of those arithmetic + // instructions is never used, and control flow to avoid computation + // of these instructions would negatively affect performance. But, + // LLVM SimplifyCFG aggressively marks code paths with undefined + // behavior as dead. This can result in removal of the mask path and + // incorrect results from legal Triton kernels due to masked elements + // being used in computation. Run a pass to add a freeze instruction + // between masked loads and sdiv/srem to signal to LLVM we consider + // the sdiv/srem operands to be well defined. + fpm.addPass(FreezeMaskedDivRemPass()); + }); mpm.addPass(pb.buildPerModuleDefaultPipeline(opt)); mpm.run(*mod, mam); });