Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cuda fallback bf16 for compute_cap < 8.0 #2704

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ BINARY_OP_OUT(__nv_bfloat16, uint8_t, ge_bf16, x >= y)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
BINARY_OP(__half, bmul_f16, x * y)
Expand Down
30 changes: 15 additions & 15 deletions candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ __device__ double atomicAdd(double* address, double val) {
// The 16-bit __half floating-point version of atomicAdd() is only supported by devices of compute capability 7.x and higher.
// Solution adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh#L96-L119
__device__ __half atomicAdd(__half *address, __half val) {
// unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
// unsigned int old = *address_as_ui;
// unsigned int assumed;
// bool unaligned = (size_t) address & 2;
// do {
// assumed = old;
// unsigned int hsum;
// hsum = unaligned ? (old >> 16) : (old & 0xffff);
// hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
// old = atomicCAS(address_as_ui, assumed,
// unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
// );

// } while (assumed != old);
// return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
unsigned int *address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
bool unaligned = (size_t) address & 2;
do {
assumed = old;
unsigned int hsum;
hsum = unaligned ? (old >> 16) : (old & 0xffff);
hsum = __half_as_ushort(__ushort_as_half(hsum) + val);
old = atomicCAS(address_as_ui, assumed,
unaligned ? (old & 0xffff) | (hsum << 16) : (old & 0xffff0000) | hsum
);

} while (assumed != old);
return __ushort_as_half(unaligned ? (old >> 16) : (old & 0xffff));
}
#endif

Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ __device__ __forceinline__ uint32_t maxg(uint32_t a, uint32_t b) { return max(a,
__device__ __forceinline__ uint8_t ming(uint8_t a, uint8_t b) { return min(a, b); }
__device__ __forceinline__ uint8_t maxg(uint8_t a, uint8_t b) { return max(a, b); }
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __half powg(__half a, __half b) { return __float2half(powf(__half2float(a), __half2float(b))); }
__device__ __forceinline__ bool isnang(__half a) { return __hisnan(a); }
__device__ __forceinline__ __half sqrtg(__half a) { return hsqrt(a); }
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ COPY2D_OP(uint32_t, copy2d_u32)
COPY2D_OP(int64_t, copy2d_i64)

#if __CUDA_ARCH__ >= 530
#include <cuda_bf16.h>
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
COPY2D_OP(__half, copy2d_f16)
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
#endif

#if __CUDA_ARCH__ >= 800
Expand Down
2 changes: 2 additions & 0 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16)
IS_OP(__half, int64_t, is_i64_f16)
IS_OP(__half, uint32_t, is_u32_f16)
IS_OP(__half, uint8_t, is_u8_f16)
Expand Down
4 changes: 4 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,10 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#endif

#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
Expand Down
13 changes: 12 additions & 1 deletion candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,18 @@ UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
#endif

#if __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 530
#include "cuda_bf16.h"
template <typename T>
__device__ __forceinline__ T silu_fwd_fallback(T x) {
const T one = T(1.0f);
const T neg_x = -x;
const T exp_neg_x = expg(neg_x);
return x / (one + exp_neg_x);
}

UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd_fallback(x))
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, urecip_f16, recipg(x))
Expand Down