From 1f5a3446fae8ce033338331488d4aa6864ff510a Mon Sep 17 00:00:00 2001 From: Nicolas <344493+haricot@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:05:48 +0100 Subject: [PATCH] add cuda fallback bf16 for compute_cap < 8.0 --- candle-kernels/src/binary.cu | 3 +++ candle-kernels/src/compatibility.cuh | 30 ++++++++++++++-------------- candle-kernels/src/cuda_utils.cuh | 2 ++ candle-kernels/src/fill.cu | 2 ++ candle-kernels/src/indexing.cu | 2 ++ candle-kernels/src/reduce.cu | 4 ++++ candle-kernels/src/unary.cu | 13 +++++++++++- 7 files changed, 40 insertions(+), 16 deletions(-) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..7a0e74ef92 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -17,6 +17,9 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y) #endif #if __CUDA_ARCH__ >= 530 +#include "cuda_bf16.h" +BINARY_OP(__nv_bfloat16, bmul_bf16, x * y) +BINARY_OP(__nv_bfloat16, badd_bf16, x + y) BINARY_OP(__half, badd_f16, x + y) BINARY_OP(__half, bdiv_f16, x / y) BINARY_OP(__half, bmul_f16, x * y) diff --git a/candle-kernels/src/compatibility.cuh b/candle-kernels/src/compatibility.cuh index d0791749bb..799f509691 100644 --- a/candle-kernels/src/compatibility.cuh +++ b/candle-kernels/src/compatibility.cuh @@ -40,21 +40,21 @@ __device__ double atomicAdd(double* address, double val) { // The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher. // Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119 __device__ __half atomicAdd(__half *address, __half val) { - // unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); - // unsigned int old = *address_as_ui; - // unsigned int assumed; - // bool unaligned = (size_t) address & 2; - // do { - // assumed = old; - // unsigned int hsum; - // hsum = unaligned ? (old >> 16) : (old & 0xffff); - // hsum = __half_as_ushort(__ushort_as_half(hsum) + val); - // old = atomicCAS(address_as_ui, assumed, - // unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum - // ); - - // } while (assumed != old); - // return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); + unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + bool unaligned = (size_t) address & 2; + do { + assumed = old; + unsigned int hsum; + hsum = unaligned ? (old >> 16) : (old & 0xffff); + hsum = __half_as_ushort(__ushort_as_half(hsum) + val); + old = atomicCAS(address_as_ui, assumed, + unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum + ); + + } while (assumed != old); + return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff)); } #endif diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..4ae23317b7 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -159,6 +159,8 @@ __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a, __device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); } __device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); } #if __CUDA_ARCH__ >= 530 +#include "cuda_bf16.h" +__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); } __device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); } __device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); } __device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..54962efeae 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -37,8 +37,10 @@ COPY2D_OP(uint32_t, copy2d_u32) COPY2D_OP(int64_t, copy2d_i64) #if __CUDA_ARCH__ >= 530 +#include extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); } COPY2D_OP(__half, copy2d_f16) +COPY2D_OP(__nv_bfloat16, copy2d_bf16) #endif #if __CUDA_ARCH__ >= 800 diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..bf1d173b52 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -162,6 +162,8 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +#include "cuda_bf16.h" +IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 079c370873..9cfc6aed71 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -581,6 +581,10 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #endif #if __CUDA_ARCH__ >= 530 +#include "cuda_bf16.h" +ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) +SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) SOFTMAX_OP(__half, float, softmax_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..fdf7310752 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -124,7 +124,18 @@ UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x)) UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x)) #endif -#if __CUDA_ARCH__ >= 530 +#if __CUDA_ARCH__ >= 530 +#include "cuda_bf16.h" +template +__device__ __forceinline__ T silu_fwd_fallback(T x) { + const T one = T(1.0f); + const T neg_x = -x; + const T exp_neg_x = expg(neg_x); + return x / (one + exp_neg_x); +} + +UNARY_OP(__nv_bfloat16, ucopy_bf16, x) +UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x)) UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) UNARY_OP(__half, urecip_f16, recipg(x))