Skip to content

Commit

Permalink
Insert freeze between masked loads and sdiv/srem instructions (#2775)
Browse files Browse the repository at this point in the history
Close #2726

From the code comments:

The Triton masked load pattern can generate instances where the
mask value causes undefined behavior in sdiv/srem instructions. The
language allows this UB as the result of those arithmetic
instructions is never used, and control flow to avoid computation
of these instructions would negatively affect performance. But,
LLVM SimplifyCFG aggressively marks code paths with undefined
behavior as dead. This can result in removal of the mask path and
incorrect results from legal Triton kernels due to masked elements
being used in computation. Run a pass to add a freeze instruction
between masked loads and sdiv/srem to signal to LLVM we consider
the sdiv/srem operands to be well defined.

The strategy here is to basically invalidate the assumptions under which
SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike
C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely
because the hardware Triton is targeting allows that). Inserting a
`freeze` instruction both signals that we expect the behavior of
sdiv/srem to be well defined and hides the constant 0 in the phi from
SimplifyCFG's UB optimizations.

The pass needs to run after every instance of `InstCombine` because the
LLVM optimization that removes UB only occurs if the sdiv/srem are in
the same BB as the phi, which can happen after any `InstCombine`.

Note that the directory structure for this pass is a little different
than `BreakStructPhiNodesPass` because we are already using those
directories in `third_party` for MLIR code. If we want to change that, I
can open an issue but let's do it separately from this PR.

---------
  • Loading branch information
alexbaden authored Dec 3, 2024
1 parent 02346d9 commit 78c13a5
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 0 deletions.
2 changes: 2 additions & 0 deletions bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ target_link_libraries(triton-opt PRIVATE
TritonTransforms
TritonGPUTransforms
TritonNvidiaGPUTransforms
TritonIntelLLVMIR
MLIRGPUToROCDLTransforms
${dialect_libs}
${conversion_libs}
Expand Down Expand Up @@ -88,6 +89,7 @@ target_link_libraries(triton-llvm-opt PRIVATE
LLVMSupport
LLVMOption
LLVMCodeGen
TritonIntelLLVMIR
TritonIntelGPUIR
)
export_executable_symbols_for_plugins(triton-llvm-opt)
Expand Down
8 changes: 8 additions & 0 deletions bin/triton-llvm-opt.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir
/// passes.
#include "lib/Target/LLVMIR/LLVMPasses.h"
#include "third_party/intel/lib/LLVMIR/LLVMPasses.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
Expand Down Expand Up @@ -42,6 +43,11 @@ static cl::opt<bool>
llvm::cl::desc("run pass to break phi struct"),
cl::init(false));

static cl::opt<bool> FreezeMaskedDivRem(
"freeze-masked-div-rem",
llvm::cl::desc("run pass to insert freeze between masked load and div/rem"),
cl::init(false));

namespace {
static std::function<Error(Module *)> makeOptimizingPipeline() {
return [](Module *m) -> Error {
Expand All @@ -62,6 +68,8 @@ static std::function<Error(Module *)> makeOptimizingPipeline() {
llvm::FunctionPassManager fpm;
if (BreakStructPhiNodes)
fpm.addPass(BreakStructPhiNodesPass());
if (FreezeMaskedDivRem)
fpm.addPass(FreezeMaskedDivRemPass());
mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm)));
mpm.run(*m, mam);
return Error::success();
Expand Down
84 changes: 84 additions & 0 deletions python/test/regression/test_divide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# flake8: noqa: F821, F841
import torch
import pytest

import triton
import triton.language as tl

aten = torch.ops.aten


def patch_kernel(template, to_replace):
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel


@pytest.mark.parametrize("float_div", [True, False])
@pytest.mark.parametrize("floor", [True, False])
@pytest.mark.parametrize("trunc", [True, False])
def test_divide(float_div, floor, trunc, device):
# regression test for various division cases

@triton.jit
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(a + (x0), xmask)
tmp2 = tl.load(b + (x0), xmask)
# custom bits
tmp1 = tmp0.to(tl.float32)
tmp3 = tmp2.to(tl.float32)
tmp4 = tmp1 / tmp3
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2),
tmp0 // tmp2)
tmp6 = tmp0 // tmp2
GENERATE_OUTPUTS_HERE

torch.manual_seed(0)

outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else ""
outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else ""
outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else ""

divide_kernel = patch_kernel(divide_kernel,
{"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"})

def launch_triton(a, b):
output0 = torch.zeros_like(a)
output1 = torch.zeros_like(a)
output2 = torch.zeros_like(a)
output3 = torch.zeros_like(a)
output4 = torch.zeros_like(a)

n_elements = output0.numel()

grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), )

divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128)

return (output0, output1, output2, output3, output4)

def launch_torch(a, b):
return (
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a),
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a),
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a),
a / b if float_div is True else torch.zeros_like(a),
a // b if floor is True else torch.zeros_like(a),
)

a = torch.randint(2**32, 2**40, [100, 100], device=device)
b = torch.randint(-10, -1, [100, 100], device=device)

for iter in range(100):
triton_result = launch_triton(a, b)
torch_result = launch_torch(a, b)

for i in range(5):
torch.testing.assert_close(
triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg:
f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}")
57 changes: 57 additions & 0 deletions test/LLVMIR/freeze-masked-div-rem.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s

define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
; CHECK-LABEL: @phi_div_of_zero_okay(
entry:
%cmp = icmp ult i8 %i, 9
br i1 %cmp, label %if.then, label %if.end

if.then:
%y = load i8, ptr %v, align 8
br label %if.end

if.end:
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
; CHECK: [[F0:%.*]] = freeze i8 %yy
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
%z = sdiv i8 %x, %yy
br i1 %cmp, label %if2.then, label %if2.end

if2.then:
store i8 %z, ptr %v, align 8
br label %if2.end

if2.end:
ret void
}

define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) {
; CHECK-LABEL: @two_phi_div_of_zero_okay(
entry:
%cmp = icmp ult i8 %i, 9
br i1 %cmp, label %if.then, label %if.end

if.then:
%y = load i8, ptr %v, align 8
%vv = getelementptr inbounds i64, ptr %v, i64 1
%b = load i8, ptr %vv, align 8
br label %if.end

if.end:
%bb = phi i8 [ %b, %if.then ], [ undef, %entry ]
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ]
; CHECK: [[F0:%.*]] = freeze i8 %yy
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]]
%z = sdiv i8 %x, %yy
; CHECK: [[F1:%.*]] = freeze i8 %bb
; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]]
%zz = sdiv i8 %x, %bb
br i1 %cmp, label %if2.then, label %if2.end

if2.then:
store i8 %z, ptr %v, align 8
br label %if2.end

if2.end:
ret void
}
1 change: 1 addition & 0 deletions third_party/intel/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(GPUToTritonGEN)
add_subdirectory(LLVMIR)
add_subdirectory(Target)
add_subdirectory(TritonAnnotateModule)
add_subdirectory(TritonGENToLLVM)
Expand Down
6 changes: 6 additions & 0 deletions third_party/intel/lib/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
add_triton_library(TritonIntelLLVMIR
LLVMIRFreezeMaskedDivRem.cpp

DEPENDS
LLVMIRIncGen
)
51 changes: 51 additions & 0 deletions third_party/intel/lib/LLVMIR/LLVMIRFreezeMaskedDivRem.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#include "LLVMPasses.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"

using namespace llvm;

static bool processPhiNode(PHINode *PhiNode) {
if (none_of(PhiNode->incoming_values(), [](Use &U) {
Constant *C = dyn_cast<Constant>(&U);
return isa<UndefValue>(U) || C && C->isNullValue();
})) {
return false;
}

bool Changed = false;
BasicBlock *BB = const_cast<BasicBlock *>(PhiNode->getParent());
for (Instruction &I : *BB) {
if (I.getOpcode() == Instruction::SDiv ||
I.getOpcode() == Instruction::SRem) {
const size_t OpIdx = 1;
if (I.getOperand(OpIdx) == PhiNode) {
auto *freezePhi = new FreezeInst(
PhiNode, PhiNode->getName() + ".frozen", I.getIterator());
I.setOperand(OpIdx, freezePhi);
Changed = true;
}
}
}
return Changed;
}

static bool runOnFunction(Function &F) {
bool Changed = false;

for (BasicBlock &BB : F) {
for (PHINode &PhiNode : BB.phis()) {
Changed |= processPhiNode(&PhiNode);
}
}

return Changed;
}

PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F,
FunctionAnalysisManager &FAM) {
const auto b = runOnFunction(F);

return b ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
11 changes: 11 additions & 0 deletions third_party/intel/lib/LLVMIR/LLVMPasses.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"

namespace llvm {

struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> {
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
static StringRef name() { return "FreezeMaskedDivRemPass"; }
};

} // namespace llvm
16 changes: 16 additions & 0 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "intel/include/TritonAnnotateModule/Passes.h"
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
#include "intel/include/TritonToTritonGPUWarp/Passes.h"
#include "intel/lib/LLVMIR/LLVMPasses.h"

#include "triton/Target/SPIRV/SPIRVTranslation.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -204,6 +205,21 @@ void init_triton_intel(py::module &&m) {
fpm.addPass(BreakStructPhiNodesPass());
fpm.addPass(InstCombinePass());
});
pb.registerPeepholeEPCallback(
[&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) {
// The Triton masked load pattern can generate instances where the
// mask value causes undefined behavior in sdiv/srem instructions. The
// language allows this UB as the result of those arithmetic
// instructions is never used, and control flow to avoid computation
// of these instructions would negatively affect performance. But,
// LLVM SimplifyCFG aggressively marks code paths with undefined
// behavior as dead. This can result in removal of the mask path and
// incorrect results from legal Triton kernels due to masked elements
// being used in computation. Run a pass to add a freeze instruction
// between masked loads and sdiv/srem to signal to LLVM we consider
// the sdiv/srem operands to be well defined.
fpm.addPass(FreezeMaskedDivRemPass());
});
mpm.addPass(pb.buildPerModuleDefaultPipeline(opt));
mpm.run(*mod, mam);
});
Expand Down

0 comments on commit 78c13a5

Please sign in to comment.