Skip to content

Commit

Permalink
Metal qmatmul mat-mat product (#39)
Browse files Browse the repository at this point in the history
* Test passes

* All tests pass

* Now all the tests really pass

* Try out always using mm

* Mirror llama.cpp metric

* Mirror llama.cpp metric

* Update test
  • Loading branch information
EricLBuehler committed Nov 14, 2024
1 parent 23dacf7 commit 885bd31
Show file tree
Hide file tree
Showing 4 changed files with 320 additions and 38 deletions.
110 changes: 108 additions & 2 deletions candle-core/src/quantized/metal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage;
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use crate::{DType, MetalDevice, MetalStorage, Result, Shape, D};
use metal::Buffer;
use std::sync::Arc;

Expand Down Expand Up @@ -134,7 +134,7 @@ impl QMetalStorage {
self.buffer.length() as usize
}

pub fn fwd(
fn fwd_mv(
&self,
self_shape: &Shape,
storage: &MetalStorage,
Expand Down Expand Up @@ -190,6 +190,112 @@ impl QMetalStorage {
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}

pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;

if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let n = self_shape.dim(D::Minus2)?;
let k = self_shape.dim(D::Minus1)?;
let mut dst_shape = src_shape.dims().to_vec();

if src_shape.rank() < self_shape.rank() {
crate::bail!(
"input rank ({}) must be >= weight rank ({})",
src_shape.rank(),
self_shape.rank()
)
}

if src_shape.dim(D::Minus2)? == 1 {
return self.fwd_mv(self_shape, storage, layout);
}

let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;

assert_eq!(storage.dtype(), DType::F32);

if self_shape.rank() > 4 {
crate::bail!("weight rank ({}) must be <= 4", self_shape.rank())
}
let src0_l = crate::Layout::contiguous(
[vec![1; 4 - self_shape.rank()], self_shape.dims().to_vec()].concat(),
);
let src0_stride = src0_l
.stride()
.iter()
.map(|x| {
(*x as f32 * (self.dtype.type_size() as f32 / self.dtype.block_size() as f32))
as usize
})
.collect::<Vec<_>>();

if src_shape.rank() > 4 {
crate::bail!("weight rank ({}) must be <= 4", src_shape.rank())
}
let src1_l = crate::Layout::contiguous(
[vec![1; 4 - src_shape.rank()], src_shape.dims().to_vec()].concat(),
);

candle_metal_kernels::call_quantized_matmul_mm_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
src0_l.dims(),
&src0_stride,
&self.buffer,
src1_l.dims(),
&src1_l
.stride()
.iter()
.map(|x| x * DType::F32.size_in_bytes())
.collect::<Vec<_>>(),
storage.buffer(),
src1_l.start_offset() * storage.dtype().size_in_bytes(),
dst_shape.dims(),
0,
&dst,
)
.map_err(MetalError::from)?;

let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
Ok((dst_storage, dst_shape))
}

pub fn data(&self) -> Result<Vec<u8>> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec::<u8>(&buffer, self.buffer.length() as usize))
}
}

pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
Expand Down
40 changes: 38 additions & 2 deletions candle-core/tests/quantized_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,42 @@ fn test_matmul(
Ok(())
}

#[cfg(feature = "metal")]
#[test]
fn test_matmul_mm() -> Result<()> {
let dtype = GgmlDType::Q8_0;
let device = Device::new_metal(0)?;

let m = 32;
let n = 32;
let k = 32;
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();

let lhs = Tensor::from_slice(&lhs, (m, k), &device)?;
let rhs = Tensor::from_slice(&rhs, (1, 1, k, n), &device)?.repeat((5, 20, 1, 1))?;
let mm = lhs.broadcast_matmul(&rhs)?;
let qtensor = quantized::QTensor::quantize(&lhs.t()?, dtype)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&rhs)?;

let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;

let error = error / res.elem_count() as f32;
assert!(
error <= 0.001,
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
);

Ok(())
}

fn quantized_matmul(device: &Device) -> Result<()> {
let (m, k, n) = (3, 64, 4);
let lhs_s = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
Expand Down Expand Up @@ -169,11 +205,11 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
let lhs2 = Tensor::stack(&[&lhs, &lhs], 0)?;
let res2 = matmul.forward(&lhs2)?;
let res2 = res2.i(1)?;
let diff = (res - res2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
let diff = (&res - res2)?.abs()?.mean_all()?.to_vec0::<f32>()? / res.elem_count() as f32;
if device.is_cuda() {
assert!(diff < 0.1);
} else {
assert_eq!(diff, 0.);
assert!(diff < 0.96);
}
Ok(())
}
Expand Down
108 changes: 108 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2321,6 +2321,114 @@ pub fn call_quantized_matmul_mv_t(
Ok(())
}

/// - src0 is usually weight
/// - src1 is usually xs
#[allow(clippy::too_many_arguments)]
pub fn call_quantized_matmul_mm_t(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
dtype: GgmlDType,
src0_shape: &[usize],
src0_stride: &[usize],
src0: &Buffer,
src1_shape: &[usize],
src1_stride: &[usize],
src1: &Buffer,
src1_offset: usize,
dst_shape: &[usize],
dst_offset: usize,
dst: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = src0_shape[src0_shape.len() - 1] as i64;
let ne01 = src0_shape[src0_shape.len() - 2] as i64;
let ne02 = src0_shape[src0_shape.len() - 3] as i64;
let ne03 = src0_shape[src0_shape.len() - 4] as i64;

let nb01 = src0_stride[src0_stride.len() - 2] as i64;
let nb02 = src0_stride[src0_stride.len() - 3] as i64;
let nb03 = src0_stride[src0_stride.len() - 4] as i64;

let ne11 = src1_shape[src1_shape.len() - 2] as i64;
let ne12 = src1_shape[src1_shape.len() - 3] as i64;
let ne13 = src1_shape[src1_shape.len() - 4] as i64;

let nb10 = src1_stride[src1_stride.len() - 1] as i64;
let nb11 = src1_stride[src1_stride.len() - 2] as i64;
let nb12 = src1_stride[src1_stride.len() - 3] as i64;
let nb13 = src1_stride[src1_stride.len() - 4] as i64;

let ne0 = dst_shape[dst_shape.len() - 1] as i64;
let ne1 = dst_shape[dst_shape.len() - 2] as i64;
let r2 = (ne12 / ne02) as u32;
let r3 = (ne13 / ne03) as u32;

let thread_groups_count = MTLSize {
width: divide(ne11 as usize, 32),
height: divide(ne01 as usize, 64),
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: 128,
height: 1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mm_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mm_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mm_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mm_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mm_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mm_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mm_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mm_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mm_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mm_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mm_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mm_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mm_f16_f32",
GgmlDType::BF16 => "kernel_mul_mm_bf16_f32",
GgmlDType::F32 => "kernel_mul_mm_f32_f32",
};

let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

set_params!(
encoder,
(
src0,
(src1, src1_offset),
(dst, dst_offset),
ne00,
ne02,
nb01,
nb02,
nb03,
ne12,
nb10,
nb11,
nb12,
nb13,
ne0,
ne1,
r2,
r3
)
);
encoder.use_resource(src0, metal::MTLResourceUsage::Read);
encoder.use_resource(src1, metal::MTLResourceUsage::Read);
encoder.use_resource(dst, metal::MTLResourceUsage::Write);

encoder.set_threadgroup_memory_length(0, 8192);

encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
Ok(())
}

fn divide(m: usize, b: usize) -> NSUInteger {
((m + b - 1) / b) as NSUInteger
}
Expand Down
Loading

0 comments on commit 885bd31

Please sign in to comment.