From 3aeb223819d632303dd2b45f4dc533d6af90dc46 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 17 Jul 2024 00:33:59 -0400 Subject: [PATCH 1/4] [RUNTIME] Fix the function lookup problem for CUDA 11 driver (#4335) There was a function pointer lookup missing in the previous patch. https://github.com/triton-lang/triton/commit/f9f2960deef376da4ebc1ff8b1546051c66894a4 --- third_party/nvidia/backend/driver.c | 30 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index 1de2c0f234..f9f60271fe 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -39,6 +39,17 @@ static bool gpuAssert(CUresult code, const char *file, int line) { } \ } while (0) +// Used to check if functions exist in old CUDA driver versions. +#define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \ + do { \ + if ((funcPointer) == NULL) { \ + (funcPointer) = (initializerFunction)(); \ + if ((funcPointer) == NULL) { \ + return NULL; \ + } \ + } \ + } while (0) + static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { int device_id; if (!PyArg_ParseTuple(args, "i", &device_id)) @@ -215,12 +226,8 @@ static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) { config.attrs = launchAttr; static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL; - if (cuOccupancyMaxActiveClusters == NULL) { - cuOccupancyMaxActiveClusters = getCuOccupancyMaxActiveClustersHandle(); - if (cuOccupancyMaxActiveClusters == NULL) { - return NULL; - } - } + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters, + getCuOccupancyMaxActiveClustersHandle); Py_BEGIN_ALLOW_THREADS; CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute( @@ -303,12 +310,8 @@ static PyObject *fill1DTMADescriptor(PyObject *self, PyObject *args) { assert((elementSize * tensorDim) >= 32 && "block size too small."); int rank = 1; static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; - if (cuTensorMapEncodeTiled == NULL) { - cuTensorMapEncodeTiled = getCuTensorMapEncodeTiledHandle(); - if (cuTensorMapEncodeTiled == NULL) { - return NULL; - } - } + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); CUresult result = cuTensorMapEncodeTiled( (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, globalStrides, boxDim, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, @@ -369,6 +372,9 @@ static PyObject *fill2DTMADescriptor(PyObject *self, PyObject *args) { if (contigDimSizeInByte > 128) { tensorDims[0] = 128 / elementSize; } + static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL; + INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled, + getCuTensorMapEncodeTiledHandle); CUresult result = cuTensorMapEncodeTiled( (CUtensorMap *)desc_address, type, rank, (void *)global_address, dims, globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, From 0dd9029abf61c949471eb512a0b1e0da55339859 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Wed, 17 Jul 2024 10:27:18 -0700 Subject: [PATCH 2/4] [BACKEND] Use nvvm.reqntid instead of nvvm.maxntid (#4343) nvvm.reqntid has a stronger semantic and should allow better optimization in the finalizer. --- lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp | 4 ++-- test/Conversion/tritongpu_to_llvm.mlir | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp index 47f40ebecd..6850df2ea9 100644 --- a/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp @@ -97,9 +97,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { rewriter.eraseOp(amendedFuncOp); newFuncOp.setLinkage(LLVM::Linkage::Internal); } - // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // Set an attribute for reqntidx, it could be used in latter LLVM codegen // for `nvvm.annotation` metadata. - newFuncOp->setAttr("nvvm.maxntid", + newFuncOp->setAttr("nvvm.reqntid", rewriter.getDenseI32ArrayAttr(32 * numWarps)); rewriter.eraseOp(funcOp); return success(); diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ff033da1fa..1a8f53e06c 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -3,7 +3,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>) // Here the 128 comes from the 4 in module attribute multiples 32 - // CHECK: nvvm.kernel = 1 : ui1, nvvm.maxntid = array + // CHECK: nvvm.kernel = 1 : ui1, nvvm.reqntid = array tt.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return tt.return From b4e079eaf7c85a64c3a74338985ee1312fbb48bf Mon Sep 17 00:00:00 2001 From: Zahi Moudallal Date: Wed, 17 Jul 2024 11:04:51 -0700 Subject: [PATCH 3/4] [BACKEND] Separate layout conversion codegen using LL to a separate pattern (#4344) --- .../TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h | 4 ++++ .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index d1494fd7ee..29af2c5f7c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -80,6 +80,10 @@ void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit); + void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 031a72eeee..545c907439 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -433,14 +433,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } // namespace } // namespace mlir::triton::gpu +void mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(typeConverter, + benefit); +} + void mlir::triton::populateConvertLayoutOpToLLVMPatterns( LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, PatternBenefit benefit) { // We prefer using the linear layout conversion, so it gets a higher benefit. // Eventually the LL conversion will subsume all of the others and be the only // one left. - patterns.add( - typeConverter, benefit.getBenefit() + 1); + mlir::triton::populateConvertLayoutOpUsingLinearLayoutsToLLVMPattern( + typeConverter, targetInfo, patterns, benefit.getBenefit() + 1); patterns.add(typeConverter, targetInfo, benefit); patterns.add(typeConverter, targetInfo, benefit); From 3379361a1cf89d4512e8b95adc9204881bc17d11 Mon Sep 17 00:00:00 2001 From: Elliot Gorokhovsky Date: Wed, 17 Jul 2024 14:08:23 -0400 Subject: [PATCH 4/4] Short-term solution for TMA descriptor cache management (#4342) This PR fixes the bug demonstrated [here](https://github.com/embg/triton/blob/ed125e4a44e397e9a40e691bb7ce40c698120a1a/tma_repro.py), which is the probable root cause of https://github.com/triton-lang/triton/issues/4332. ## The problem NVIDIA docs [recommend](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#using-tma-to-transfer-multi-dimensional-arrays) that TMA descriptors should be passed through immutable device memory, but Triton currently passes them through mutable device memory. This is unsafe unless the TMA descriptor cache is flushed *on every SM*. The current implementation attempts to flush the cache by launching a [dedicated TMA flush kernel](https://github.com/triton-lang/triton/blob/3aeb223819d632303dd2b45f4dc533d6af90dc46/python/triton/tools/experimental_descriptor.py#L34). Unfortunately, this kernel does not run on all SMs. As a result, Triton TMA kernels may hang or return incorrect results. According to @ThomasRaoux, it isn't possible to guarantee a kernel will run on every SM (as there may be another workload on a different CUDA stream). So flushing in a separate kernel is not possible. ## Proposed solution * Add fences to all example code via inline assembly. * Add documentation to inform users about the fence issue. * Remove the existing cache flush code since it is incorrect. ## Why this solution? Since each kernel needs to issue its own fence instruction, we have three options: * Inline assembly * Add a new op, like `tl._experimental_tma_acquire_fence(addr)` * Use compiler analysis to insert the fence automatically I believe we should not add a new op or analysis pass until both `__grid_constant__` and on-device descriptor mutation are landed. Once host-side descriptors switch to `__grid_constant__`, the fence will only be needed for on-device mutation, which won't require a separate op or analysis pass (simply add a fence while lowering the mutation op). If I'm wrong and we do end up needing a separate op or analysis pass, it will be trivial to clean up 6 lines of inline assembly. ## Checklist - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - Select one of the following. - [x] I have not added any `lit` tests. --- .../test/unit/hopper/test_experimental_tma.py | 8 +++++++ .../triton/tools/experimental_descriptor.py | 23 +++++++------------ python/tutorials/09-persistent-matmul.py | 8 +++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 7e6158cafb..c0228fb54a 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -74,6 +74,14 @@ def kernel(Z, desc, SIZE: tl.constexpr): @triton.jit def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, # M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + # TODO(embg) remove TMA fence after __grid_constant__ lands + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) pid_m = pid % num_pid_m diff --git a/python/triton/tools/experimental_descriptor.py b/python/triton/tools/experimental_descriptor.py index 10ce4363c7..c1265ba04b 100644 --- a/python/triton/tools/experimental_descriptor.py +++ b/python/triton/tools/experimental_descriptor.py @@ -1,35 +1,28 @@ import torch import triton -import triton.language as tl - - -@triton.jit -def flush_TMA_cache(desc_ptr): - tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", - [desc_ptr], dtype=tl.int32, is_pure=False, pack=1) +# Constructs a 1D TMA descriptor in mutable GPU memory. +# +# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's +# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. def create_1d_tma_descriptor(ptr, dim, block_dim, element_size): TMA_SIZE = 128 desc = torch.empty(TMA_SIZE, dtype=torch.int8) triton.runtime.driver.active.utils.fill_1d_tma_descriptor(ptr, dim, block_dim, element_size, desc.data_ptr()) gpu_desc = desc.cuda() - # TMA cache is not being flushed in between dispacthes, therefore we should - # manually flush the cache every time we create a new TMA descriptor to make - # sure the following dispatch don't use stale cache when accessing TMA. - flush_TMA_cache[(1, )](gpu_desc, num_warps=1) return gpu_desc +# Constructs a 2D TMA descriptor in mutable GPU memory. +# +# Note: on the first use of a new descriptor, each SM must invalidate the descriptor's +# address in TMA cache via fence.proxy.tensormap::generic.acquire.gpu. def create_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size): TMA_SIZE = 128 desc = torch.empty(TMA_SIZE, dtype=torch.int8) triton.runtime.driver.active.utils.fill_2d_tma_descriptor(ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc.data_ptr()) gpu_desc = desc.cuda() - # TMA cache is not being flushed in between dispacthes, therefore we should - # manually flush the cache every time we create a new TMA descriptor to make - # sure the following dispatch don't use stale cache when accessing TMA. - flush_TMA_cache[(1, )](gpu_desc, num_warps=1) return gpu_desc diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index fdbdbfecfb..460c374d7f 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -259,6 +259,14 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, # GROUP_SIZE_M: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # NUM_SMS: tl.constexpr): # + # TODO(embg) remove TMA fence after __grid_constant__ lands + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [a_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [b_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + tl.inline_asm_elementwise("fence.proxy.tensormap::generic.acquire.gpu [$1], 128; // $0 dummy reg", "=r, l", + [c_desc_ptr], dtype=tl.int32, is_pure=False, pack=1) + dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)