Skip to content

Commit

Permalink
Merge commit '3379361a1cf89d4512e8b95adc9204881bc17d11'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jul 17, 2024
2 parents b5f4f4a + 3379361 commit 1ae1c3c
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit);

void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);
Expand Down
11 changes: 9 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,14 +433,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
} // namespace
} // namespace mlir::triton::gpu

void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<gpu::ConvertLayoutOpUsingLinearLayoutsConversion>(typeConverter,
benefit);
}

void mlir::triton::populateConvertLayoutOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit) {
// We prefer using the linear layout conversion, so it gets a higher benefit.
// Eventually the LL conversion will subsume all of the others and be the only
// one left.
patterns.add<gpu::ConvertLayoutOpUsingLinearLayoutsConversion>(
typeConverter, benefit.getBenefit() + 1);
mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern(
typeConverter, targetInfo, patterns, benefit.getBenefit() + 1);
patterns.add<gpu::ConvertLayoutOpConversion>(typeConverter, targetInfo,
benefit);
patterns.add<gpu::LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
rewriter.eraseOp(amendedFuncOp);
newFuncOp.setLinkage(LLVM::Linkage::Internal);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// Set an attribute for reqntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid",
newFuncOp->setAttr("nvvm.reqntid",
rewriter.getDenseI32ArrayAttr(32 * numWarps));
rewriter.eraseOp(funcOp);
return success();
Expand Down
8 changes: 8 additions & 0 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def kernel(Z, desc, SIZE: tl.constexpr):
@triton.jit
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):
# TODO(embg) remove TMA fence after __grid_constant__ lands
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % num_pid_m
Expand Down
23 changes: 8 additions & 15 deletions python/triton/tools/experimental_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
import torch

import triton
import triton.language as tl


@triton.jit
def flush_TMA_cache(desc_ptr):
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[desc_ptr], dtype=tl.int32, is_pure=False, pack=1)


# Constructs a 1D TMA descriptor in mutable GPU memory.
#
# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's
# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu.
def create_1d_tma_descriptor(ptr, dim, block_dim, element_size):
TMA_SIZE = 128
desc = torch.empty(TMA_SIZE, dtype=torch.int8)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc.data_ptr())
gpu_desc = desc.cuda()
# TMA cache is not being flushed in between dispacthes, therefore we should
# manually flush the cache every time we create a new TMA descriptor to make
# sure the following dispatch don't use stale cache when accessing TMA.
flush_TMA_cache[(1, )](gpu_desc, num_warps=1)
return gpu_desc


# Constructs a 2D TMA descriptor in mutable GPU memory.
#
# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's
# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu.
def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size):
TMA_SIZE = 128
desc = torch.empty(TMA_SIZE, dtype=torch.int8)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size,
desc.data_ptr())
gpu_desc = desc.cuda()
# TMA cache is not being flushed in between dispacthes, therefore we should
# manually flush the cache every time we create a new TMA descriptor to make
# sure the following dispatch don't use stale cache when accessing TMA.
flush_TMA_cache[(1, )](gpu_desc, num_warps=1)
return gpu_desc
8 changes: 8 additions & 0 deletions python/tutorials/09-persistent-matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,14 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
NUM_SMS: tl.constexpr): #
# TODO(embg) remove TMA fence after __grid_constant__ lands
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)
tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l",
[c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1)

dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = array<i32: 128>
// CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 128>
tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
tt.return
Expand Down
30 changes: 18 additions & 12 deletions third_party/nvidia/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ static bool gpuAssert(CUresult code, const char *file, int line) {
} \
} while (0)

// Used to check if functions exist in old CUDA driver versions.
#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \
do { \
if ((funcPointer) == NULL) { \
(funcPointer) = (initializerFunction)(); \
if ((funcPointer) == NULL) { \
return NULL; \
} \
} \
} while (0)

static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
int device_id;
if (!PyArg_ParseTuple(args, "i", &device_id))
Expand Down Expand Up @@ -215,12 +226,8 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
config.attrs = launchAttr;

static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
if (cuOccupancyMaxActiveClusters == NULL) {
cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle();
if (cuOccupancyMaxActiveClusters == NULL) {
return NULL;
}
}
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters,
getCuOccupancyMaxActiveClustersHandle);

Py_BEGIN_ALLOW_THREADS;
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
Expand Down Expand Up @@ -303,12 +310,8 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) {
assert((elementSize * tensorDim) >= 32 && "block size too small.");
int rank = 1;
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
if (cuTensorMapEncodeTiled == NULL) {
cuTensorMapEncodeTiled = getCuTensorMapEncodeTiledHandle();
if (cuTensorMapEncodeTiled == NULL) {
return NULL;
}
}
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
getCuTensorMapEncodeTiledHandle);
CUresult result = cuTensorMapEncodeTiled(
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
Expand Down Expand Up @@ -369,6 +372,9 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) {
if (contigDimSizeInByte > 128) {
tensorDims[0] = 128 / elementSize;
}
static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
getCuTensorMapEncodeTiledHandle);
CUresult result = cuTensorMapEncodeTiled(
(CUtensorMap *)desc_address, type, rank, (void *)global_address, dims,
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
Expand Down

0 comments on commit 1ae1c3c

Please sign in to comment.