Skip to content

Commit

Permalink
Import the ggml_cuda_dp4a function.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Nov 19, 2024
1 parent e865656 commit 9e585a6
Showing 1 changed file with 44 additions and 33 deletions.
77 changes: 44 additions & 33 deletions candle-kernels/src/quantized.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)

static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A
return __dp4a(a, b, c);
#else // __CUDA_ARCH__ >= MIN_CC_DP4A
const int8_t * a8 = (const int8_t *) &a;
const int8_t * b8 = (const int8_t *) &b;
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}


#define MMQ_X_Q4_0_RDNA2 64
#define MMQ_Y_Q4_0_RDNA2 128
#define NWARPS_Q4_0_RDNA2 8
Expand Down Expand Up @@ -1821,8 +1832,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;

// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
}

const float2 ds8f = __half22float2(ds8);
Expand All @@ -1844,8 +1855,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;

// SIMD dot product of quantized values
sumi = __dp4a(vi0, u[2*i+0], sumi);
sumi = __dp4a(vi1, u[2*i+1], sumi);
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
}

#ifdef GGML_CUDA_F16
Expand Down Expand Up @@ -1878,14 +1889,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values

int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}

const float2 ds8f = __half22float2(ds8);
Expand All @@ -1909,14 +1920,14 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
sumi = __dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values

int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
}

#ifdef GGML_CUDA_F16
Expand Down Expand Up @@ -1945,7 +1956,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
}

return d8_0*d8_1 * sumi;
Expand All @@ -1959,7 +1970,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
#pragma unroll
for (int i = 0; i < vdr; ++i) {
// SIMD dot product of quantized values
sumi = __dp4a(v[i], u[i], sumi);
sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
}

#ifdef GGML_CUDA_F16
Expand Down Expand Up @@ -1994,13 +2005,13 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(

const int vi = (v >> (2*i)) & 0x03030303;

sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product

// fill int with 4x m
int m = sc >> 4;
m |= m << 8;
m |= m << 16;
sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
}

const float2 dm2f = __half22float2(dm2);
Expand Down Expand Up @@ -2029,8 +2040,8 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(

#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
sumi_d_sc = ggml_cuda_dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
sumi_m = ggml_cuda_dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m
}

sumi_d += sumi_d_sc * (sc & 0xF);
Expand Down Expand Up @@ -2071,7 +2082,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(

const int vi = __vsubss4(vil, vih);

sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
}

return d3 * sumf;
Expand All @@ -2089,7 +2100,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
int sumi_sc = 0;

for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
}

sumi += sumi_sc * scales[i0 / (QI8_1/2)];
Expand All @@ -2114,8 +2125,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;

const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u
const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u

sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
Expand All @@ -2140,7 +2151,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(

#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
}

const float2 ds8f = __half22float2(ds8[i]);
Expand Down Expand Up @@ -2176,8 +2187,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
const int v0i = vl0i | vh0i;
const int v1i = vl1i | vh1i;

const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u
const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u

sumf_d += d8[i] * (dot1 * sc[i]);
sumf_m += d8[i] * (dot2 * m[i]);
Expand All @@ -2203,7 +2214,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(

#pragma unroll
for (int j = 0; j < QI8_1; ++j) {
sumi_d = __dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
}

const float2 ds8f = __half22float2(ds8[i]);
Expand Down Expand Up @@ -2237,7 +2248,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(

const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32

sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product
sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
}

return d*sumf;
Expand All @@ -2256,11 +2267,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(

#pragma unroll
for (int i = i0; i < i0 + 2; ++i) {
sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product

sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
}

sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
Expand Down Expand Up @@ -2488,10 +2499,10 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const int v1 = q4[0];
const int v2 = q4[4];

const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
const int dot1 = ggml_cuda_dp4a(ui2, v2 & 0x0f0f0f0f, ggml_cuda_dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = ggml_cuda_dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, ggml_cuda_dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = ggml_cuda_dp4a(0x01010101, ui2, ggml_cuda_dp4a(0x01010101, ui1, 0));
const int dot4 = ggml_cuda_dp4a(0x01010101, ui4, ggml_cuda_dp4a(0x01010101, ui3, 0));

sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
Expand Down Expand Up @@ -2576,8 +2587,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);

const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
const float sumf_d = d8_1 * (ggml_cuda_dp4a(ui1, v1, 0) * s[0] + ggml_cuda_dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (ggml_cuda_dp4a(ui3, v3, 0) * s[2] + ggml_cuda_dp4a(ui4, v4, 0) * s[3]);

return d * sumf_d;
#endif
Expand Down

0 comments on commit 9e585a6

Please sign in to comment.