From 351744c668fc7a41a3121680b96885369b1e1e56 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 5 Jan 2024 11:04:50 +0100 Subject: [PATCH] Fixing Q2K bug (present in ggml). --- candle-core/src/quantized/metal.rs | 41 +------- candle-core/tests/quantized_tests.rs | 16 ++-- candle-metal-kernels/src/lib.rs | 138 ++++++--------------------- 3 files changed, 38 insertions(+), 157 deletions(-) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 1278fed968..b945a1125b 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -1,5 +1,5 @@ use super::{GgmlDType, QStorage}; -use crate::{DType, MetalDevice, MetalError, MetalStorage, Result}; +use crate::{DType, MetalDevice, MetalStorage, Result}; use metal::Buffer; use std::sync::Arc; @@ -27,45 +27,6 @@ impl QMetalStorage { } pub fn dequantize(&self, elem_count: usize) -> Result { - // let buffer = self - // .device - // .new_buffer(elem_count, DType::F32, "dequantize")?; - // let device = &self.device; - // let command_buffer = device.command_buffer()?; - // let name = match self.dtype { - // GgmlDType::Q4_0 => "kernel_dequantize_q4_0", - // GgmlDType::Q4_1 => "kernel_dequantize_q4_1", - // GgmlDType::Q5_0 => "kernel_dequantize_q5_0", - // GgmlDType::Q5_1 => "kernel_dequantize_q5_1", - // GgmlDType::Q8_0 => "kernel_dequantize_q8_0", - // GgmlDType::Q8_1 => "kernel_dequantize_q8_1", - // GgmlDType::Q2K => "kernel_dequantize_q2_K", - // GgmlDType::Q3K => "kernel_dequantize_q3_K", - // GgmlDType::Q4K => "kernel_dequantize_q4_K", - // GgmlDType::Q5K => "kernel_dequantize_q5_K", - // GgmlDType::Q6K => "kernel_dequantize_q6_K", - // GgmlDType::Q8K => "kernel_dequantize_q8_K", - // GgmlDType::F16 => "kernel_dequantize_f16", - // GgmlDType::F32 => "kernel_dequantize_f32", - // }; - // candle_metal_kernels::call_quantized_dequantize( - // device.device(), - // &command_buffer, - // device.kernels(), - // name, - // elem_count, - // &self.buffer, - // &buffer, - // ) - // .map_err(MetalError::from)?; - let length = self.buffer.length() as usize; - let size = self.dtype.block_size(); - // if length != size * elem_count { - // crate::bail!( - // "The Metal buffer length is not aligned with dtype {:?} ({length} vs {size})", - // self.dtype - // ); - // } let buffer = self.device.new_buffer_managed(self.buffer.length())?; let command_buffer = self.device.command_buffer()?; command_buffer.set_label("to_cpu"); diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index cc96663ba4..9bcfab7263 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -756,14 +756,14 @@ quantized_matmul!( // quantized_matmul_q8_1_metal, // GgmlDType::Q8_1 // ); -// TODO Bugged in metal - Bug also present in GGML -// quantized_matmul!( -// quantized_matmul_q2k_bis, -// quantized_matmul_q2k_cpu, -// quantized_matmul_q2k_cuda, -// quantized_matmul_q2k_metal, -// GgmlDType::Q2K -// ); +// TODO This is bugged (also bugged in GGML +quantized_matmul!( + quantized_matmul_q2k_bis, + quantized_matmul_q2k_cpu, + quantized_matmul_q2k_cuda, + quantized_matmul_q2k_metal, + GgmlDType::Q2K +); quantized_matmul!( quantized_matmul_q3k_bis, quantized_matmul_q3k_cpu, diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index f8b81db3ce..65ee39b723 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1619,7 +1619,7 @@ pub fn call_quantized_matmul_t( let r2: u32 = (ne12 / ne02) as u32; let r3: u32 = (ne13 / ne03) as u32; - let (thread_groups_count, threads_per_threadgroup) = match dtype { + let (nth0, nth1, align) = match dtype { GgmlDType::Q4_0 | GgmlDType::Q4_1 | GgmlDType::Q5_0 @@ -1628,113 +1628,58 @@ pub fn call_quantized_matmul_t( | GgmlDType::Q8_1 => { let nth0 = 8; let nth1 = 8; - let thread_groups_count = MTLSize { - width: (ne01 as u64 + 7) / 8, - height: ne11 as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 8; + (nth0, nth1, align) } GgmlDType::Q2K => { - let nth0 = 2; - let nth1 = 32; - let thread_groups_count = MTLSize { - width: (ne01 as u64 + 7) / 8, - height: ne11 as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + // Fixing a bug in Metal for GGML + let nth0 = 4; + let nth1 = 8; + let align = 4; + (nth0, nth1, align) } GgmlDType::Q4K => { let nth0 = 4; let nth1 = 8; - let thread_groups_count = MTLSize { - width: (ne01 as u64 + 3) / 4, - height: ne11 as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 4; + (nth0, nth1, align) } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; - let thread_groups_count = MTLSize { - width: (ne01 as u64 + 3) / 4, - height: ne11 as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 4; + (nth0, nth1, align) } GgmlDType::Q6K => { let nth0 = 2; let nth1 = 32; - let thread_groups_count = MTLSize { - width: (ne01 as u64 + 1) / 2, - height: ne11 as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 2; + (nth0, nth1, align) } GgmlDType::F16 | GgmlDType::Q8K => { + // Original implem uses rows let nth0 = 32; let nth1 = 1; - let nrows = 1; - let ny = (ne11 + nrows - 1) / nrows; - let thread_groups_count = MTLSize { - width: ne01 as u64, - height: ny as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 8; + (nth0, nth1, align) } GgmlDType::F32 => { let nth0 = 32; let nth1 = 1; - let nrows = 4; - let ny = (ne11 + nrows - 1) / nrows; - let thread_groups_count = MTLSize { - width: ne01 as u64, - height: ny as u64, - depth: (ne12 * ne13) as u64, - }; - let threads_per_threadgroup = MTLSize { - width: nth0, - height: nth1, - depth: 1, - }; - (thread_groups_count, threads_per_threadgroup) + let align = 8; + (nth0, nth1, align) } }; + let thread_groups_count = MTLSize { + width: divide(ne01 as usize, align), + height: ne11 as u64, + depth: (ne12 * ne13) as u64, + }; + let threads_per_threadgroup = MTLSize { + width: nth0, + height: nth1, + depth: 1, + }; let name = match dtype { GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32", GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32", @@ -1794,31 +1739,6 @@ pub fn call_quantized_matmul_t( Ok(()) } -pub fn call_quantized_dequantize( - device: &Device, - command_buffer: &CommandBufferRef, - kernels: &Kernels, - name: &'static str, - elem_count: usize, - src: &Buffer, - output: &Buffer, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; - // float4x4 - let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count / 16); - let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); - encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (elem_count, 0usize, src, output)); - encoder.use_resource(src, metal::MTLResourceUsage::Read); - encoder.use_resource(output, metal::MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); - encoder.end_encoding(); - - Ok(()) -} - fn divide(m: usize, b: usize) -> NSUInteger { ((m + b - 1) / b) as NSUInteger }