Skip to content

Commit

Permalink
Fixing Q2K bug (present in ggml).
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 5, 2024
1 parent 9f83a17 commit 351744c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 157 deletions.
41 changes: 1 addition & 40 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{GgmlDType, QStorage};
use crate::{DType, MetalDevice, MetalError, MetalStorage, Result};
use crate::{DType, MetalDevice, MetalStorage, Result};
use metal::Buffer;
use std::sync::Arc;

Expand Down Expand Up @@ -27,45 +27,6 @@ impl QMetalStorage {
}

pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
// let buffer = self
// .device
// .new_buffer(elem_count, DType::F32, "dequantize")?;
// let device = &self.device;
// let command_buffer = device.command_buffer()?;
// let name = match self.dtype {
// GgmlDType::Q4_0 => "kernel_dequantize_q4_0",
// GgmlDType::Q4_1 => "kernel_dequantize_q4_1",
// GgmlDType::Q5_0 => "kernel_dequantize_q5_0",
// GgmlDType::Q5_1 => "kernel_dequantize_q5_1",
// GgmlDType::Q8_0 => "kernel_dequantize_q8_0",
// GgmlDType::Q8_1 => "kernel_dequantize_q8_1",
// GgmlDType::Q2K => "kernel_dequantize_q2_K",
// GgmlDType::Q3K => "kernel_dequantize_q3_K",
// GgmlDType::Q4K => "kernel_dequantize_q4_K",
// GgmlDType::Q5K => "kernel_dequantize_q5_K",
// GgmlDType::Q6K => "kernel_dequantize_q6_K",
// GgmlDType::Q8K => "kernel_dequantize_q8_K",
// GgmlDType::F16 => "kernel_dequantize_f16",
// GgmlDType::F32 => "kernel_dequantize_f32",
// };
// candle_metal_kernels::call_quantized_dequantize(
// device.device(),
// &command_buffer,
// device.kernels(),
// name,
// elem_count,
// &self.buffer,
// &buffer,
// )
// .map_err(MetalError::from)?;
let length = self.buffer.length() as usize;
let size = self.dtype.block_size();
// if length != size * elem_count {
// crate::bail!(
// "The Metal buffer length is not aligned with dtype {:?} ({length} vs {size})",
// self.dtype
// );
// }
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
Expand Down
16 changes: 8 additions & 8 deletions candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -756,14 +756,14 @@ quantized_matmul!(
// quantized_matmul_q8_1_metal,
// GgmlDType::Q8_1
// );
// TODO Bugged in metal - Bug also present in GGML
// quantized_matmul!(
// quantized_matmul_q2k_bis,
// quantized_matmul_q2k_cpu,
// quantized_matmul_q2k_cuda,
// quantized_matmul_q2k_metal,
// GgmlDType::Q2K
// );
// TODO This is bugged (also bugged in GGML
quantized_matmul!(
quantized_matmul_q2k_bis,
quantized_matmul_q2k_cpu,
quantized_matmul_q2k_cuda,
quantized_matmul_q2k_metal,
GgmlDType::Q2K
);
quantized_matmul!(
quantized_matmul_q3k_bis,
quantized_matmul_q3k_cpu,
Expand Down
138 changes: 29 additions & 109 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ pub fn call_quantized_matmul_t(
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;

let (thread_groups_count, threads_per_threadgroup) = match dtype {
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
Expand All @@ -1628,113 +1628,58 @@ pub fn call_quantized_matmul_t(
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let thread_groups_count = MTLSize {
width: (ne01 as u64 + 7) / 8,
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
let nth0 = 2;
let nth1 = 32;
let thread_groups_count = MTLSize {
width: (ne01 as u64 + 7) / 8,
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let thread_groups_count = MTLSize {
width: (ne01 as u64 + 3) / 4,
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let thread_groups_count = MTLSize {
width: (ne01 as u64 + 3) / 4,
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let thread_groups_count = MTLSize {
width: (ne01 as u64 + 1) / 2,
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let nrows = 1;
let ny = (ne11 + nrows - 1) / nrows;
let thread_groups_count = MTLSize {
width: ne01 as u64,
height: ny as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let nrows = 4;
let ny = (ne11 + nrows - 1) / nrows;
let thread_groups_count = MTLSize {
width: ne01 as u64,
height: ny as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
(thread_groups_count, threads_per_threadgroup)
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
Expand Down Expand Up @@ -1794,31 +1739,6 @@ pub fn call_quantized_matmul_t(
Ok(())
}

pub fn call_quantized_dequantize(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
elem_count: usize,
src: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
// float4x4
let (thread_group_count, thread_group_size) = linear_split(&pipeline, elem_count / 16);
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (elem_count, 0usize, src, output));
encoder.use_resource(src, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();

Ok(())
}

fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
Expand Down

0 comments on commit 351744c

Please sign in to comment.