-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Insert freeze between masked loads and sdiv/srem instructions (#2775)
Close #2726 From the code comments: The Triton masked load pattern can generate instances where the mask value causes undefined behavior in sdiv/srem instructions. The language allows this UB as the result of those arithmetic instructions is never used, and control flow to avoid computation of these instructions would negatively affect performance. But, LLVM SimplifyCFG aggressively marks code paths with undefined behavior as dead. This can result in removal of the mask path and incorrect results from legal Triton kernels due to masked elements being used in computation. Run a pass to add a freeze instruction between masked loads and sdiv/srem to signal to LLVM we consider the sdiv/srem operands to be well defined. The strategy here is to basically invalidate the assumptions under which SimplifyCFG can remove UB for sdiv/srem. The rationale is that, unlike C/C++, Triton explicitly allows UB in sdiv/srem instructions (likely because the hardware Triton is targeting allows that). Inserting a `freeze` instruction both signals that we expect the behavior of sdiv/srem to be well defined and hides the constant 0 in the phi from SimplifyCFG's UB optimizations. The pass needs to run after every instance of `InstCombine` because the LLVM optimization that removes UB only occurs if the sdiv/srem are in the same BB as the phi, which can happen after any `InstCombine`. Note that the directory structure for this pass is a little different than `BreakStructPhiNodesPass` because we are already using those directories in `third_party` for MLIR code. If we want to change that, I can open an issue but let's do it separately from this PR. ---------
- Loading branch information
Showing
9 changed files
with
236 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# flake8: noqa: F821, F841 | ||
import torch | ||
import pytest | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
aten = torch.ops.aten | ||
|
||
|
||
def patch_kernel(template, to_replace): | ||
kernel = triton.JITFunction(template.fn) | ||
for key, value in to_replace.items(): | ||
kernel.src = kernel.src.replace(key, value) | ||
return kernel | ||
|
||
|
||
@pytest.mark.parametrize("float_div", [True, False]) | ||
@pytest.mark.parametrize("floor", [True, False]) | ||
@pytest.mark.parametrize("trunc", [True, False]) | ||
def test_divide(float_div, floor, trunc, device): | ||
# regression test for various division cases | ||
|
||
@triton.jit | ||
def divide_kernel(a, b, out_ptr0, out_ptr1, out_ptr2, out_ptr3, out_ptr4, xnumel, XBLOCK: tl.constexpr): | ||
xoffset = tl.program_id(0) * XBLOCK | ||
xindex = xoffset + tl.arange(0, XBLOCK)[:] | ||
xmask = xindex < xnumel | ||
x0 = xindex | ||
tmp0 = tl.load(a + (x0), xmask) | ||
tmp2 = tl.load(b + (x0), xmask) | ||
# custom bits | ||
tmp1 = tmp0.to(tl.float32) | ||
tmp3 = tmp2.to(tl.float32) | ||
tmp4 = tmp1 / tmp3 | ||
tmp5 = tl.where((tmp0 < 0) != (tmp2 < 0), tl.where(tmp0 % tmp2 != 0, tmp0 // tmp2 - 1, tmp0 // tmp2), | ||
tmp0 // tmp2) | ||
tmp6 = tmp0 // tmp2 | ||
GENERATE_OUTPUTS_HERE | ||
|
||
torch.manual_seed(0) | ||
|
||
outputs_float_div = "tl.store(out_ptr0 + (x0), tmp4, xmask)\n tl.store(out_ptr3 + (x0), tmp4, xmask)" if float_div else "" | ||
outputs_floor = " tl.store(out_ptr1 + (x0), tmp5, xmask)\n tl.store(out_ptr4 + (x0), tmp5, xmask)" if floor else "" | ||
outputs_trunc = " tl.store(out_ptr2 + (x0), tmp6, xmask)" if trunc else "" | ||
|
||
divide_kernel = patch_kernel(divide_kernel, | ||
{"GENERATE_OUTPUTS_HERE": f"{outputs_float_div}\n{outputs_floor}\n{outputs_trunc}"}) | ||
|
||
def launch_triton(a, b): | ||
output0 = torch.zeros_like(a) | ||
output1 = torch.zeros_like(a) | ||
output2 = torch.zeros_like(a) | ||
output3 = torch.zeros_like(a) | ||
output4 = torch.zeros_like(a) | ||
|
||
n_elements = output0.numel() | ||
|
||
grid = lambda meta: (triton.cdiv(n_elements, meta['XBLOCK']), ) | ||
|
||
divide_kernel[grid](a, b, output0, output1, output2, output3, output4, n_elements, XBLOCK=128) | ||
|
||
return (output0, output1, output2, output3, output4) | ||
|
||
def launch_torch(a, b): | ||
return ( | ||
aten.div(a, b, rounding_mode=None) if float_div is True else torch.zeros_like(a), | ||
aten.div(a, b, rounding_mode="floor") if floor is True else torch.zeros_like(a), | ||
aten.div(a, b, rounding_mode="trunc") if trunc is True else torch.zeros_like(a), | ||
a / b if float_div is True else torch.zeros_like(a), | ||
a // b if floor is True else torch.zeros_like(a), | ||
) | ||
|
||
a = torch.randint(2**32, 2**40, [100, 100], device=device) | ||
b = torch.randint(-10, -1, [100, 100], device=device) | ||
|
||
for iter in range(100): | ||
triton_result = launch_triton(a, b) | ||
torch_result = launch_torch(a, b) | ||
|
||
for i in range(5): | ||
torch.testing.assert_close( | ||
triton_result[i], torch_result[i], check_dtype=False, msg=lambda msg: | ||
f"Float: {float_div}, Floor: {floor}, Trunc: {trunc}\nIteration {iter}, {i} failed\n{msg}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
; RUN: triton-llvm-opt -freeze-masked-div-rem %s | FileCheck %s | ||
|
||
define void @phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { | ||
; CHECK-LABEL: @phi_div_of_zero_okay( | ||
entry: | ||
%cmp = icmp ult i8 %i, 9 | ||
br i1 %cmp, label %if.then, label %if.end | ||
|
||
if.then: | ||
%y = load i8, ptr %v, align 8 | ||
br label %if.end | ||
|
||
if.end: | ||
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ] | ||
; CHECK: [[F0:%.*]] = freeze i8 %yy | ||
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] | ||
%z = sdiv i8 %x, %yy | ||
br i1 %cmp, label %if2.then, label %if2.end | ||
|
||
if2.then: | ||
store i8 %z, ptr %v, align 8 | ||
br label %if2.end | ||
|
||
if2.end: | ||
ret void | ||
} | ||
|
||
define void @two_phi_div_of_zero_okay(i8 noundef %x, i8 %i, ptr %v) { | ||
; CHECK-LABEL: @two_phi_div_of_zero_okay( | ||
entry: | ||
%cmp = icmp ult i8 %i, 9 | ||
br i1 %cmp, label %if.then, label %if.end | ||
|
||
if.then: | ||
%y = load i8, ptr %v, align 8 | ||
%vv = getelementptr inbounds i64, ptr %v, i64 1 | ||
%b = load i8, ptr %vv, align 8 | ||
br label %if.end | ||
|
||
if.end: | ||
%bb = phi i8 [ %b, %if.then ], [ undef, %entry ] | ||
%yy = phi i8 [ %y, %if.then ], [ 0, %entry ] | ||
; CHECK: [[F0:%.*]] = freeze i8 %yy | ||
; CHECK-NEXT: %z = sdiv i8 %x, [[F0:%.*]] | ||
%z = sdiv i8 %x, %yy | ||
; CHECK: [[F1:%.*]] = freeze i8 %bb | ||
; CHECK-NEXT: %zz = sdiv i8 %x, [[F1:%.*]] | ||
%zz = sdiv i8 %x, %bb | ||
br i1 %cmp, label %if2.then, label %if2.end | ||
|
||
if2.then: | ||
store i8 %z, ptr %v, align 8 | ||
br label %if2.end | ||
|
||
if2.end: | ||
ret void | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
add_triton_library(TritonIntelLLVMIR | ||
LLVMIRFreezeMaskedDivRem.cpp | ||
|
||
DEPENDS | ||
LLVMIRIncGen | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#include "LLVMPasses.h" | ||
#include "llvm/Analysis/TargetTransformInfo.h" | ||
#include "llvm/Analysis/ValueTracking.h" | ||
#include "llvm/IR/Dominators.h" | ||
#include "llvm/IR/Instructions.h" | ||
|
||
using namespace llvm; | ||
|
||
static bool processPhiNode(PHINode *PhiNode) { | ||
if (none_of(PhiNode->incoming_values(), [](Use &U) { | ||
Constant *C = dyn_cast<Constant>(&U); | ||
return isa<UndefValue>(U) || C && C->isNullValue(); | ||
})) { | ||
return false; | ||
} | ||
|
||
bool Changed = false; | ||
BasicBlock *BB = const_cast<BasicBlock *>(PhiNode->getParent()); | ||
for (Instruction &I : *BB) { | ||
if (I.getOpcode() == Instruction::SDiv || | ||
I.getOpcode() == Instruction::SRem) { | ||
const size_t OpIdx = 1; | ||
if (I.getOperand(OpIdx) == PhiNode) { | ||
auto *freezePhi = new FreezeInst( | ||
PhiNode, PhiNode->getName() + ".frozen", I.getIterator()); | ||
I.setOperand(OpIdx, freezePhi); | ||
Changed = true; | ||
} | ||
} | ||
} | ||
return Changed; | ||
} | ||
|
||
static bool runOnFunction(Function &F) { | ||
bool Changed = false; | ||
|
||
for (BasicBlock &BB : F) { | ||
for (PHINode &PhiNode : BB.phis()) { | ||
Changed |= processPhiNode(&PhiNode); | ||
} | ||
} | ||
|
||
return Changed; | ||
} | ||
|
||
PreservedAnalyses FreezeMaskedDivRemPass::run(Function &F, | ||
FunctionAnalysisManager &FAM) { | ||
const auto b = runOnFunction(F); | ||
|
||
return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/Pass.h" | ||
|
||
namespace llvm { | ||
|
||
struct FreezeMaskedDivRemPass : PassInfoMixin<FreezeMaskedDivRemPass> { | ||
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); | ||
static StringRef name() { return "FreezeMaskedDivRemPass"; } | ||
}; | ||
|
||
} // namespace llvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters