From 36e1dcc12f46991bb76f2b4de9eea93c6e347e4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Sun, 22 Dec 2024 07:59:37 +0100 Subject: [PATCH 01/11] wip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: José Carlos García --- tensor-tools/src/main.rs | 55 +++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 0bda36d524..51910b1ad4 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -1,7 +1,8 @@ use candle::quantized::{gguf_file, GgmlDType, QTensor}; -use candle::{Device, Result}; +use candle::{Device, Result, Tensor}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; +use safetensors::tensor; #[derive(ValueEnum, Debug, Clone)] enum QuantizationMode { @@ -11,7 +12,7 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result { + fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType, bitnet_mode: bool) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -143,6 +144,9 @@ enum Command { #[arg(long)] out_file: std::path::PathBuf, + #[clap(long, short, action)] + bitnet_mode: bool, + /// The quantization schema to apply. #[arg(long, value_enum)] quantization: Quantization, @@ -395,10 +399,32 @@ fn run_ls( Ok(()) } +fn unpack_bitnet_weights(tensor: &Tensor) -> Result { + let packed_vec = tensor.to_vec2::().unwrap(); + + let rows = tensor.dim(0).unwrap(); + let cols = tensor.dim(1).unwrap(); + + let mut unpacked_vec = vec![0f32; rows * 4 * cols]; + for i in 0..rows { + for j in 0..cols { + let packed = packed_vec[i][j]; + for k in 0..4 { + let bits = ((packed >> (k * 2)) & 0b11) as i8 - 1; + unpacked_vec[(i * 4 + k) * cols + j] = bits as f32; + } + } + } + + let unpacked_tensor = Tensor::from_vec(unpacked_vec, (rows*4, cols), tensor.device())?; + Ok(unpacked_tensor) +} + fn run_quantize_safetensors( in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, q: Quantization, + bitnet_mode: bool, ) -> Result<()> { let mut out_file = std::fs::File::create(out_file)?; let mut tensors = std::collections::HashMap::new(); @@ -414,7 +440,22 @@ fn run_quantize_safetensors( let qtensors = tensors .into_par_iter() .map(|(name, tensor)| { - let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + let mut should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + let mut tensor = tensor; + if should_quantize && bitnet_mode { + let is_bitnet_weight = name.contains("self_attn.k_proj") || + name.contains("self_attn.v_proj") || + name.contains("self_attn.q_proj") || + name.contains("self_attn.o_proj") || + name.contains("mlp.down_proj") || + name.contains("mlp.up_proj") || + name.contains("mlp.gate_proj"); + + if is_bitnet_weight { + println!(" unpacking {name} {tensor:?} {should_quantize}"); + tensor = unpack_bitnet_weights(&tensor)?; + } + } println!(" quantizing {name} {tensor:?} {should_quantize}"); let tensor = if should_quantize { QTensor::quantize(&tensor, dtype)? @@ -454,6 +495,7 @@ fn run_quantize( out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, + bitnet_mode: bool, device: &Device, ) -> Result<()> { if in_files.is_empty() { @@ -466,7 +508,7 @@ fn run_quantize( } if let Some(extension) = in_files[0].extension() { if extension == "safetensors" { - return run_quantize_safetensors(in_files, out_file, q); + return run_quantize_safetensors(in_files, out_file, q, bitnet_mode); } } @@ -488,7 +530,7 @@ fn run_quantize( println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_files[0])?; let tensor = content.tensor(&mut in_file, name, device)?; - let tensor = qmode.quantize(name, tensor, dtype)?; + let tensor = qmode.quantize(name, tensor, dtype, bitnet_mode)?; Ok((name, tensor)) }) .collect::>>()?; @@ -535,7 +577,8 @@ fn main() -> anyhow::Result<()> { out_file, quantization, mode, - } => run_quantize(&in_file, out_file, quantization, mode, &device)?, + bitnet_mode, + } => run_quantize(&in_file, out_file, quantization, mode, bitnet_mode, &device)?, Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(()) From 23373d11c978cc4449aa2ce8029bd95349f4a6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 23 Dec 2024 00:00:09 +0100 Subject: [PATCH 02/11] initial: q support --- Cargo.toml | 2 +- candle-core/src/quantized/gguf_file.rs | 18 +- .../examples/quantized-bitnet/main.rs | 417 +++++++++++++ candle-examples/examples/quantized/main.rs | 3 +- candle-transformers/src/models/mod.rs | 1 + .../src/models/quantized_llama_bitnet.rs | 552 ++++++++++++++++++ tensor-tools/Cargo.toml | 1 + tensor-tools/src/main.rs | 89 ++- 8 files changed, 1075 insertions(+), 8 deletions(-) create mode 100644 candle-examples/examples/quantized-bitnet/main.rs create mode 100644 candle-transformers/src/models/quantized_llama_bitnet.rs diff --git a/Cargo.toml b/Cargo.toml index 17e7e4ba57..fe6d4f7e01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,7 +66,7 @@ serde = { version = "1.0.171", features = ["derive"] } serde_plain = "1.0.2" serde_json = "1.0.99" thiserror = "1" -tokenizers = { version = "0.19.1", default-features = false } +tokenizers = { version = "0.21.0", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index ccbd59eb5c..af5d3a46e8 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -174,6 +174,22 @@ impl Value { } } + pub fn from_u8(v: u8) -> Self { + Self::U8(v) + } + + pub fn from_u64(v: u64) -> Self { + Self::U64(v) + } + + pub fn from_u32(v: u32) -> Self { + Self::U32(v) + } + + pub fn from_f32(v: f32) -> Self { + Self::F32(v) + } + pub fn to_u8(&self) -> Result { match self { Self::U8(v) => Ok(*v), @@ -489,7 +505,7 @@ fn write_string(w: &mut W, str: &str) -> Result<()> { pub fn write( w: &mut W, - metadata: &[(&str, &Value)], + metadata: &[(&str, Value)], tensors: &[(&str, &QTensor)], ) -> Result<()> { w.write_u32::(0x46554747)?; diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs new file mode 100644 index 0000000000..9407c30f0e --- /dev/null +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -0,0 +1,417 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::{ggml_file, gguf_file}; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_llama_bitnet as model; +use model::ModelWeights; + +const DEFAULT_PROMPT: &str = "My favorite theorem is "; + +#[derive(Debug)] +enum Prompt { + Interactive, + Chat, + One(String), +} + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "falcon3-1b-1.58")] + Falcon3_1b1_58, +} + +impl Which { + fn is_mistral(&self) -> bool { + match self { + Self::Falcon3_1b1_58 => false, + } + } + + fn is_zephyr(&self) -> bool { + match self { + Self::Falcon3_1b1_58 => false, + } + } + + fn is_open_chat(&self) -> bool { + match self { + Self::Falcon3_1b1_58 => false, + } + } + + fn tokenizer_repo(&self) -> &'static str { + match self { + Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", + } + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from l + /// lama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "falcon3-1b-1.58")] + which: Which, + + /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. + #[arg(long)] + gqa: Option, + + /// Use the slower dmmv cuda kernel. + #[arg(long)] + force_dmmv: bool, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = self.which.tokenizer_repo(); + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename) = match self.which { + Which::Falcon3_1b1_58 => ( + "tiiuae/Falcon3-1B-Instruct-1.58bit", + "Falcon3-1B-Instruct-1.58bit.gguf", + ), + }; + let revision = "main"; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + #[cfg(feature = "cuda")] + candle::quantized::cuda::set_force_dmmv(args.force_dmmv); + + candle::cuda::set_gemm_reduced_precision_f16(true); + candle::cuda::set_gemm_reduced_precision_bf16(true); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = match model_path.extension().and_then(|v| v.to_str()) { + Some("gguf") => { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file, &device)? + } + Some("ggml" | "bin") | Some(_) | None => { + let model = ggml_file::Content::read(&mut file, &device) + .map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensors.iter() { + let elem_count = tensor.shape().elem_count(); + total_size_in_bytes += + elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensors.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + println!("params: {:?}", model.hparams); + let default_gqa = 1; + ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? + } + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt = match args.prompt.as_deref() { + Some("chat") => Prompt::Chat, + Some("interactive") => Prompt::Interactive, + Some(s) => Prompt::One(s.to_string()), + None => Prompt::One(DEFAULT_PROMPT.to_string()), + }; + + let mut pre_prompt_tokens = vec![]; + for prompt_index in 0.. { + let prompt_str = match &prompt { + Prompt::One(prompt) => prompt.clone(), + Prompt::Interactive | Prompt::Chat => { + let is_interactive = matches!(prompt, Prompt::Interactive); + print!("> "); + std::io::stdout().flush()?; + let mut prompt = String::new(); + std::io::stdin().read_line(&mut prompt)?; + if prompt.ends_with('\n') { + prompt.pop(); + if prompt.ends_with('\r') { + prompt.pop(); + } + } + if args.which.is_open_chat() { + format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:") + } else if args.which.is_zephyr() { + if prompt_index == 0 || is_interactive { + format!("<|system|>\n\n<|user|>\n{prompt}\n<|assistant|>",) + } else { + format!("<|user|>\n{prompt}\n<|assistant|>") + } + } else if args.which.is_mistral() { + format!("[INST] {prompt} [/INST]") + } else { + prompt + } + } + }; + print!("{}", &prompt_str); + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + if args.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + + let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat(); + let to_sample = args.sample_len.saturating_sub(1); + let prompt_tokens = if prompt_tokens.len() + to_sample > model::MAX_SEQ_LEN - 10 { + let to_remove = prompt_tokens.len() + to_sample + 10 - model::MAX_SEQ_LEN; + prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec() + } else { + prompt_tokens + }; + let mut all_tokens = vec![]; + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in prompt_tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let eos_token = match args.which { + Which::Falcon3_1b1_58 => "<|endoftext|>", + }; + + let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); + let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, prompt_tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + prompt_tokens.len(), + prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + + match prompt { + Prompt::One(_) => break, + Prompt::Interactive => {} + Prompt::Chat => { + pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat() + } + } + } + + Ok(()) +} diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index 2b537aac9e..a089b05380 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -198,7 +198,8 @@ impl Which { #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { - /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp + /// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from l + /// lama.cpp #[arg(long)] model: Option, diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index be1f15c413..34fc8b57e4 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -110,3 +110,4 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; +pub mod quantized_llama_bitnet; \ No newline at end of file diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs new file mode 100644 index 0000000000..30d9b47726 --- /dev/null +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -0,0 +1,552 @@ +//! Quantized llama model implementation. +//! +//! This provides a quantized implementation of the llama language model architecture. +//! The model implements parameter efficient quantization for reduced memory usage +//! while maintaining model quality. +//! +//! Key characteristics: +//! - Transformer decoder architecture +//! - Support for 2/3/4/8-bit quantization +//! - Optimized memory usage through quantization +//! - Configurable model sizes and parameter counts +//! +//! - 💻 [GH Link](https://github.com/facebookresearch/llama) +//! - 📝 [Paper](https://arxiv.org/abs/2302.13971) +//! +//! ![](https://raw.githubusercontent.com/huggingface/candle/main/candle-examples/examples/quantized/assets/aoc.gif) +//! + +use std::collections::HashMap; + +use crate::quantized_nn::RmsNorm; +use candle::quantized::QTensor; +use candle::quantized::{ggml_file, gguf_file}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, Module}; + +pub const MAX_SEQ_LEN: usize = 4096; + +// QMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] +struct QMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + +impl QMatMul { + fn from_qtensor(qtensor: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs) + } +} + +// BitQMatMul wrapper adding some tracing. +#[derive(Debug, Clone)] +struct BitQMatMul { + inner: candle::quantized::QMatMul, + span: tracing::Span, +} + + +fn activation_quant(x: &Tensor) -> Result<(Tensor, Tensor)> { + let scale = (127.0 + / x.abs()? + .max(D::Minus1)? + .max(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?)? + .to_dtype(x.dtype())?; + + let y = x + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .round()? + .clamp(-128.0, 127.0)?; + + Ok((y, scale)) +} + +impl BitQMatMul { + fn from_qtensor(qtensor: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + Ok(Self { inner, span }) + } + + fn forward(&self, x: &Tensor) -> Result { + let (x, x_scale) = activation_quant(x)?; + + let _enter = self.span.enter(); + self.inner.forward(&x)?.broadcast_div(&x_scale) + } +} + + +#[derive(Debug, Clone)] +struct Mlp { + feed_forward_w1: BitQMatMul, + feed_forward_w2: BitQMatMul, + feed_forward_w3: BitQMatMul, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let w1 = self.feed_forward_w1.forward(xs)?; + let w3 = self.feed_forward_w3.forward(xs)?; + self.feed_forward_w2 + .forward(&(candle_nn::ops::silu(&w1)? * w3)?) + } +} + +#[derive(Debug, Clone)] +enum MlpOrMoe { + Mlp(Mlp), + MoE { + n_expert_used: usize, + feed_forward_gate_inp: QMatMul, + experts: Vec, + }, +} + +impl Module for MlpOrMoe { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::MoE { + feed_forward_gate_inp, + experts, + n_expert_used, + } => { + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + let xs = xs.reshape(((), hidden_dim))?; + let router_logits = feed_forward_gate_inp.forward(&xs)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // In order to extract topk, we extract the data from the tensor and manipulate it + // directly. Maybe we will want to use some custom ops instead at some point. + let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::()?; + + // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + // top_x contains the row indexes to evaluate for each expert. + let mut top_x = vec![vec![]; experts.len()]; + let mut selected_rws = vec![vec![]; experts.len()]; + for (row_idx, rw) in routing_weights.iter().enumerate() { + let mut dst = (0..rw.len() as u32).collect::>(); + dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize])); + let mut sum_routing_weights = 0f32; + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + sum_routing_weights += routing_weight; + top_x[expert_idx].push(row_idx as u32); + } + for &expert_idx in dst.iter().take(*n_expert_used) { + let expert_idx = expert_idx as usize; + let routing_weight = rw[expert_idx]; + selected_rws[expert_idx].push(routing_weight / sum_routing_weights) + } + } + + // routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + let mut ys = xs.zeros_like()?; + for (expert_idx, expert_layer) in experts.iter().enumerate() { + let top_x = &top_x[expert_idx]; + if top_x.is_empty() { + continue; + } + let top_x = Tensor::new(top_x.as_slice(), xs.device())?; + let selected_rws = + Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())? + .reshape(((), 1))?; + // Index the correct hidden states and compute the expert hidden state for + // the current expert. We need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1 and top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?; + // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None]) + let current_hidden_states = expert_layer.forward(¤t_state)?; + let current_hidden_states = + current_hidden_states.broadcast_mul(&selected_rws)?; + ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?; + } + + let ys = ys.reshape((b_size, seq_len, hidden_dim))?; + Ok(ys) + } + Self::Mlp(mlp) => mlp.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attention_wq: BitQMatMul, + attention_wk: BitQMatMul, + attention_wv: BitQMatMul, + attention_wo: BitQMatMul, + attention_norm: RmsNorm, + mlp_or_moe: MlpOrMoe, + ffn_norm: RmsNorm, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + neg_inf: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, + span_mlp: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let _enter = self.span_rot.enter(); + let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + // The call to contiguous below is only necessary when processing the prompt. + // When the seq_len is 1 in the inference loop, this is a no-op. + candle_nn::rotary_emb::rope_i(&x.contiguous()?, &cos, &sin) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let q = self.attention_wq.forward(x)?; + let k = self.attention_wk.forward(x)?; + let v = self.attention_wv.forward(x)?; + + let q = q + .reshape((b_sz, seq_len, self.n_head, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))? + .transpose(1, 2)? + // This call to contiguous ensures that the fast kernel can be called below. It's + // actually a no-op except when processing the initial prompt so has no significant + // impact on performance. + .contiguous()?; + + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let y = if q.device().is_metal() && seq_len == 1 { + // SDPA will do MQA for us + candle_nn::ops::sdpa(&q, &k, &v, 1. / (self.head_dim as f32).sqrt(), 1.)? + } else { + // Support for MQA, useful for 70B models and mistral. + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + att.matmul(&v.contiguous()?)? + }; + + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attention_wo.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec, + norm: RmsNorm, + output: QMatMul, + masks: HashMap, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +impl ModelWeights { + pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result { + let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; + let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; + let tok_embeddings = ct.remove("tok_embeddings.weight")?; + let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; + let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; + let output = ct.remove("output.weight")?; + let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); + for layer_idx in 0..ct.hparams.n_layer { + let prefix = format!("layers.{layer_idx}"); + let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let mlp_or_moe = { + let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + }) + }; + let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; + let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: BitQMatMul::from_qtensor(attention_wq)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, + mlp_or_moe, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, + n_head: ct.hparams.n_head as usize, + n_kv_head: ct.hparams.n_head as usize / gqa, + head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), + layers, + norm, + output: QMatMul::from_qtensor(output)?, + masks: HashMap::new(), + span, + span_output, + }) + } + + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings_q.dequantize(device)?; + let norm = RmsNorm::from_qtensor( + ct.tensor(reader, "output_norm.weight", device)?, + rms_norm_eps, + )?; + let output = match ct.tensor(reader, "output.weight", device) { + Ok(tensor) => tensor, + Err(_) => tok_embeddings_q, + }; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let mlp_or_moe = if n_expert <= 1 { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + MlpOrMoe::Mlp(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + }) + } else { + let feed_forward_gate_inp = + ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; + let mut experts = Vec::with_capacity(n_expert); + for i in 0..n_expert { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + + experts.push(Mlp { + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + }) + } + MlpOrMoe::MoE { + n_expert_used, + feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?, + experts, + } + }; + let attention_norm = + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; + let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); + layers.push(LayerWeights { + attention_wq: BitQMatMul::from_qtensor(attention_wq)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo)?, + attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, + mlp_or_moe, + ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + span_mlp, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + norm, + output: QMatMul::from_qtensor(output)?, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = x.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, x.device())?) + }; + let _enter = self.span.enter(); + let mut layer_in = self.tok_embeddings.forward(x)?; + for layer in self.layers.iter_mut() { + let x = layer_in; + let residual = &x; + let x = layer.attention_norm.forward(&x)?; + let attn = layer.forward_attn(&x, mask.as_ref(), index_pos)?; + let x = (attn + residual)?; + + // MLP + let _enter = layer.span_mlp.enter(); + let residual = &x; + let x = layer.ffn_norm.forward(&x)?; + let x = layer.mlp_or_moe.forward(&x)?; + let x = (x + residual)?; + layer_in = x + } + let x = self.norm.forward(&layer_in)?; + let x = x.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&x) + } +} \ No newline at end of file diff --git a/tensor-tools/Cargo.toml b/tensor-tools/Cargo.toml index eecd7e4353..b48e81cf1d 100644 --- a/tensor-tools/Cargo.toml +++ b/tensor-tools/Cargo.toml @@ -14,3 +14,4 @@ candle = { workspace = true } clap = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } +serde_json = { workspace = true } \ No newline at end of file diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 51910b1ad4..2f588d0abe 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -3,6 +3,7 @@ use candle::{Device, Result, Tensor}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; use safetensors::tensor; +use serde_json; #[derive(ValueEnum, Debug, Clone)] enum QuantizationMode { @@ -428,7 +429,11 @@ fn run_quantize_safetensors( ) -> Result<()> { let mut out_file = std::fs::File::create(out_file)?; let mut tensors = std::collections::HashMap::new(); + let metadata_file = in_files.iter().find(|f| f.to_string_lossy().ends_with("config.json")); for in_file in in_files.iter() { + if metadata_file.is_some() && in_file == metadata_file.unwrap() { + continue; + } let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; tensors.extend(in_tensors) } @@ -439,14 +444,15 @@ fn run_quantize_safetensors( let qtensors = tensors .into_par_iter() - .map(|(name, tensor)| { - let mut should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + .map(|(mut name, tensor)| { + let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let mut tensor = tensor; if should_quantize && bitnet_mode { - let is_bitnet_weight = name.contains("self_attn.k_proj") || + let is_bitnet_weight = name.contains("self_attn.v_proj") || name.contains("self_attn.q_proj") || name.contains("self_attn.o_proj") || + name.contains("self_attn.k_proj") || name.contains("mlp.down_proj") || name.contains("mlp.up_proj") || name.contains("mlp.gate_proj"); @@ -462,6 +468,30 @@ fn run_quantize_safetensors( } else { QTensor::quantize(&tensor, GgmlDType::F32)? }; + + if name == "model.embed_tokens.weight" { + name = "token_embd.weight".to_string(); + } + + if name == "model.norm.weight" { + name = "output_norm.weight".to_string() + } + + if name == "lm_head.weight" { + name = "output.weight".to_string() + } + + name = name.replace("model.layers.", "blk."); + name = name.replace("self_attn.q_proj", "attn_q"); + name = name.replace("self_attn.k_proj", "attn_k"); + name = name.replace("self_attn.v_proj", "attn_v"); + name = name.replace("self_attn.o_proj", "attn_output"); + name = name.replace("mlp.gate_proj", "ffn_gate"); + name = name.replace("mlp.down_proj", "ffn_down"); + name = name.replace("mlp.up_proj", "ffn_up"); + name = name.replace("input_layernorm", "attn_norm"); + name = name.replace("post_attention_layernorm", "ffn_norm"); + Ok((name, tensor)) }) .collect::>>()?; @@ -469,7 +499,56 @@ fn run_quantize_safetensors( .iter() .map(|(k, v)| (k.as_str(), v)) .collect::>(); - gguf_file::write(&mut out_file, &[], &qtensors)?; + + // Load metadata + let gguf_metadata: Vec<(&str, gguf_file::Value)> = if let Some(metadata_file) = metadata_file { + let metadata = std::fs::read_to_string(metadata_file)?; + let metadata: serde_json::Value = serde_json::from_str(&metadata).unwrap(); + + let num_attention_heads = gguf_file::Value::from_u32(metadata["num_attention_heads"].as_u64().unwrap() as u32); + let num_attention_heads_kv = gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32); + + let num_hidden_layers = gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32); + let embedding_length = gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32); + let rope_dimension_count = gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32) + ); + let layer_norm_eps = gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32); + + let mut gguf_metadata: Vec<(&str, gguf_file::Value)> = Vec::new(); + gguf_metadata.push(( + "llama.attention.head_count", + num_attention_heads.clone(), + )); + gguf_metadata.push(( + "llama.attention.head_count_kv", + num_attention_heads_kv.clone(), + )); + gguf_metadata.push(( + "llama.block_count", + num_hidden_layers.clone(), + )); + gguf_metadata.push(( + "llama.embedding_length", + embedding_length.clone(), + )); + gguf_metadata.push(( + "llama.attention.layer_norm_rms_epsilon", layer_norm_eps.clone() + )); + gguf_metadata.push(( + "llama.rope.dimension_count", + rope_dimension_count.clone(), + )); + + // Print metadata + for (key, value) in gguf_metadata.iter() { + println!(" {key}: {value:?}"); + } + gguf_metadata + } else { + Vec::new() + }; + gguf_file::write(&mut out_file, gguf_metadata.as_slice(), &qtensors)?; Ok(()) } @@ -542,7 +621,7 @@ fn run_quantize( let metadata = content .metadata .iter() - .map(|(k, v)| (k.as_str(), v)) + .map(|(k, v)| (k.as_str(), v.clone())) .collect::>(); gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?; Ok(()) From e7e23e324f8b99e5e905ffb64a8a41c53829fbd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 23 Dec 2024 21:43:23 +0100 Subject: [PATCH 03/11] Initial quantized support --- candle-core/src/quantized/ggml_file.rs | 6 ++ candle-core/src/quantized/k_quants.rs | 92 +++++++++++++++++++ candle-core/src/quantized/mod.rs | 7 +- candle-core/src/quantized/neon.rs | 59 +++++++++++- .../examples/quantized-bitnet/main.rs | 48 ++++------ .../src/models/quantized_llama_bitnet.rs | 84 +++++++++-------- tensor-tools/src/main.rs | 34 +++++-- 7 files changed, 255 insertions(+), 75 deletions(-) diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 0f7e9c118c..0afd150e5d 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -183,6 +183,12 @@ pub fn qtensor_from_ggml( GgmlDType::Q6K => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Q8K => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } + GgmlDType::Q2b0 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 1d3e053898..44a38b63ca 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -154,6 +154,14 @@ pub struct BlockQ8K { } const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::()); +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQ2b0 { + pub(crate) qs: [i8; QK_K / 4], // Every single byte represents 4 values. +} + +const _: () = assert!(QK_K / 4 == std::mem::size_of::()); + impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; const BLCK_SIZE: usize = QK4_0; @@ -1838,6 +1846,90 @@ impl GgmlType for BlockQ8K { } } +impl GgmlType for BlockQ2b0 { + const DTYPE: GgmlDType = GgmlDType::Q2b0; + const BLCK_SIZE: usize = QK_K; + type VecDotType = BlockQ8K; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> crate::Result<()> { + let k = ys.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!( + "to_float Q2b0: size {} is not divisible by {}", + k, + Self::BLCK_SIZE + ); + } + + let nb = k / Self::BLCK_SIZE; + for i in 0..nb { + let base = i * Self::BLCK_SIZE; + for (j, &qbyte) in xs[i].qs.iter().enumerate() { + let start = base + j * 4; + ys[start] = (qbyte & 0b11) as f32 - 2.0; + ys[start + 1] = ((qbyte >> 2) & 0b11) as f32 - 2.0; + ys[start + 2] = ((qbyte >> 4) & 0b11) as f32 - 2.0; + ys[start + 3] = (((qbyte >> 6) & 0b11) as f32 - 2.0); + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> crate::Result<()> { + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("from_float Q2b0: size {} is not divisible by {}", k, Self::BLCK_SIZE); + } + + let nb = k / Self::BLCK_SIZE; + for i in 0..nb { + let slice = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + ys[i].qs.fill(0); + + for (j, qbyte) in ys[i].qs.iter_mut().enumerate() { + let start = j * 4; + let q0 = ((slice[start] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; + let q1 = ((slice[start + 1] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; + let q2 = ((slice[start + 2] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; + let q3 = ((slice[start + 3] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; + + *qbyte = q0 | (q1 << 2) | (q2 << 4) | (q3 << 6); + } + } + Ok(()) + } + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> crate::Result { + let nb = n / Self::BLCK_SIZE; + let mut sumf = 0f32; + + for i in 0..nb { + let d8 = ys[i].d; + + for (j, &qbyte) in xs[i].qs.iter().enumerate() { + let idx_base = j * 4; + let q_vals = [ + (qbyte & 0b11) - 2, + ((qbyte >> 2) & 0b11) - 2, + ((qbyte >> 4) & 0b11) - 2, + ((qbyte >> 6) & 0b11) - 2, + ]; + + let sum_i = q_vals.iter().zip(ys[i].qs[idx_base..idx_base + 4].iter()) + .map(|(&q_val, &y_val)| q_val as i32 * y_val as i32) + .sum::(); + + sumf += sum_i as f32 * d8; + } + } + Ok(sumf) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } +} + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 pub fn matmul( mkn: (usize, usize, usize), diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 236f5a9811..9d6d9abfaf 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -146,6 +146,7 @@ pub enum GgmlDType { Q5K, Q6K, Q8K, + Q2b0, } impl GgmlDType { @@ -165,6 +166,7 @@ impl GgmlDType { 13 => Self::Q5K, 14 => Self::Q6K, 15 => Self::Q8K, + 40 => Self::Q2b0, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -186,6 +188,7 @@ impl GgmlDType { Self::Q5K => 13, Self::Q6K => 14, Self::Q8K => 15, + Self::Q2b0 => 40, } } @@ -206,6 +209,7 @@ impl GgmlDType { Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]), Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), + Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]), } } /// The type size for blocks in bytes. @@ -227,6 +231,7 @@ impl GgmlDType { Self::Q5K => std::mem::size_of::(), Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), + Self::Q2b0 => std::mem::size_of::(), } } @@ -241,7 +246,7 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, - Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + Self::Q2b0 | Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, } } } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index c4d5d6f41a..5b775bbaed 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,5 +1,5 @@ use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, + BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, BlockQ2b0 }; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -517,6 +517,63 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res Ok(sumf) } +#[inline(always)] +pub(crate) fn vec_dot_q2b0_q8k(n: usize, xs: &[BlockQ2b0], ys: &[BlockQ8K]) -> crate::Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}") + } + + let mut sumf = 0f32; + + unsafe { + let m2b = vdupq_n_u8(0b11); // Máscara para extraer 2 bits por elemento + + for (x, y) in xs.iter().zip(ys.iter()) { + let d = y.d; + + let mut q2_ptr = x.qs.as_ptr(); + let mut q8_ptr = y.qs.as_ptr(); + + let mut sumi = 0i32; + + for _ in 0..QK_K / 16 { + // Cargar bloques de datos + let q2bytes = vld1q_u8(q2_ptr as *const u8); + q2_ptr = q2_ptr.add(16); + + let q8bytes_low = vld1q_s8(q8_ptr); + let q8bytes_high = vld1q_s8(q8_ptr.add(16)); + q8_ptr = q8_ptr.add(32); + + // Extraer los valores de los 2 bits (4 valores por byte) + let q2_vals_low = vreinterpretq_s8_u8(vandq_u8(q2bytes, m2b)); + let q2_vals_high = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 2), m2b)); + let q2_vals_mid_low = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 4), m2b)); + let q2_vals_mid_high = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 6), m2b)); + + // Ajustar los valores de Q2 (centro en -2) + let q2_low = vsubq_s8(q2_vals_low, vdupq_n_s8(2)); + let q2_high = vsubq_s8(q2_vals_high, vdupq_n_s8(2)); + let q2_mid_low = vsubq_s8(q2_vals_mid_low, vdupq_n_s8(2)); + let q2_mid_high = vsubq_s8(q2_vals_mid_high, vdupq_n_s8(2)); + + // Calcular productos punto para cada parte + let prod0 = vdotq_s32(q2_low, q8bytes_low); + let prod1 = vdotq_s32(q2_high, q8bytes_high); + let prod2 = vdotq_s32(q2_mid_low, q8bytes_low); + let prod3 = vdotq_s32(q2_mid_high, q8bytes_high); + + // Sumar los productos + sumi += vaddvq_s32(prod0) + vaddvq_s32(prod1) + vaddvq_s32(prod2) + vaddvq_s32(prod3); + } + + sumf += sumi as f32 * d; + } + } + + Ok(sumf) +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 9407c30f0e..a867a8652b 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -29,30 +29,19 @@ enum Prompt { enum Which { #[value(name = "falcon3-1b-1.58")] Falcon3_1b1_58, + #[value(name = "falcon3-3b-1.58")] + Falcon3_3b1_58, } impl Which { - fn is_mistral(&self) -> bool { - match self { - Self::Falcon3_1b1_58 => false, - } - } - - fn is_zephyr(&self) -> bool { - match self { - Self::Falcon3_1b1_58 => false, - } - } - - fn is_open_chat(&self) -> bool { - match self { - Self::Falcon3_1b1_58 => false, - } + fn is_falcon(&self) -> bool { + matches!(self, Self::Falcon3_1b1_58) } fn tokenizer_repo(&self) -> &'static str { match self { Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", + Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", } } } @@ -155,6 +144,10 @@ impl Args { "tiiuae/Falcon3-1B-Instruct-1.58bit", "Falcon3-1B-Instruct-1.58bit.gguf", ), + Which::Falcon3_3b1_58 => ( + "tiiuae/Falcon3-3B-Instruct-1.58bit", + "Falcon3-3B-Instruct-1.58bit.gguf", + ), }; let revision = "main"; let api = hf_hub::api::sync::Api::new()?; @@ -270,7 +263,13 @@ fn main() -> anyhow::Result<()> { let mut pre_prompt_tokens = vec![]; for prompt_index in 0.. { let prompt_str = match &prompt { - Prompt::One(prompt) => prompt.clone(), + Prompt::One(prompt) => { + if args.which.is_falcon() { + format!("<|user|>\n{prompt}\n<|assistant|>") + } else { + prompt.clone() + } + } Prompt::Interactive | Prompt::Chat => { let is_interactive = matches!(prompt, Prompt::Interactive); print!("> "); @@ -283,16 +282,8 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - if args.which.is_open_chat() { - format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:") - } else if args.which.is_zephyr() { - if prompt_index == 0 || is_interactive { - format!("<|system|>\n\n<|user|>\n{prompt}\n<|assistant|>",) - } else { - format!("<|user|>\n{prompt}\n<|assistant|>") - } - } else if args.which.is_mistral() { - format!("[INST] {prompt} [/INST]") + if args.which.is_falcon() { + format!("<|user|>\n{prompt}\n<|assistant|>") } else { prompt } @@ -358,10 +349,11 @@ fn main() -> anyhow::Result<()> { } let eos_token = match args.which { - Which::Falcon3_1b1_58 => "<|endoftext|>", + Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>", }; let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); + let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; for index in 0..to_sample { diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index 30d9b47726..1c520b1686 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -51,37 +51,20 @@ impl QMatMul { struct BitQMatMul { inner: candle::quantized::QMatMul, span: tracing::Span, -} - - -fn activation_quant(x: &Tensor) -> Result<(Tensor, Tensor)> { - let scale = (127.0 - / x.abs()? - .max(D::Minus1)? - .max(D::Minus1)? - .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; - - let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .round()? - .clamp(-128.0, 127.0)?; - - Ok((y, scale)) + weight_scale: Tensor, } impl BitQMatMul { - fn from_qtensor(qtensor: QTensor) -> Result { + fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result { let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Ok(Self { inner, span }) + let weight_scale = weight_scale.dequantize(&weight_scale.device())?; + Ok(Self { inner, span, weight_scale }) } fn forward(&self, x: &Tensor) -> Result { - let (x, x_scale) = activation_quant(x)?; - let _enter = self.span.enter(); - self.inner.forward(&x)?.broadcast_div(&x_scale) + self.inner.forward(&x)?.broadcast_div(&self.weight_scale) } } @@ -333,17 +316,24 @@ impl ModelWeights { for layer_idx in 0..ct.hparams.n_layer { let prefix = format!("layers.{layer_idx}"); let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?; + let attention_wq_ws = ct.remove(&format!("{prefix}.attention.wq.weight_scale"))?; let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?; + let attention_wk_ws = ct.remove(&format!("{prefix}.attention.wk.weight_scale"))?; let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?; + let attention_wv_ws = ct.remove(&format!("{prefix}.attention.wv.weight_scale"))?; let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?; + let attention_wo_ws = ct.remove(&format!("{prefix}.attention.wo.weight_scale"))?; let mlp_or_moe = { let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; + let feed_forward_w1_ws = ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; + let feed_forward_w2_ws = ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; + let feed_forward_w3_ws = ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; MlpOrMoe::Mlp(Mlp { - feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, }) }; let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?; @@ -352,10 +342,10 @@ impl ModelWeights { let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); layers.push(LayerWeights { - attention_wq: BitQMatMul::from_qtensor(attention_wq)?, - attention_wk: BitQMatMul::from_qtensor(attention_wk)?, - attention_wv: BitQMatMul::from_qtensor(attention_wv)?, - attention_wo: BitQMatMul::from_qtensor(attention_wo)?, + attention_wq: BitQMatMul::from_qtensor(attention_wq, attention_wq_ws)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk, attention_wk_ws)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv, attention_wv_ws)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo, attention_wo_ws)?, attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, mlp_or_moe, ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, @@ -429,20 +419,30 @@ impl ModelWeights { for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; + let attention_wq_ws = ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; + let attention_wk_ws = ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; + let attention_wv_ws = ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wo_ws = ct.tensor(reader, &format!("{prefix}.attn_output.weight_scale"), device)?; let mlp_or_moe = if n_expert <= 1 { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w1_ws = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight_scale"), device)?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w2_ws = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight_scale"), device)?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let feed_forward_w3_ws = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight_scale"), device)?; MlpOrMoe::Mlp(Mlp { - feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, }) } else { let feed_forward_gate_inp = @@ -451,15 +451,21 @@ impl ModelWeights { for i in 0..n_expert { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; + let feed_forward_w1_ws = + ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight_scale"), device)?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; + let feed_forward_w2_ws = + ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight_scale"), device)?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; + let feed_forward_w3_ws = + ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight_scale"), device)?; experts.push(Mlp { - feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3)?, + feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, + feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, + feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, }) } MlpOrMoe::MoE { @@ -475,10 +481,10 @@ impl ModelWeights { let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); layers.push(LayerWeights { - attention_wq: BitQMatMul::from_qtensor(attention_wq)?, - attention_wk: BitQMatMul::from_qtensor(attention_wk)?, - attention_wv: BitQMatMul::from_qtensor(attention_wv)?, - attention_wo: BitQMatMul::from_qtensor(attention_wo)?, + attention_wq: BitQMatMul::from_qtensor(attention_wq, attention_wq_ws)?, + attention_wk: BitQMatMul::from_qtensor(attention_wk, attention_wk_ws)?, + attention_wv: BitQMatMul::from_qtensor(attention_wv, attention_wv_ws)?, + attention_wo: BitQMatMul::from_qtensor(attention_wo, attention_wo_ws)?, attention_norm: RmsNorm::from_qtensor(attention_norm, rms_norm_eps)?, mlp_or_moe, ffn_norm: RmsNorm::from_qtensor(ffn_norm, rms_norm_eps)?, diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 2f588d0abe..af944057d8 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -1,3 +1,4 @@ +use candle::op::Op; use candle::quantized::{gguf_file, GgmlDType, QTensor}; use candle::{Device, Result, Tensor}; use clap::{Parser, Subcommand, ValueEnum}; @@ -47,6 +48,8 @@ enum Quantization { Q8_0, #[value(name = "q8_1")] Q8_1, + #[value(name = "q2b0")] + Q2b0, Q2k, Q3k, Q4k, @@ -74,6 +77,7 @@ impl Quantization { Quantization::Q8k => GgmlDType::Q8K, Quantization::F16 => GgmlDType::F16, Quantization::F32 => GgmlDType::F32, + Quantization::Q2b0 => GgmlDType::Q2b0, } } } @@ -148,6 +152,10 @@ enum Command { #[clap(long, short, action)] bitnet_mode: bool, + // Allow to specify quantization_bitnet in case of bitnet_mode + #[arg(long, value_enum)] + bitnet_quantization: Option, + /// The quantization schema to apply. #[arg(long, value_enum)] quantization: Quantization, @@ -293,7 +301,10 @@ fn run_print( let tensor = tensor.dequantize(device)?; println!("{tensor}") } - Err(_) => println!("not found"), + Err(e) => { + eprintln!("error: {e}"); + println!("not found") + } } } } @@ -407,17 +418,20 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result { let cols = tensor.dim(1).unwrap(); let mut unpacked_vec = vec![0f32; rows * 4 * cols]; + for i in 0..rows { for j in 0..cols { let packed = packed_vec[i][j]; + for k in 0..4 { let bits = ((packed >> (k * 2)) & 0b11) as i8 - 1; - unpacked_vec[(i * 4 + k) * cols + j] = bits as f32; + let index = (k * rows + i) * cols + j; + unpacked_vec[index] = bits as f32; } } } - let unpacked_tensor = Tensor::from_vec(unpacked_vec, (rows*4, cols), tensor.device())?; + let unpacked_tensor = Tensor::from_vec(unpacked_vec, (rows * 4, cols), tensor.device())?; Ok(unpacked_tensor) } @@ -425,6 +439,7 @@ fn run_quantize_safetensors( in_files: &[std::path::PathBuf], out_file: std::path::PathBuf, q: Quantization, + bq: Option, bitnet_mode: bool, ) -> Result<()> { let mut out_file = std::fs::File::create(out_file)?; @@ -445,6 +460,7 @@ fn run_quantize_safetensors( let qtensors = tensors .into_par_iter() .map(|(mut name, tensor)| { + let mut local_dtype = dtype.clone(); let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let mut tensor = tensor; if should_quantize && bitnet_mode { @@ -460,11 +476,12 @@ fn run_quantize_safetensors( if is_bitnet_weight { println!(" unpacking {name} {tensor:?} {should_quantize}"); tensor = unpack_bitnet_weights(&tensor)?; + local_dtype = bq.clone().unwrap().dtype(); } } println!(" quantizing {name} {tensor:?} {should_quantize}"); let tensor = if should_quantize { - QTensor::quantize(&tensor, dtype)? + QTensor::quantize(&tensor, local_dtype)? } else { QTensor::quantize(&tensor, GgmlDType::F32)? }; @@ -574,12 +591,16 @@ fn run_quantize( out_file: std::path::PathBuf, q: Quantization, qmode: QuantizationMode, + bq: Option, bitnet_mode: bool, device: &Device, ) -> Result<()> { if in_files.is_empty() { candle::bail!("no specified input files") } + if bitnet_mode && bq.is_none() { + candle::bail!("bitnet mode requires a bitnet quantization") + } if let Some(extension) = out_file.extension() { if extension == "safetensors" { candle::bail!("the generated file cannot use the safetensors extension") @@ -587,7 +608,7 @@ fn run_quantize( } if let Some(extension) = in_files[0].extension() { if extension == "safetensors" { - return run_quantize_safetensors(in_files, out_file, q, bitnet_mode); + return run_quantize_safetensors(in_files, out_file, q, bq, bitnet_mode); } } @@ -655,9 +676,10 @@ fn main() -> anyhow::Result<()> { in_file, out_file, quantization, + bitnet_quantization, mode, bitnet_mode, - } => run_quantize(&in_file, out_file, quantization, mode, bitnet_mode, &device)?, + } => run_quantize(&in_file, out_file, quantization, mode, bitnet_quantization, bitnet_mode, &device)?, Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(()) From 81fe4833355e5ab020b7c9a08bb496e2cee73eaf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Wed, 25 Dec 2024 20:25:36 +0100 Subject: [PATCH 04/11] wip --- candle-core/src/quantized/k_quants.rs | 130 ++++++------ candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/neon.rs | 75 +++---- .../examples/quantized-bitnet/main.rs | 21 +- candle-metal-kernels/src/lib.rs | 6 +- .../src/models/quantized_llama_bitnet.rs | 19 +- tensor-tools/src/main.rs | 200 ++++++++---------- 7 files changed, 243 insertions(+), 213 deletions(-) diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index 44a38b63ca..d1b010a5a5 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -157,7 +157,8 @@ const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::( #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ2b0 { - pub(crate) qs: [i8; QK_K / 4], // Every single byte represents 4 values. + pub(crate) qs: [u8; QK_K / 8], // Every single bit represents positive values, is a vector of {0, 1} + pub(crate) qd: [u8; QK_K / 8], // Every single bit represents negatives values, is a vector of {0, 1} } const _: () = assert!(QK_K / 4 == std::mem::size_of::()); @@ -1846,87 +1847,94 @@ impl GgmlType for BlockQ8K { } } + impl GgmlType for BlockQ2b0 { const DTYPE: GgmlDType = GgmlDType::Q2b0; const BLCK_SIZE: usize = QK_K; type VecDotType = BlockQ8K; - fn to_float(xs: &[Self], ys: &mut [f32]) -> crate::Result<()> { - let k = ys.len(); - if k % Self::BLCK_SIZE != 0 { - crate::bail!( - "to_float Q2b0: size {} is not divisible by {}", - k, - Self::BLCK_SIZE - ); + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % QK_K != 0 { + crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}"); } - - let nb = k / Self::BLCK_SIZE; - for i in 0..nb { - let base = i * Self::BLCK_SIZE; - for (j, &qbyte) in xs[i].qs.iter().enumerate() { - let start = base + j * 4; - ys[start] = (qbyte & 0b11) as f32 - 2.0; - ys[start + 1] = ((qbyte >> 2) & 0b11) as f32 - 2.0; - ys[start + 2] = ((qbyte >> 4) & 0b11) as f32 - 2.0; - ys[start + 3] = (((qbyte >> 6) & 0b11) as f32 - 2.0); - } + let mut sumf = 0.0; + for (x, y) in xs.iter().zip(ys.iter()) { + let mut isum = 0i32; + for i in 0..QK_K / 8 { + let qs = x.qs[i]; + let qd = x.qd[i]; + let mut y_cache = [0i32; 8]; + y_cache.copy_from_slice(&y.qs[i * 8..(i + 1) * 8].iter().map(|&x| x as i32).collect::>()[..]); + + let pos_sum: i32 = (0..8).map(|bit| { + let mask = 1 << bit; + let is_active = ((qs & mask) >> bit) as i32; + is_active * y_cache[bit] + }).sum(); + + let neg_sum: i32 = (0..8).map(|bit| { + let mask = 1 << bit; + let is_active = ((qd & mask) >> bit) as i32; + is_active * y_cache[bit] + }).sum(); + + isum += pos_sum - neg_sum; + } + sumf += isum as f32 * y.d; } - Ok(()) + Ok(sumf) } - fn from_float(xs: &[f32], ys: &mut [Self]) -> crate::Result<()> { - let k = xs.len(); - if k % Self::BLCK_SIZE != 0 { - crate::bail!("from_float Q2b0: size {} is not divisible by {}", k, Self::BLCK_SIZE); + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() % QK_K != 0 { + crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), QK_K); } - let nb = k / Self::BLCK_SIZE; - for i in 0..nb { - let slice = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; - ys[i].qs.fill(0); + for (block, x) in ys.iter_mut().zip(xs.chunks_exact(QK_K)) { + for (i, chunk) in x.chunks_exact(8).enumerate() { + let mut qs = 0u8; + let mut qd = 0u8; - for (j, qbyte) in ys[i].qs.iter_mut().enumerate() { - let start = j * 4; - let q0 = ((slice[start] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; - let q1 = ((slice[start + 1] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; - let q2 = ((slice[start + 2] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; - let q3 = ((slice[start + 3] + 2.0).round().clamp(0.0, 3.0) as i8) & 0b11; - - *qbyte = q0 | (q1 << 2) | (q2 << 4) | (q3 << 6); + for (b, &value) in chunk.iter().enumerate() { + if value > 0.0 { + qs |= 1 << b; + } else if value < 0.0 { + qd |= 1 << b; + } + } + block.qs[i] = qs; + block.qd[i] = qd; } } Ok(()) } - fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> crate::Result { - let nb = n / Self::BLCK_SIZE; - let mut sumf = 0f32; - - for i in 0..nb { - let d8 = ys[i].d; - - for (j, &qbyte) in xs[i].qs.iter().enumerate() { - let idx_base = j * 4; - let q_vals = [ - (qbyte & 0b11) - 2, - ((qbyte >> 2) & 0b11) - 2, - ((qbyte >> 4) & 0b11) - 2, - ((qbyte >> 6) & 0b11) - 2, - ]; + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if ys.len() % QK_K != 0 { + crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), QK_K); + } - let sum_i = q_vals.iter().zip(ys[i].qs[idx_base..idx_base + 4].iter()) - .map(|(&q_val, &y_val)| q_val as i32 * y_val as i32) - .sum::(); + for (block, y) in xs.iter().zip(ys.chunks_exact_mut(QK_K)) { + for (i, chunk) in y.chunks_exact_mut(8).enumerate() { + let qs = block.qs[i]; + let qd = block.qd[i]; - sumf += sum_i as f32 * d8; + for b in 0..8 { + chunk[b] = if (qs >> b) & 1 != 0 { + 1.0 + } else if (qd >> b) & 1 != 0 { + -1.0 + } else { + 0.0 + }; + } } } - Ok(sumf) - } - - fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - Self::vec_dot_unopt(n, xs, ys) + Ok(()) } } diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index f7f5b68ac2..4f7e1f1534 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -103,6 +103,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ8K::to_float(&vec, &mut out)?; } + GgmlDType::Q2b0 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -225,6 +229,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K, GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, + GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0, } } } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 5b775bbaed..e459892b24 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -523,51 +523,52 @@ pub(crate) fn vec_dot_q2b0_q8k(n: usize, xs: &[BlockQ2b0], ys: &[BlockQ8K]) -> c crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}") } - let mut sumf = 0f32; + let mut sumf = 0.0_f32; unsafe { - let m2b = vdupq_n_u8(0b11); // Máscara para extraer 2 bits por elemento - for (x, y) in xs.iter().zip(ys.iter()) { - let d = y.d; + let mut isum = 0_i32; - let mut q2_ptr = x.qs.as_ptr(); - let mut q8_ptr = y.qs.as_ptr(); + for i in 0..(QK_K / 8) { + let qs = x.qs[i]; + let qd = x.qd[i]; - let mut sumi = 0i32; + // Load y_cache: Load 8 i8 values from y.qs[i * 8..(i + 1) * 8] + let y_qs_ptr = y.qs.as_ptr().add(i * 8); + let y_cache_i8x8 = vld1_s8(y_qs_ptr); + + // Extend y_cache_i8x8 to int16x8_t vector + let y_cache_i16x8 = vmovl_s8(y_cache_i8x8); + + // Prepare shift amounts: [0, -1, -2, -3, -4, -5, -6, -7] + let shift_vec_data: [i8; 8] = [0, -1, -2, -3, -4, -5, -6, -7]; + let shift_vec = vld1_s8(shift_vec_data.as_ptr()); + + // Duplicate qs and qd into vectors + let qs_vec = vdup_n_u8(qs); + let qd_vec = vdup_n_u8(qd); + + // Shift to bring bits into LSB + let qs_shifted = vshl_u8(qs_vec, shift_vec); + let qd_shifted = vshl_u8(qd_vec, shift_vec); + + // Mask LSB to get bits + let one_vec = vdup_n_u8(1); + let qs_bits = vand_u8(qs_shifted, one_vec); + let qd_bits = vand_u8(qd_shifted, one_vec); + + // Convert bits to int16x8_t + let qs_bits_i16x8 = vreinterpretq_s16_u16(vmovl_u8(qs_bits)); + let qd_bits_i16x8 = vreinterpretq_s16_u16(vmovl_u8(qd_bits)); + + // Multiply and accumulate + let pos_sum = vaddvq_s16(vmulq_s16(qs_bits_i16x8, y_cache_i16x8)); + let neg_sum = vaddvq_s16(vmulq_s16(qd_bits_i16x8, y_cache_i16x8)); - for _ in 0..QK_K / 16 { - // Cargar bloques de datos - let q2bytes = vld1q_u8(q2_ptr as *const u8); - q2_ptr = q2_ptr.add(16); - - let q8bytes_low = vld1q_s8(q8_ptr); - let q8bytes_high = vld1q_s8(q8_ptr.add(16)); - q8_ptr = q8_ptr.add(32); - - // Extraer los valores de los 2 bits (4 valores por byte) - let q2_vals_low = vreinterpretq_s8_u8(vandq_u8(q2bytes, m2b)); - let q2_vals_high = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 2), m2b)); - let q2_vals_mid_low = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 4), m2b)); - let q2_vals_mid_high = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bytes, 6), m2b)); - - // Ajustar los valores de Q2 (centro en -2) - let q2_low = vsubq_s8(q2_vals_low, vdupq_n_s8(2)); - let q2_high = vsubq_s8(q2_vals_high, vdupq_n_s8(2)); - let q2_mid_low = vsubq_s8(q2_vals_mid_low, vdupq_n_s8(2)); - let q2_mid_high = vsubq_s8(q2_vals_mid_high, vdupq_n_s8(2)); - - // Calcular productos punto para cada parte - let prod0 = vdotq_s32(q2_low, q8bytes_low); - let prod1 = vdotq_s32(q2_high, q8bytes_high); - let prod2 = vdotq_s32(q2_mid_low, q8bytes_low); - let prod3 = vdotq_s32(q2_mid_high, q8bytes_high); - - // Sumar los productos - sumi += vaddvq_s32(prod0) + vaddvq_s32(prod1) + vaddvq_s32(prod2) + vaddvq_s32(prod3); + isum += pos_sum as i32 - neg_sum as i32; } - sumf += sumi as f32 * d; + sumf += isum as f32 * y.d; } } diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index a867a8652b..9caa409257 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -31,17 +31,24 @@ enum Which { Falcon3_1b1_58, #[value(name = "falcon3-3b-1.58")] Falcon3_3b1_58, + #[value(name = "llama3-8b-1.58")] + Llama3_8b1_58, } impl Which { fn is_falcon(&self) -> bool { - matches!(self, Self::Falcon3_1b1_58) + matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58) + } + + fn is_llama(&self) -> bool { + matches!(self, Self::Llama3_8b1_58) } fn tokenizer_repo(&self) -> &'static str { match self { Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", + Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens", } } } @@ -148,6 +155,10 @@ impl Args { "tiiuae/Falcon3-3B-Instruct-1.58bit", "Falcon3-3B-Instruct-1.58bit.gguf", ), + Which::Llama3_8b1_58 => ( + "HF1BitLLM/Llama3-8B-1.58-100B-tokens", + "Llama3-8B-1.58bit.gguf", + ), }; let revision = "main"; let api = hf_hub::api::sync::Api::new()?; @@ -252,6 +263,7 @@ fn main() -> anyhow::Result<()> { println!("model built"); let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); let prompt = match args.prompt.as_deref() { Some("chat") => Prompt::Chat, @@ -265,7 +277,7 @@ fn main() -> anyhow::Result<()> { let prompt_str = match &prompt { Prompt::One(prompt) => { if args.which.is_falcon() { - format!("<|user|>\n{prompt}\n<|assistant|>") + format!("<|user|>{prompt}<|assistant|>") } else { prompt.clone() } @@ -284,6 +296,10 @@ fn main() -> anyhow::Result<()> { } if args.which.is_falcon() { format!("<|user|>\n{prompt}\n<|assistant|>") + } else if args.which.is_llama() { + format!( + "<|start_header_id|>user<|end_header_id|>\n{prompt}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + ) } else { prompt } @@ -350,6 +366,7 @@ fn main() -> anyhow::Result<()> { let eos_token = match args.which { Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>", + Which::Llama3_8b1_58 => "<|eot_id|>", }; let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5f948cbf4c..1b8a947345 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2164,6 +2164,7 @@ pub enum GgmlDType { Q8K, F16, F32, + Q2b0, } #[allow(clippy::too_many_arguments)] @@ -2229,7 +2230,7 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } - GgmlDType::Q3K | GgmlDType::Q5K => { + GgmlDType::Q2b0 | GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; @@ -2253,7 +2254,7 @@ pub fn call_quantized_matmul_mv_t( let nth1 = 1; let align = 8; (nth0, nth1, align) - } + }, }; let thread_groups_count = MTLSize { width: divide(ne01 as usize, align), @@ -2280,6 +2281,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", + GgmlDType::Q2b0 => todo!(), }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index 1c520b1686..3a5e9df49e 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -54,6 +54,21 @@ struct BitQMatMul { weight_scale: Tensor, } +fn activation_quant(x: &Tensor) -> Result<(Tensor, Tensor)> { + let scale = (127.0 + / x.abs()? + .max(D::Minus1)? + .max(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?)? + .to_dtype(x.dtype())?; + + let y = x + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .clamp(-128.0, 127.0)?; + + Ok((y, scale)) +} + impl BitQMatMul { fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result { let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; @@ -63,8 +78,10 @@ impl BitQMatMul { } fn forward(&self, x: &Tensor) -> Result { + let (x, xscale) = activation_quant(&x)?; let _enter = self.span.enter(); - self.inner.forward(&x)?.broadcast_div(&self.weight_scale) + let scale = self.weight_scale.broadcast_mul(&xscale)?; + self.inner.forward(&x)?.broadcast_div(&scale) } } diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index af944057d8..9adde2282a 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -435,137 +435,117 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result { Ok(unpacked_tensor) } +use std::collections::HashMap; +use std::fs::File; +use std::path::PathBuf; +use rayon::prelude::*; +use serde_json::Value; + fn run_quantize_safetensors( - in_files: &[std::path::PathBuf], - out_file: std::path::PathBuf, + in_files: &[PathBuf], + out_file: PathBuf, q: Quantization, bq: Option, bitnet_mode: bool, ) -> Result<()> { - let mut out_file = std::fs::File::create(out_file)?; - let mut tensors = std::collections::HashMap::new(); + let mut out_file = File::create(out_file)?; + let dtype = q.dtype(); + let block_size = dtype.block_size(); + let metadata_file = in_files.iter().find(|f| f.to_string_lossy().ends_with("config.json")); - for in_file in in_files.iter() { - if metadata_file.is_some() && in_file == metadata_file.unwrap() { - continue; + + let mut qtensors = Vec::new(); + + for in_file in in_files { + if let Some(metadata) = &metadata_file { + if Some(in_file) == Some(metadata) { + continue; + } } - let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; - tensors.extend(in_tensors) - } - println!("tensors: {}", tensors.len()); - let dtype = q.dtype(); - let block_size = dtype.block_size(); + println!("Loading tensors from file: {:?}", in_file); + let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; - let qtensors = tensors - .into_par_iter() - .map(|(mut name, tensor)| { - let mut local_dtype = dtype.clone(); - let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; - let mut tensor = tensor; - if should_quantize && bitnet_mode { - let is_bitnet_weight = - name.contains("self_attn.v_proj") || - name.contains("self_attn.q_proj") || - name.contains("self_attn.o_proj") || - name.contains("self_attn.k_proj") || - name.contains("mlp.down_proj") || - name.contains("mlp.up_proj") || - name.contains("mlp.gate_proj"); - - if is_bitnet_weight { - println!(" unpacking {name} {tensor:?} {should_quantize}"); - tensor = unpack_bitnet_weights(&tensor)?; - local_dtype = bq.clone().unwrap().dtype(); + let processed_tensors = in_tensors + .into_par_iter() + .map(|(mut name, tensor)| { + let mut local_dtype = dtype.clone(); + let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + let mut tensor = tensor; + + if should_quantize && bitnet_mode { + let is_bitnet_weight = + name.contains("self_attn.v_proj") || + name.contains("self_attn.q_proj") || + name.contains("self_attn.o_proj") || + name.contains("self_attn.k_proj") || + name.contains("mlp.down_proj") || + name.contains("mlp.up_proj") || + name.contains("mlp.gate_proj"); + + if is_bitnet_weight { + println!(" unpacking {name} {tensor:?} {should_quantize}"); + tensor = unpack_bitnet_weights(&tensor)?; + local_dtype = bq.clone().unwrap().dtype(); + } } - } - println!(" quantizing {name} {tensor:?} {should_quantize}"); - let tensor = if should_quantize { - QTensor::quantize(&tensor, local_dtype)? - } else { - QTensor::quantize(&tensor, GgmlDType::F32)? - }; - - if name == "model.embed_tokens.weight" { - name = "token_embd.weight".to_string(); - } + println!(" quantizing {name} {tensor:?} {should_quantize}"); + let tensor = if should_quantize { + QTensor::quantize(&tensor, local_dtype)? + } else { + QTensor::quantize(&tensor, GgmlDType::F32)? + }; - if name == "model.norm.weight" { - name = "output_norm.weight".to_string() - } + if name == "model.embed_tokens.weight" { + name = "token_embd.weight".to_string(); + } else if name == "model.norm.weight" { + name = "output_norm.weight".to_string(); + } else if name == "lm_head.weight" { + name = "output.weight".to_string(); + } - if name == "lm_head.weight" { - name = "output.weight".to_string() - } + name = name.replace("model.layers.", "blk."); + name = name.replace("self_attn.q_proj", "attn_q"); + name = name.replace("self_attn.k_proj", "attn_k"); + name = name.replace("self_attn.v_proj", "attn_v"); + name = name.replace("self_attn.o_proj", "attn_output"); + name = name.replace("mlp.gate_proj", "ffn_gate"); + name = name.replace("mlp.down_proj", "ffn_down"); + name = name.replace("mlp.up_proj", "ffn_up"); + name = name.replace("input_layernorm", "attn_norm"); + name = name.replace("post_attention_layernorm", "ffn_norm"); + + Ok((name, tensor)) + }) + .collect::>>()?; - name = name.replace("model.layers.", "blk."); - name = name.replace("self_attn.q_proj", "attn_q"); - name = name.replace("self_attn.k_proj", "attn_k"); - name = name.replace("self_attn.v_proj", "attn_v"); - name = name.replace("self_attn.o_proj", "attn_output"); - name = name.replace("mlp.gate_proj", "ffn_gate"); - name = name.replace("mlp.down_proj", "ffn_down"); - name = name.replace("mlp.up_proj", "ffn_up"); - name = name.replace("input_layernorm", "attn_norm"); - name = name.replace("post_attention_layernorm", "ffn_norm"); + qtensors.extend(processed_tensors); + } - Ok((name, tensor)) - }) - .collect::>>()?; let qtensors = qtensors .iter() .map(|(k, v)| (k.as_str(), v)) .collect::>(); - // Load metadata - let gguf_metadata: Vec<(&str, gguf_file::Value)> = if let Some(metadata_file) = metadata_file { - let metadata = std::fs::read_to_string(metadata_file)?; - let metadata: serde_json::Value = serde_json::from_str(&metadata).unwrap(); - - let num_attention_heads = gguf_file::Value::from_u32(metadata["num_attention_heads"].as_u64().unwrap() as u32); - let num_attention_heads_kv = gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32); - - let num_hidden_layers = gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32); - let embedding_length = gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32); - let rope_dimension_count = gguf_file::Value::from_u32( - (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32) - ); - let layer_norm_eps = gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32); - - let mut gguf_metadata: Vec<(&str, gguf_file::Value)> = Vec::new(); - gguf_metadata.push(( - "llama.attention.head_count", - num_attention_heads.clone(), - )); - gguf_metadata.push(( - "llama.attention.head_count_kv", - num_attention_heads_kv.clone(), - )); - gguf_metadata.push(( - "llama.block_count", - num_hidden_layers.clone(), - )); - gguf_metadata.push(( - "llama.embedding_length", - embedding_length.clone(), - )); - gguf_metadata.push(( - "llama.attention.layer_norm_rms_epsilon", layer_norm_eps.clone() - )); - gguf_metadata.push(( - "llama.rope.dimension_count", - rope_dimension_count.clone(), - )); - - // Print metadata - for (key, value) in gguf_metadata.iter() { - println!(" {key}: {value:?}"); - } - gguf_metadata + let gguf_metadata = if let Some(metadata_file) = metadata_file { + let metadata_content = std::fs::read_to_string(metadata_file)?; + let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); + + vec![ + ("llama.attention.head_count", gguf_file::Value::from_u32(metadata["num_attention_heads"].as_u64().unwrap() as u32)), + ("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)), + ("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)), + ("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)), + ("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)), + ("llama.rope.dimension_count", gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32), + )), + ] } else { - Vec::new() + vec![] }; - gguf_file::write(&mut out_file, gguf_metadata.as_slice(), &qtensors)?; + + gguf_file::write(&mut out_file, &gguf_metadata, &qtensors)?; Ok(()) } From 677d03ae51d2eb49f581b25083dc2952ae1a0d41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Fri, 27 Dec 2024 10:17:44 +0100 Subject: [PATCH 05/11] Pre-eliminar qbitnet implementation --- candle-core/src/quantized/ggml_file.rs | 3 + candle-core/src/quantized/k_quants.rs | 116 +++++++++++++++--- candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 9 +- candle-core/src/quantized/neon.rs | 14 +-- .../examples/quantized-bitnet/main.rs | 37 ++++-- .../src/models/quantized_llama_bitnet.rs | 34 ++--- tensor-tools/src/main.rs | 5 +- 8 files changed, 172 insertions(+), 51 deletions(-) diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 0afd150e5d..7cde3eaa72 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -189,6 +189,9 @@ pub fn qtensor_from_ggml( GgmlDType::Q2b0 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::QI8 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index d1b010a5a5..e9c2379868 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -6,6 +6,7 @@ use super::GgmlDType; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; use half::f16; +use num_traits::real::Real; use rayon::prelude::*; // Default to QK_K 256 rather than 64. @@ -18,6 +19,8 @@ pub const QK5_0: usize = 32; pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; +pub const Q2B_0: usize = 32; +pub const QI8: usize = 32; pub trait GgmlType: Sized + Clone + Send + Sync { const DTYPE: GgmlDType; @@ -157,11 +160,18 @@ const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::( #[derive(Debug, Clone, PartialEq)] #[repr(C)] pub struct BlockQ2b0 { - pub(crate) qs: [u8; QK_K / 8], // Every single bit represents positive values, is a vector of {0, 1} - pub(crate) qd: [u8; QK_K / 8], // Every single bit represents negatives values, is a vector of {0, 1} + pub(crate) qs: [u8; Q2B_0 / 8], // Every single bit represents positive values, is a vector of {0, 1} + pub(crate) qd: [u8; Q2B_0 / 8], // Every single bit represents negatives values, is a vector of {0, 1} } -const _: () = assert!(QK_K / 4 == std::mem::size_of::()); +const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQI8{ + pub(crate) qs: [i8; QI8], +} +const _: () = assert!(std::mem::size_of::() == QI8); impl GgmlType for BlockQ4_0 { const DTYPE: GgmlDType = GgmlDType::Q4_0; @@ -1850,21 +1860,24 @@ impl GgmlType for BlockQ8K { impl GgmlType for BlockQ2b0 { const DTYPE: GgmlDType = GgmlDType::Q2b0; - const BLCK_SIZE: usize = QK_K; - type VecDotType = BlockQ8K; + const BLCK_SIZE: usize = Q2B_0; + type VecDotType = BlockQI8; fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2b0_qi8(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) } fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { - if n % QK_K != 0 { - crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}"); + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {Q2B_0}"); } let mut sumf = 0.0; for (x, y) in xs.iter().zip(ys.iter()) { let mut isum = 0i32; - for i in 0..QK_K / 8 { + for i in 0..Q2B_0 / 8 { let qs = x.qs[i]; let qd = x.qd[i]; let mut y_cache = [0i32; 8]; @@ -1884,17 +1897,17 @@ impl GgmlType for BlockQ2b0 { isum += pos_sum - neg_sum; } - sumf += isum as f32 * y.d; + sumf += isum as f32; } - Ok(sumf) + Ok(sumf as f32) } fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { - if xs.len() % QK_K != 0 { - crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), QK_K); + if xs.len() % Q2B_0 != 0 { + crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), Q2B_0); } - for (block, x) in ys.iter_mut().zip(xs.chunks_exact(QK_K)) { + for (block, x) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { for (i, chunk) in x.chunks_exact(8).enumerate() { let mut qs = 0u8; let mut qd = 0u8; @@ -1914,11 +1927,11 @@ impl GgmlType for BlockQ2b0 { } fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { - if ys.len() % QK_K != 0 { - crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), QK_K); + if ys.len() % Q2B_0 != 0 { + crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), Q2B_0); } - for (block, y) in xs.iter().zip(ys.chunks_exact_mut(QK_K)) { + for (block, y) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { for (i, chunk) in y.chunks_exact_mut(8).enumerate() { let qs = block.qs[i]; let qd = block.qd[i]; @@ -1938,6 +1951,77 @@ impl GgmlType for BlockQ2b0 { } } +impl GgmlType for BlockQI8 { + const DTYPE: GgmlDType = GgmlDType::QI8; + const BLCK_SIZE: usize = QI8; + type VecDotType = BlockQI8; + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + let k = ys.len(); + if k % QI8 != 0 { + crate::bail!("dequantize_row_qi8: {k} is not divisible by {QI8}"); + } + + let nb = k / QI8; + + for i in 0..nb { + for j in 0..QI8 { + ys[i * QI8 + j] = xs[i].qs[j] as f32; + } + } + Ok(()) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + // quantize_row_q8_0 + let k = xs.len(); + if k % Self::BLCK_SIZE != 0 { + crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE); + }; + let nb = k / Self::BLCK_SIZE; + if ys.len() != nb { + crate::bail!( + "size mismatch {} {} {}", + xs.len(), + ys.len(), + Self::BLCK_SIZE + ) + } + for (i, ys) in ys.iter_mut().enumerate() { + let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; + for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { + *y = x as i8; + } + } + Ok(()) + } + + #[allow(unreachable_code)] + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + let qk = QK8_0; + if n % QI8 != 0 { + crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}") + } + + // Generic implementation. + let mut sumf = 0f32; + for (xs, ys) in xs.iter().zip(ys.iter()) { + let sum_i = xs + .qs + .iter() + .zip(ys.qs.iter()) + .map(|(&x, &y)| x as i32 * y as i32) + .sum::(); + sumf += sum_i as f32; + } + Ok(sumf) + } +} + // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605 pub fn matmul( mkn: (usize, usize, usize), diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 4f7e1f1534..4ed1941ed9 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -107,6 +107,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?; } + GgmlDType::QI8 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQI8::to_float(&vec, &mut out)?; + } } let buffer = self.device.new_buffer_with_data(&out)?; @@ -230,6 +234,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0, + GgmlDType::QI8 => todo!(), } } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 9d6d9abfaf..728466e87a 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -147,6 +147,7 @@ pub enum GgmlDType { Q6K, Q8K, Q2b0, + QI8, } impl GgmlDType { @@ -167,6 +168,7 @@ impl GgmlDType { 14 => Self::Q6K, 15 => Self::Q8K, 40 => Self::Q2b0, + 41 => Self::QI8, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -189,6 +191,7 @@ impl GgmlDType { Self::Q6K => 14, Self::Q8K => 15, Self::Q2b0 => 40, + Self::QI8 => 41, } } @@ -210,6 +213,7 @@ impl GgmlDType { Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]), Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]), + Self::QI8 => Box::new(vec![BlockQI8::zeros(); elem_count / BlockQI8::BLCK_SIZE]), } } /// The type size for blocks in bytes. @@ -232,6 +236,7 @@ impl GgmlDType { Self::Q6K => std::mem::size_of::(), Self::Q8K => std::mem::size_of::(), Self::Q2b0 => std::mem::size_of::(), + Self::QI8 => std::mem::size_of::(), } } @@ -246,7 +251,9 @@ impl GgmlDType { Self::Q5_1 => k_quants::QK5_1, Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, - Self::Q2b0 | Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, + Self::Q2b0 => k_quants::Q2B_0, + Self::QI8 => k_quants::QI8, + Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, } } } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index e459892b24..8387555940 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,6 +1,6 @@ -use super::k_quants::{ - BlockQ2K, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, QK8_0, QK_K, BlockQ2b0 -}; +use super::{k_quants::{ + BlockQ2K, BlockQ2b0, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K +}, BlockQI8}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -518,8 +518,8 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res } #[inline(always)] -pub(crate) fn vec_dot_q2b0_q8k(n: usize, xs: &[BlockQ2b0], ys: &[BlockQ8K]) -> crate::Result { - if n % QK_K != 0 { +pub(crate) fn vec_dot_q2b0_qi8(n: usize, xs: &[BlockQ2b0], ys: &[BlockQI8]) -> crate::Result { + if n % Q2B_0 != 0 { crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {QK_K}") } @@ -529,7 +529,7 @@ pub(crate) fn vec_dot_q2b0_q8k(n: usize, xs: &[BlockQ2b0], ys: &[BlockQ8K]) -> c for (x, y) in xs.iter().zip(ys.iter()) { let mut isum = 0_i32; - for i in 0..(QK_K / 8) { + for i in 0..(Q2B_0 / 8) { let qs = x.qs[i]; let qd = x.qd[i]; @@ -568,7 +568,7 @@ pub(crate) fn vec_dot_q2b0_q8k(n: usize, xs: &[BlockQ2b0], ys: &[BlockQ8K]) -> c isum += pos_sum as i32 - neg_sum as i32; } - sumf += isum as f32 * y.d; + sumf += isum as f32; } } diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 9caa409257..e450905b1f 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -5,8 +5,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use clap::{Parser, ValueEnum}; +use tracing_subscriber::fmt::time::FormatTime; use std::io::Write; -use tokenizers::Tokenizer; +use tokenizers::{Tokenizer, AddedToken}; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; @@ -31,13 +32,17 @@ enum Which { Falcon3_1b1_58, #[value(name = "falcon3-3b-1.58")] Falcon3_3b1_58, + #[value(name = "falcon3-7b-1.58")] + Falcon3_7b1_58, + #[value(name = "falcon3-10b-1.58")] + Falcon3_10b1_58, #[value(name = "llama3-8b-1.58")] Llama3_8b1_58, } impl Which { fn is_falcon(&self) -> bool { - matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58) + matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58 | Self::Falcon3_7b1_58 | Self::Falcon3_10b1_58) } fn is_llama(&self) -> bool { @@ -49,6 +54,8 @@ impl Which { Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens", + Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Instruct-1.58bit", + Self::Falcon3_7b1_58 => "tiiuae/Falcon3-7B-Instruct-1.58bit", } } } @@ -76,7 +83,7 @@ struct Args { tokenizer: Option, /// The temperature used to generate samples, use 0 for greedy sampling. - #[arg(long, default_value_t = 0.8)] + #[arg(long, default_value_t = 0.2)] temperature: f64, /// Nucleus sampling probability cutoff. @@ -108,7 +115,7 @@ struct Args { cpu: bool, /// Penalty to be applied for repeating tokens, 1. means no penalty. - #[arg(long, default_value_t = 1.1)] + #[arg(long, default_value_t = 1.5)] repeat_penalty: f32, /// The context size to consider for the repeat penalty. @@ -155,6 +162,14 @@ impl Args { "tiiuae/Falcon3-3B-Instruct-1.58bit", "Falcon3-3B-Instruct-1.58bit.gguf", ), + Which::Falcon3_10b1_58 => ( + "tiiuae/Falcon3-10B-Instruct-1.58bit", + "Falcon3-10B-Instruct-1.58bit.gguf", + ), + Which::Falcon3_7b1_58 => ( + "tiiuae/Falcon3-7B-Instruct-1.58bit", + "Falcon3-7B-Instruct-1.58bit.gguf", + ), Which::Llama3_8b1_58 => ( "HF1BitLLM/Llama3-8B-1.58-100B-tokens", "Llama3-8B-1.58bit.gguf", @@ -256,7 +271,7 @@ fn main() -> anyhow::Result<()> { start.elapsed().as_secs_f32(), ); println!("params: {:?}", model.hparams); - let default_gqa = 1; + let default_gqa = 0; ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))? } }; @@ -277,7 +292,11 @@ fn main() -> anyhow::Result<()> { let prompt_str = match &prompt { Prompt::One(prompt) => { if args.which.is_falcon() { - format!("<|user|>{prompt}<|assistant|>") + format!("<|user|>\n{prompt}\n<|assistant|>") + } else if args.which.is_llama() { + format!( + "{prompt}" + ) } else { prompt.clone() } @@ -298,7 +317,7 @@ fn main() -> anyhow::Result<()> { format!("<|user|>\n{prompt}\n<|assistant|>") } else if args.which.is_llama() { format!( - "<|start_header_id|>user<|end_header_id|>\n{prompt}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{prompt}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>" ) } else { prompt @@ -365,10 +384,10 @@ fn main() -> anyhow::Result<()> { } let eos_token = match args.which { - Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>", + Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>", Which::Llama3_8b1_58 => "<|eot_id|>", }; - + let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); let start_post_prompt = std::time::Instant::now(); diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index 3a5e9df49e..7a7c2ee85e 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -54,20 +54,7 @@ struct BitQMatMul { weight_scale: Tensor, } -fn activation_quant(x: &Tensor) -> Result<(Tensor, Tensor)> { - let scale = (127.0 - / x.abs()? - .max(D::Minus1)? - .max(D::Minus1)? - .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; - - let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)?; - - Ok((y, scale)) -} + impl BitQMatMul { fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result { @@ -77,11 +64,24 @@ impl BitQMatMul { Ok(Self { inner, span, weight_scale }) } + pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { + let last_dim = x.rank().saturating_sub(1); + let max_abs = x.abs()?.max_keepdim(last_dim)?; + + let clamped = max_abs.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / &clamped)?; + + let scaled_rounded = x.broadcast_mul(&scale)?.round()?.clamp(-128f32, 127f32)?; + + Ok((scaled_rounded, scale)) + } + fn forward(&self, x: &Tensor) -> Result { - let (x, xscale) = activation_quant(&x)?; + let (x, xscale) = self.activation_quant(x)?; let _enter = self.span.enter(); - let scale = self.weight_scale.broadcast_mul(&xscale)?; - self.inner.forward(&x)?.broadcast_div(&scale) + self.inner.forward(&x)? + .broadcast_div(&self.weight_scale)? + .broadcast_div(&xscale) } } diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 9adde2282a..037c884491 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -50,6 +50,8 @@ enum Quantization { Q8_1, #[value(name = "q2b0")] Q2b0, + #[value(name = "qi8")] + QI8, Q2k, Q3k, Q4k, @@ -78,6 +80,7 @@ impl Quantization { Quantization::F16 => GgmlDType::F16, Quantization::F32 => GgmlDType::F32, Quantization::Q2b0 => GgmlDType::Q2b0, + Quantization::QI8 => GgmlDType::QI8, } } } @@ -470,7 +473,7 @@ fn run_quantize_safetensors( .into_par_iter() .map(|(mut name, tensor)| { let mut local_dtype = dtype.clone(); - let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; + let mut should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let mut tensor = tensor; if should_quantize && bitnet_mode { From 1ada7fa68155aedc3a7a744824d6e7d71bbf1d6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 30 Dec 2024 10:39:08 +0100 Subject: [PATCH 06/11] fix an issue while quantizing llama models --- candle-core/src/quantized/gguf_file.rs | 4 + .../examples/quantized-bitnet/main.rs | 38 ++++---- .../src/models/quantized_llama_bitnet.rs | 26 +++--- tensor-tools/src/main.rs | 92 +++++++++++++++---- 4 files changed, 113 insertions(+), 47 deletions(-) diff --git a/candle-core/src/quantized/gguf_file.rs b/candle-core/src/quantized/gguf_file.rs index af5d3a46e8..cfad05d009 100644 --- a/candle-core/src/quantized/gguf_file.rs +++ b/candle-core/src/quantized/gguf_file.rs @@ -190,6 +190,10 @@ impl Value { Self::F32(v) } + pub fn from_string(v: String) -> Self { + Self::String(v) + } + pub fn to_u8(&self) -> Result { match self { Self::U8(v) => Ok(*v), diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index e450905b1f..d91c68d400 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -291,15 +291,7 @@ fn main() -> anyhow::Result<()> { for prompt_index in 0.. { let prompt_str = match &prompt { Prompt::One(prompt) => { - if args.which.is_falcon() { - format!("<|user|>\n{prompt}\n<|assistant|>") - } else if args.which.is_llama() { - format!( - "{prompt}" - ) - } else { - prompt.clone() - } + prompt.clone() } Prompt::Interactive | Prompt::Chat => { let is_interactive = matches!(prompt, Prompt::Interactive); @@ -324,6 +316,7 @@ fn main() -> anyhow::Result<()> { } } }; + print!("{}", &prompt_str); let tokens = tos .tokenizer() @@ -331,7 +324,7 @@ fn main() -> anyhow::Result<()> { .map_err(anyhow::Error::msg)?; if args.verbose_prompt { for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { - let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + let token = token.to_string().replace('▁', " ").replace("<0x0A>", "\n"); println!("{id:7} -> '{token}'"); } } @@ -383,12 +376,24 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } - let eos_token = match args.which { - Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => "<|endoftext|>", - Which::Llama3_8b1_58 => "<|eot_id|>", + let eos_tokens = match args.which { + Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => { + vec!["<|endoftext|>"] + } + Which::Llama3_8b1_58 => { + vec!["<|eot_id|>"] + } }; - - let eos_token = *tos.tokenizer().get_vocab(true).get(eos_token).unwrap(); + + let eos_tokens: Vec = eos_tokens + .iter() + .map(|token| { + *tos.tokenizer() + .get_vocab(true) + .get(*token) + .unwrap_or_else(|| panic!("EoS token not found: {}", token)) + }) + .collect(); let start_post_prompt = std::time::Instant::now(); let mut sampled = 0; @@ -413,7 +418,8 @@ fn main() -> anyhow::Result<()> { std::io::stdout().flush()?; } sampled += 1; - if next_token == eos_token { + + if eos_tokens.contains(&next_token) { break; }; } diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index 7a7c2ee85e..e9a38a26a7 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -54,8 +54,6 @@ struct BitQMatMul { weight_scale: Tensor, } - - impl BitQMatMul { fn from_qtensor(qtensor: QTensor, weight_scale: QTensor) -> Result { let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; @@ -65,23 +63,27 @@ impl BitQMatMul { } pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { - let last_dim = x.rank().saturating_sub(1); - let max_abs = x.abs()?.max_keepdim(last_dim)?; - - let clamped = max_abs.clamp(1e-5, f32::INFINITY)?; - let scale = (127.0 / &clamped)?; - - let scaled_rounded = x.broadcast_mul(&scale)?.round()?.clamp(-128f32, 127f32)?; - + let target_dim = x.rank().saturating_sub(1); + + let max_abs = x.abs()?.max_keepdim(target_dim)?; + + let scale = (127.0/ &max_abs)?; + + let scaled_rounded = x + .broadcast_mul(&scale)? + .round()? + .clamp(-128f32, 127f32)?; + + Ok((scaled_rounded, scale)) } fn forward(&self, x: &Tensor) -> Result { let (x, xscale) = self.activation_quant(x)?; let _enter = self.span.enter(); + let scale = self.weight_scale.broadcast_mul(&xscale)?; self.inner.forward(&x)? - .broadcast_div(&self.weight_scale)? - .broadcast_div(&xscale) + .broadcast_div(&scale) } } diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 037c884491..4a1e9bc039 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -301,8 +301,10 @@ fn run_print( println!("==== {name} ===="); match content.tensor(&mut file, name, device) { Ok(tensor) => { + let dtype = tensor.dtype(); + let tensor = tensor.dequantize(device)?; - println!("{tensor}") + println!("{tensor} {dtype:?}") } Err(e) => { eprintln!("error: {e}"); @@ -438,12 +440,36 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result { Ok(unpacked_tensor) } +use core::num; use std::collections::HashMap; use std::fs::File; use std::path::PathBuf; use rayon::prelude::*; use serde_json::Value; +fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result { + let n_head = match n_head_kv { + Some(n_head_kv) if n_head != n_head_kv => n_head_kv, + _ => n_head, + }; + + let shape = weights.shape(); + let shape0 = shape.dims()[0]; + if shape0 % (n_head * 2) != 0 { + candle::bail!("weights.shape()[0] is not divisible by (n_head * 2)"); + } + + let mut new_shape = vec![n_head, 2, shape0 / (n_head * 2)]; + new_shape.extend_from_slice(&shape.dims()[1..]); + + let permuted = weights + .reshape(new_shape)? + .transpose(1, 2)? + .reshape(weights.shape())?; + + Ok(permuted) +} + fn run_quantize_safetensors( in_files: &[PathBuf], out_file: PathBuf, @@ -459,6 +485,34 @@ fn run_quantize_safetensors( let mut qtensors = Vec::new(); + let mut num_attention_heads = 0; + let mut num_key_value_heads = 0; + let mut architecture = String::new(); + + + let gguf_metadata = if let Some(metadata_file) = metadata_file { + let metadata_content = std::fs::read_to_string(metadata_file)?; + let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); + + num_attention_heads = metadata["num_attention_heads"].as_u64().unwrap(); + num_key_value_heads = metadata["num_key_value_heads"].as_u64().unwrap(); + architecture = metadata["model_type"].as_str().unwrap().to_string(); + + vec![ + ("llama.attention.head_count", gguf_file::Value::from_u32(num_attention_heads as u32)), + ("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)), + ("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)), + ("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)), + ("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)), + ("llama.rope.dimension_count", gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32), + )), + ("llama.rope.freq_base", gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32)), + ("general.architecture", gguf_file::Value::from_string(architecture.clone())), + ] + } else { + vec![] + }; for in_file in in_files { if let Some(metadata) = &metadata_file { if Some(in_file) == Some(metadata) { @@ -492,6 +546,24 @@ fn run_quantize_safetensors( local_dtype = bq.clone().unwrap().dtype(); } } + + if name == "lm_head.weight" { + local_dtype = GgmlDType::Q6K; + } + + // apply transformations to the tensors, based on the architecture + match architecture.as_str() { + "llama" => { + if name.ends_with("self_attn.q_proj.weight") { + tensor = permute(&tensor, num_attention_heads as usize, Some(num_attention_heads as usize))?; + } + if name.ends_with("self_attn.k_proj.weight") { + tensor = permute(&tensor, num_attention_heads as usize, Some(num_key_value_heads as usize))?; + } + } + _ => {} + } + println!(" quantizing {name} {tensor:?} {should_quantize}"); let tensor = if should_quantize { QTensor::quantize(&tensor, local_dtype)? @@ -530,24 +602,6 @@ fn run_quantize_safetensors( .map(|(k, v)| (k.as_str(), v)) .collect::>(); - let gguf_metadata = if let Some(metadata_file) = metadata_file { - let metadata_content = std::fs::read_to_string(metadata_file)?; - let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); - - vec![ - ("llama.attention.head_count", gguf_file::Value::from_u32(metadata["num_attention_heads"].as_u64().unwrap() as u32)), - ("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)), - ("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)), - ("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)), - ("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)), - ("llama.rope.dimension_count", gguf_file::Value::from_u32( - (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32), - )), - ] - } else { - vec![] - }; - gguf_file::write(&mut out_file, &gguf_metadata, &qtensors)?; Ok(()) } From 4fee75a4e275ab09d7cf8eeb32a301652e3e34e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Mon, 30 Dec 2024 17:31:31 +0100 Subject: [PATCH 07/11] initial metal support --- .../examples/quantized-bitnet/main.rs | 9 +- candle-metal-kernels/src/lib.rs | 5 +- candle-metal-kernels/src/quantized.metal | 124 ++++++++++++++++++ 3 files changed, 131 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index d91c68d400..3ae5e51f46 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -54,7 +54,7 @@ impl Which { Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens", - Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Instruct-1.58bit", + Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Base-1.58bit", Self::Falcon3_7b1_58 => "tiiuae/Falcon3-7B-Instruct-1.58bit", } } @@ -305,11 +305,10 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - if args.which.is_falcon() { - format!("<|user|>\n{prompt}\n<|assistant|>") - } else if args.which.is_llama() { + + if args.which.is_llama() { format!( - "<|start_header_id|>user<|end_header_id|>\n\n{prompt}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + "<|start_header_id|>user<|end_header_id|>{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" ) } else { prompt diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1b8a947345..fe82e87196 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2210,6 +2210,7 @@ pub fn call_quantized_matmul_mv_t( | GgmlDType::Q5_0 | GgmlDType::Q5_1 | GgmlDType::Q8_0 + | GgmlDType::Q2b0 | GgmlDType::Q8_1 => { let nth0 = 8; let nth1 = 8; @@ -2230,7 +2231,7 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } - GgmlDType::Q2b0 | GgmlDType::Q3K | GgmlDType::Q5K => { + GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; let align = 4; @@ -2281,7 +2282,7 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", - GgmlDType::Q2b0 => todo!(), + GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32" }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index fef6ac54f8..ac860fd06f 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -42,6 +42,17 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +#define Q2B_0 32 +typedef struct { + uint8_t qs[Q2B_0 / 8]; // Every single bit represents positive values, is a vector of {0, 1} + uint8_t qd[Q2B_0 / 8]; // Every single bit represents negative values, is a vector of {0, 1} +} block_q2b_0; + +#define QI8 32 +typedef struct { + int8_t qs[QI8]; // quants +} block_qi8; + #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { @@ -3469,6 +3480,119 @@ kernel void kernel_mul_mv_q6_K_f32( kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +#define NB_Q2B_0 8 +void kernel_mul_mv_q2b0_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nr = N_DST; + const int nsg = N_SIMDGROUP; + const int nw = N_SIMDWIDTH; + const int nb = ne00 / Q2B_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint i12 = im % ne12; + const uint i13 = im / ne12; + + const uint offset0 = first_row * nb + + (i12 / r2) * (nb * ne01) + + (i13 / r3) * (nb * ne01 * ne02); + + device const block_q2b_0 * x = (device const block_q2b_0 *) src0 + offset0; + device const float * y = (device const float *) src1 + + r1 * ne10 + + im * ne00 * ne1; + + float yl[NB_Q2B_0]; + float sumf[nr]; + for (int i = 0; i < nr; ++i) { + sumf[i] = 0.0f; + } + + const int ix = tiisg / 4; + const int il = tiisg % 4; + + device const float * yb = y + ix * Q2B_0 + NB_Q2B_0 * il; + + for (int ib = ix; ib < nb; ib += (nw / 4)) { + for (int i = 0; i < NB_Q2B_0; ++i) { + yl[i] = yb[i]; + } + + for (int row = 0; row < nr; row++) { + device const block_q2b_0 * bx = x + ib + row * nb; + + float sumq = 0.f; + const int startBit = NB_Q2B_0 * il; + + for (int iBit = 0; iBit < NB_Q2B_0; iBit++) { + int bit = startBit + iBit; + int bByte = bit >> 3; + int bMask = 1 << (bit & 7); + int isPos = ((bx->qs[bByte] & bMask) != 0) ? 1 : 0; + int isNeg = ((bx->qd[bByte] & bMask) != 0) ? 1 : 0; + + sumq += float(isPos - isNeg) * yl[iBit]; + } + + sumf[row] += sumq; + } + + yb += NB_Q2B_0 * nw; + } + + for (int row = 0; row < nr; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && (first_row + row) < ne01) { + dst[r1 * ne0 + im * ne0 * ne1 + (first_row + row)] = tot; + } + } +} + + +[[host_name("kernel_mul_mv_q2b0_f32")]] +kernel void kernel_mul_mv_q2b0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2b0_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template From 4c669dcac62c321365f9cb28bd96d694c9240117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 31 Dec 2024 10:36:07 +0100 Subject: [PATCH 08/11] Add initial commit --- .../examples/quantized-bitnet/main.rs | 89 +++++++++++-------- candle-metal-kernels/src/lib.rs | 7 +- candle-metal-kernels/src/quantized.metal | 9 +- .../src/models/quantized_llama_bitnet.rs | 19 ++-- 4 files changed, 68 insertions(+), 56 deletions(-) diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 3ae5e51f46..2b024b50b4 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -28,34 +28,35 @@ enum Prompt { #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum Which { - #[value(name = "falcon3-1b-1.58")] - Falcon3_1b1_58, + #[value(name = "falcon3-1b-instruct-1.58")] + Falcon3_1bInstruct1_58, + #[value(name = "falcon3-3b-instruct-1.58")] + Falcon3_3bInstruct1_58, #[value(name = "falcon3-3b-1.58")] Falcon3_3b1_58, + #[value(name = "falcon3-7b-instruct-1.58")] + Falcon3_7bInstruct1_58, #[value(name = "falcon3-7b-1.58")] Falcon3_7b1_58, + #[value(name = "falcon3-10b-instruct-1.58")] + Falcon3_10bInstruct1_58, #[value(name = "falcon3-10b-1.58")] Falcon3_10b1_58, #[value(name = "llama3-8b-1.58")] Llama3_8b1_58, } -impl Which { - fn is_falcon(&self) -> bool { - matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58 | Self::Falcon3_7b1_58 | Self::Falcon3_10b1_58) - } - - fn is_llama(&self) -> bool { - matches!(self, Self::Llama3_8b1_58) - } - +impl Which { fn tokenizer_repo(&self) -> &'static str { match self { - Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", - Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", - Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens", - Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Base-1.58bit", - Self::Falcon3_7b1_58 => "tiiuae/Falcon3-7B-Instruct-1.58bit", + Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + Self::Falcon3_3bInstruct1_58 => "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + Self::Falcon3_3b1_58 => "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + Self::Falcon3_7bInstruct1_58 => "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + Self::Falcon3_10b1_58 => "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + Self::Falcon3_10bInstruct1_58 => "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + Self::Falcon3_7b1_58 => "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + Self::Llama3_8b1_58 => "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", } } } @@ -123,7 +124,7 @@ struct Args { repeat_last_n: usize, /// The model size to use. - #[arg(long, default_value = "falcon3-1b-1.58")] + #[arg(long, default_value = "falcon3-1b-instruct-1.58")] which: Which, /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. @@ -154,25 +155,37 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let (repo, filename) = match self.which { - Which::Falcon3_1b1_58 => ( - "tiiuae/Falcon3-1B-Instruct-1.58bit", - "Falcon3-1B-Instruct-1.58bit.gguf", + Which::Falcon3_1bInstruct1_58 => ( + "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + "Falcon3-1B-Instruct-1.58bit-q2b0.gguf", + ), + Which::Falcon3_3bInstruct1_58 => ( + "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + "Falcon3-3B-Instruct-1.58bit-q2b0.gguf", ), Which::Falcon3_3b1_58 => ( - "tiiuae/Falcon3-3B-Instruct-1.58bit", - "Falcon3-3B-Instruct-1.58bit.gguf", + "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + "Falcon3-3B-Base-1.58bit-q2b0.gguf", ), - Which::Falcon3_10b1_58 => ( - "tiiuae/Falcon3-10B-Instruct-1.58bit", - "Falcon3-10B-Instruct-1.58bit.gguf", + Which::Falcon3_7bInstruct1_58 => ( + "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + "Falcon3-7B-Instruct-1.58bit-q2b0.gguf", ), Which::Falcon3_7b1_58 => ( - "tiiuae/Falcon3-7B-Instruct-1.58bit", - "Falcon3-7B-Instruct-1.58bit.gguf", + "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + "Falcon3-7B-Base-1.58bit-q2b0.gguf", + ), + Which::Falcon3_10b1_58 => ( + "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + "Falcon3-10B-Base-1.58bit-q2b0.gguf", + ), + Which::Falcon3_10bInstruct1_58 => ( + "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + "Falcon3-10B-Instruct-1.58bit-q2b0.gguf", ), Which::Llama3_8b1_58 => ( - "HF1BitLLM/Llama3-8B-1.58-100B-tokens", - "Llama3-8B-1.58bit.gguf", + "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", + "Llama3-8B-1.58-100B-tokens-q2b0.gguf", ), }; let revision = "main"; @@ -306,13 +319,7 @@ fn main() -> anyhow::Result<()> { } } - if args.which.is_llama() { - format!( - "<|start_header_id|>user<|end_header_id|>{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" - ) - } else { - prompt - } + prompt } }; @@ -376,11 +383,17 @@ fn main() -> anyhow::Result<()> { } let eos_tokens = match args.which { - Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => { + Which::Falcon3_10b1_58 | + Which::Falcon3_10bInstruct1_58 | + Which::Falcon3_7bInstruct1_58 | + Which::Falcon3_7b1_58 | + Which::Falcon3_3bInstruct1_58 | + Which::Falcon3_3b1_58 | + Which::Falcon3_1bInstruct1_58 => { vec!["<|endoftext|>"] } Which::Llama3_8b1_58 => { - vec!["<|eot_id|>"] + vec!["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>"] } }; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index fe82e87196..ddf0c6bb2f 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2210,7 +2210,6 @@ pub fn call_quantized_matmul_mv_t( | GgmlDType::Q5_0 | GgmlDType::Q5_1 | GgmlDType::Q8_0 - | GgmlDType::Q2b0 | GgmlDType::Q8_1 => { let nth0 = 8; let nth1 = 8; @@ -2231,6 +2230,12 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } + GgmlDType::Q2b0 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index ac860fd06f..6972d38250 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -3544,10 +3544,11 @@ void kernel_mul_mv_q2b0_f32_impl( int bit = startBit + iBit; int bByte = bit >> 3; int bMask = 1 << (bit & 7); - int isPos = ((bx->qs[bByte] & bMask) != 0) ? 1 : 0; - int isNeg = ((bx->qd[bByte] & bMask) != 0) ? 1 : 0; - - sumq += float(isPos - isNeg) * yl[iBit]; + if ((bx->qs[bByte] & bMask) != 0) { + sumq += yl[iBit]; + } else if ((bx->qd[bByte] & bMask) != 0) { + sumq -= yl[iBit]; + } } sumf[row] += sumq; diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index e9a38a26a7..bfc3113ce7 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -62,20 +62,13 @@ impl BitQMatMul { Ok(Self { inner, span, weight_scale }) } - pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { - let target_dim = x.rank().saturating_sub(1); - - let max_abs = x.abs()?.max_keepdim(target_dim)?; - - let scale = (127.0/ &max_abs)?; - - let scaled_rounded = x - .broadcast_mul(&scale)? - .round()? - .clamp(-128f32, 127f32)?; - + fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { + let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; + + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?; - Ok((scaled_rounded, scale)) + Ok((y, scale)) } fn forward(&self, x: &Tensor) -> Result { From 49d44b5190e39aa1a331d2bee9d3ef16d09f491b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 31 Dec 2024 10:40:43 +0100 Subject: [PATCH 09/11] fix unit tests --- candle-pyo3/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index b8695cc8a0..3f12cbe8ac 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -1456,7 +1456,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) let converted_metadata: Vec<_> = metadata .iter() - .map(|(name, value)| (name.as_str(), value)) + .map(|(name, value)| (name.as_str(), value.clone())) .collect(); let converted_tensors: Vec<_> = tensors From eefc3365e88daea45bc3166dd1472ef71c4025c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Wed, 1 Jan 2025 12:33:41 +0100 Subject: [PATCH 10/11] Q2B1: Add new quant with optimized performance --- candle-core/src/quantized/ggml_file.rs | 3 + candle-core/src/quantized/k_quants.rs | 203 ++++++++++++++++-- candle-core/src/quantized/metal.rs | 5 + candle-core/src/quantized/mod.rs | 6 + candle-core/src/quantized/neon.rs | 83 ++++++- .../examples/quantized-bitnet/main.rs | 30 ++- candle-metal-kernels/src/lib.rs | 6 +- candle-metal-kernels/src/quantized.metal | 152 +++++++++++++ candle-transformers/src/models/mod.rs | 2 +- .../src/models/quantized_llama_bitnet.rs | 77 +++++-- tensor-tools/src/main.rs | 104 ++++++--- 11 files changed, 579 insertions(+), 92 deletions(-) diff --git a/candle-core/src/quantized/ggml_file.rs b/candle-core/src/quantized/ggml_file.rs index 7cde3eaa72..4cbddfb9c5 100644 --- a/candle-core/src/quantized/ggml_file.rs +++ b/candle-core/src/quantized/ggml_file.rs @@ -192,6 +192,9 @@ pub fn qtensor_from_ggml( GgmlDType::QI8 => { from_raw_data::(raw_data, size_in_bytes, dims, device) } + GgmlDType::Q2b1 => { + from_raw_data::(raw_data, size_in_bytes, dims, device) + } _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), } } diff --git a/candle-core/src/quantized/k_quants.rs b/candle-core/src/quantized/k_quants.rs index e9c2379868..5dd9595773 100644 --- a/candle-core/src/quantized/k_quants.rs +++ b/candle-core/src/quantized/k_quants.rs @@ -20,6 +20,7 @@ pub const QK5_1: usize = 32; pub const QK8_0: usize = 32; pub const QK8_1: usize = 32; pub const Q2B_0: usize = 32; +pub const Q2B_1: usize = 32; pub const QI8: usize = 32; pub trait GgmlType: Sized + Clone + Send + Sync { @@ -168,7 +169,14 @@ const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); #[derive(Debug, Clone, PartialEq)] #[repr(C)] -pub struct BlockQI8{ +pub struct BlockQ2b1 { + pub(crate) qs: [u8; Q2B_0 / 4], // Every single 2-bit represents {-1, 0, 1} +} +const _: () = assert!(Q2B_0 / 4 == std::mem::size_of::()); + +#[derive(Debug, Clone, PartialEq)] +#[repr(C)] +pub struct BlockQI8 { pub(crate) qs: [i8; QI8], } const _: () = assert!(std::mem::size_of::() == QI8); @@ -1857,7 +1865,6 @@ impl GgmlType for BlockQ8K { } } - impl GgmlType for BlockQ2b0 { const DTYPE: GgmlDType = GgmlDType::Q2b0; const BLCK_SIZE: usize = Q2B_0; @@ -1869,7 +1876,7 @@ impl GgmlType for BlockQ2b0 { Self::vec_dot_unopt(n, xs, ys) } - + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { if n % Q2B_0 != 0 { crate::bail!("vec_dot_q2b0_q8k: {n} is not divisible by {Q2B_0}"); @@ -1881,20 +1888,29 @@ impl GgmlType for BlockQ2b0 { let qs = x.qs[i]; let qd = x.qd[i]; let mut y_cache = [0i32; 8]; - y_cache.copy_from_slice(&y.qs[i * 8..(i + 1) * 8].iter().map(|&x| x as i32).collect::>()[..]); - - let pos_sum: i32 = (0..8).map(|bit| { - let mask = 1 << bit; - let is_active = ((qs & mask) >> bit) as i32; - is_active * y_cache[bit] - }).sum(); - - let neg_sum: i32 = (0..8).map(|bit| { - let mask = 1 << bit; - let is_active = ((qd & mask) >> bit) as i32; - is_active * y_cache[bit] - }).sum(); - + y_cache.copy_from_slice( + &y.qs[i * 8..(i + 1) * 8] + .iter() + .map(|&x| x as i32) + .collect::>()[..], + ); + + let pos_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qs & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + + let neg_sum: i32 = (0..8) + .map(|bit| { + let mask = 1 << bit; + let is_active = ((qd & mask) >> bit) as i32; + is_active * y_cache[bit] + }) + .sum(); + isum += pos_sum - neg_sum; } sumf += isum as f32; @@ -1904,7 +1920,11 @@ impl GgmlType for BlockQ2b0 { fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { if xs.len() % Q2B_0 != 0 { - crate::bail!("quantize_row_q2b0: size mismatch {} not divisible by {}", xs.len(), Q2B_0); + crate::bail!( + "quantize_row_q2b0: size mismatch {} not divisible by {}", + xs.len(), + Q2B_0 + ); } for (block, x) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { @@ -1928,7 +1948,11 @@ impl GgmlType for BlockQ2b0 { fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { if ys.len() % Q2B_0 != 0 { - crate::bail!("dequantize_row_q2b0: size mismatch {} not divisible by {}", ys.len(), Q2B_0); + crate::bail!( + "dequantize_row_q2b0: size mismatch {} not divisible by {}", + ys.len(), + Q2B_0 + ); } for (block, y) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { @@ -1951,6 +1975,145 @@ impl GgmlType for BlockQ2b0 { } } +const fn build_decode_q2b1_lut_i8() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = build_decode_q2b1_lut_i8(); +const fn build_decode_q2b1_lut_f32() -> [[f32; 4]; 256] { + let mut table = [[0.0_f32; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0.0_f32; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0.0, + 0b01 => 1.0, + 0b10 => -1.0, + 0b11 => 0.0, + _ => unreachable!(), + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table +} + +static LUT_DECODE_Q2B1_F32: [[f32; 4]; 256] = build_decode_q2b1_lut_f32(); +impl GgmlType for BlockQ2b1 { + const DTYPE: GgmlDType = GgmlDType::Q2b1; + const BLCK_SIZE: usize = Q2B_0; + type VecDotType = BlockQI8; + + fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + #[cfg(target_feature = "neon")] + return super::neon::vec_dot_q2b1_qi8(n, xs, ys); + Self::vec_dot_unopt(n, xs, ys) + } + + fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result { + if n % Q2B_0 != 0 { + crate::bail!("vec_dot_q2b1_qi8: n = {n} is not divisible by {Q2B_0}"); + } + + let mut sumf = 0.0; + + for (block_x, block_y) in xs.iter().zip(ys.iter()) { + let mut isum = 0i32; + + for i in 0..(Q2B_0 / 4) { + let enc_x = block_x.qs[i]; + let y_slice = &block_y.qs[i * 4..(i + 1) * 4]; + + let dec_x = &LUT_DECODE_Q2B1_I8[enc_x as usize]; + + for b in 0..4 { + let x_val = dec_x[b] as i32; + let y_val = y_slice[b] as i32; + isum += x_val * y_val; + } + } + sumf += isum as f32; + } + + Ok(sumf) + } + + fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> { + if xs.len() % Q2B_0 != 0 { + crate::bail!( + "quantize_row_q2b1: size {} is not divisible by {}", + xs.len(), + Q2B_0 + ); + } + + for (block, chunk) in ys.iter_mut().zip(xs.chunks_exact(Q2B_0)) { + for (i, subchunk) in chunk.chunks_exact(4).enumerate() { + let mut encoded: u8 = 0; + for (b, &val) in subchunk.iter().enumerate() { + let bits = if val > 0.0 { + 0b01 + } else if val < 0.0 { + 0b10 + } else { + 0b00 + }; + encoded |= bits << (2 * b); + } + block.qs[i] = encoded; + } + } + + Ok(()) + } + + fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> { + if ys.len() % Q2B_0 != 0 { + crate::bail!( + "dequantize_row_q2b1: size {} is not divisible by {}", + ys.len(), + Q2B_0 + ); + } + + for (block, out_chunk) in xs.iter().zip(ys.chunks_exact_mut(Q2B_0)) { + for (i, subchunk) in out_chunk.chunks_exact_mut(4).enumerate() { + let enc = block.qs[i]; + let dec = &LUT_DECODE_Q2B1_F32[enc as usize]; + subchunk.copy_from_slice(dec); + } + } + + Ok(()) + } +} + impl GgmlType for BlockQI8 { const DTYPE: GgmlDType = GgmlDType::QI8; const BLCK_SIZE: usize = QI8; @@ -1990,7 +2153,7 @@ impl GgmlType for BlockQI8 { for (i, ys) in ys.iter_mut().enumerate() { let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE]; for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) { - *y = x as i8; + *y = x as i8; } } Ok(()) diff --git a/candle-core/src/quantized/metal.rs b/candle-core/src/quantized/metal.rs index 4ed1941ed9..426721db8f 100644 --- a/candle-core/src/quantized/metal.rs +++ b/candle-core/src/quantized/metal.rs @@ -107,6 +107,10 @@ impl QMetalStorage { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQ2b0::to_float(&vec, &mut out)?; } + GgmlDType::Q2b1 => { + let vec: Vec = read_to_vec(&buffer, block_len); + crate::quantized::BlockQ2b1::to_float(&vec, &mut out)?; + } GgmlDType::QI8 => { let vec: Vec = read_to_vec(&buffer, block_len); crate::quantized::BlockQI8::to_float(&vec, &mut out)?; @@ -234,6 +238,7 @@ impl From for candle_metal_kernels::GgmlDType { GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16, GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32, GgmlDType::Q2b0 => candle_metal_kernels::GgmlDType::Q2b0, + GgmlDType::Q2b1 => candle_metal_kernels::GgmlDType::Q2b1, GgmlDType::QI8 => todo!(), } } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 728466e87a..4912f0f33d 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -147,6 +147,7 @@ pub enum GgmlDType { Q6K, Q8K, Q2b0, + Q2b1, QI8, } @@ -169,6 +170,7 @@ impl GgmlDType { 15 => Self::Q8K, 40 => Self::Q2b0, 41 => Self::QI8, + 42 => Self::Q2b1, _ => crate::bail!("unknown dtype for tensor {u}"), }; Ok(dtype) @@ -192,6 +194,7 @@ impl GgmlDType { Self::Q8K => 15, Self::Q2b0 => 40, Self::QI8 => 41, + Self::Q2b1 => 42, } } @@ -214,6 +217,7 @@ impl GgmlDType { Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]), Self::Q2b0 => Box::new(vec![BlockQ2b0::zeros(); elem_count / BlockQ2b0::BLCK_SIZE]), Self::QI8 => Box::new(vec![BlockQI8::zeros(); elem_count / BlockQI8::BLCK_SIZE]), + Self::Q2b1 => Box::new(vec![BlockQ2b1::zeros(); elem_count / BlockQ2b1::BLCK_SIZE]), } } /// The type size for blocks in bytes. @@ -237,6 +241,7 @@ impl GgmlDType { Self::Q8K => std::mem::size_of::(), Self::Q2b0 => std::mem::size_of::(), Self::QI8 => std::mem::size_of::(), + Self::Q2b1 => std::mem::size_of::(), } } @@ -252,6 +257,7 @@ impl GgmlDType { Self::Q8_0 => k_quants::QK8_0, Self::Q8_1 => k_quants::QK8_1, Self::Q2b0 => k_quants::Q2B_0, + Self::Q2b1 => k_quants::Q2B_1, Self::QI8 => k_quants::QI8, Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => k_quants::QK_K, } diff --git a/candle-core/src/quantized/neon.rs b/candle-core/src/quantized/neon.rs index 8387555940..53de2b9007 100644 --- a/candle-core/src/quantized/neon.rs +++ b/candle-core/src/quantized/neon.rs @@ -1,6 +1,10 @@ -use super::{k_quants::{ - BlockQ2K, BlockQ2b0, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K -}, BlockQI8}; +use super::{ + k_quants::{ + BlockQ2K, BlockQ2b0, BlockQ2b1, BlockQ3K, BlockQ4K, BlockQ4_0, BlockQ5K, BlockQ6K, + BlockQ8K, BlockQ8_0, Q2B_0, QK8_0, QK_K, + }, + BlockQI8, +}; use crate::Result; use byteorder::{ByteOrder, LittleEndian}; @@ -11,6 +15,7 @@ use core::arch::arm::*; #[allow(unused_imports)] #[cfg(target_arch = "aarch64")] use core::arch::aarch64::*; +use std::ptr; #[inline(always)] unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t { @@ -575,6 +580,78 @@ pub(crate) fn vec_dot_q2b0_qi8(n: usize, xs: &[BlockQ2b0], ys: &[BlockQI8]) -> c Ok(sumf) } +static LUT_DECODE_Q2B1_I8: [[i8; 4]; 256] = { + const fn build_decode_table() -> [[i8; 4]; 256] { + let mut table = [[0i8; 4]; 256]; + let mut i = 0; + while i < 256 { + let byte = i as u8; + let mut dec = [0i8; 4]; + let mut b = 0; + while b < 4 { + let code = (byte >> (2 * b)) & 0b11; + dec[b as usize] = match code { + 0b00 => 0, + 0b01 => 1, + 0b10 => -1, + 0b11 => 0, + _ => 0, + }; + b += 1; + } + table[i] = dec; + i += 1; + } + table + } + build_decode_table() +}; + +unsafe fn decode_q2b1_16(input: &[u8]) -> int8x16_t { + debug_assert_eq!(input.len(), 4, "input must be 4 bytes long"); + let mut tmp = [0i8; 16]; + + for (i, &byte) in input.iter().enumerate() { + let decoded4 = LUT_DECODE_Q2B1_I8[byte as usize]; + tmp[i * 4..i * 4 + 4].copy_from_slice(&decoded4); + } + + vld1q_s8(tmp.as_ptr()) +} + +#[inline(always)] +pub fn vec_dot_q2b1_qi8(n: usize, xs: &[BlockQ2b1], ys: &[BlockQI8]) -> crate::Result { + let blocks = n / 32; + + let mut total_sum = 0i32; + + unsafe { + for i in 0..blocks { + let x_block = &xs[i]; + let y_block = &ys[i]; + + let x_dec_lo = decode_q2b1_16(&x_block.qs[0..4]); + let x_dec_hi = decode_q2b1_16(&x_block.qs[4..8]); + + let y_lo = vld1q_s8(y_block.qs[0..16].as_ptr()); + let y_hi = vld1q_s8(y_block.qs[16..32].as_ptr()); + + let mut acc0 = vdupq_n_s32(0); + let mut acc1 = vdupq_n_s32(0); + + acc0 = vaddq_s32(acc0, vdotq_s32(x_dec_lo, y_lo)); + acc1 = vaddq_s32(acc1, vdotq_s32(x_dec_hi, y_hi)); + + let sum0 = vaddvq_s32(acc0); + let sum1 = vaddvq_s32(acc1); + + total_sum += sum0 + sum1; + } + } + + Ok(total_sum as f32) +} + #[inline(always)] pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Result { if n % QK_K != 0 { diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 2b024b50b4..a40f3a8747 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -5,9 +5,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use clap::{Parser, ValueEnum}; -use tracing_subscriber::fmt::time::FormatTime; use std::io::Write; -use tokenizers::{Tokenizer, AddedToken}; +use tokenizers::{AddedToken, Tokenizer}; +use tracing_subscriber::fmt::time::FormatTime; use candle::quantized::{ggml_file, gguf_file}; use candle::Tensor; @@ -46,7 +46,7 @@ enum Which { Llama3_8b1_58, } -impl Which { +impl Which { fn tokenizer_repo(&self) -> &'static str { match self { Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", @@ -303,9 +303,7 @@ fn main() -> anyhow::Result<()> { let mut pre_prompt_tokens = vec![]; for prompt_index in 0.. { let prompt_str = match &prompt { - Prompt::One(prompt) => { - prompt.clone() - } + Prompt::One(prompt) => prompt.clone(), Prompt::Interactive | Prompt::Chat => { let is_interactive = matches!(prompt, Prompt::Interactive); print!("> "); @@ -318,11 +316,11 @@ fn main() -> anyhow::Result<()> { prompt.pop(); } } - - prompt + + prompt.clone() } }; - + print!("{}", &prompt_str); let tokens = tos .tokenizer() @@ -383,13 +381,13 @@ fn main() -> anyhow::Result<()> { } let eos_tokens = match args.which { - Which::Falcon3_10b1_58 | - Which::Falcon3_10bInstruct1_58 | - Which::Falcon3_7bInstruct1_58 | - Which::Falcon3_7b1_58 | - Which::Falcon3_3bInstruct1_58 | - Which::Falcon3_3b1_58 | - Which::Falcon3_1bInstruct1_58 => { + Which::Falcon3_10b1_58 + | Which::Falcon3_10bInstruct1_58 + | Which::Falcon3_7bInstruct1_58 + | Which::Falcon3_7b1_58 + | Which::Falcon3_3bInstruct1_58 + | Which::Falcon3_3b1_58 + | Which::Falcon3_1bInstruct1_58 => { vec!["<|endoftext|>"] } Which::Llama3_8b1_58 => { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index ddf0c6bb2f..7e683596c2 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2165,6 +2165,7 @@ pub enum GgmlDType { F16, F32, Q2b0, + Q2b1, } #[allow(clippy::too_many_arguments)] @@ -2230,7 +2231,7 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } - GgmlDType::Q2b0 => { + GgmlDType::Q2b1 | GgmlDType::Q2b0 => { let nth0 = 8; let nth1 = 8; let align = 8; @@ -2287,7 +2288,8 @@ pub fn call_quantized_matmul_mv_t( GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32", GgmlDType::F16 => "kernel_mul_mv_f16_f32", GgmlDType::F32 => "kernel_mul_mv_f32_f32", - GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32" + GgmlDType::Q2b0 => "kernel_mul_mv_q2b0_f32", + GgmlDType::Q2b1 => "kernel_mul_mv_q2b1_f32" }; let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index 6972d38250..73a3744556 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -48,6 +48,11 @@ typedef struct { uint8_t qd[Q2B_0 / 8]; // Every single bit represents negative values, is a vector of {0, 1} } block_q2b_0; +#define Q2B_1 32 +typedef struct { + uint8_t qs[Q2B_1 / 4]; // Every single 2-bit represents {-1, 0, 1} +} block_q2b_1; + #define QI8 32 typedef struct { int8_t qs[QI8]; // quants @@ -3594,6 +3599,153 @@ kernel void kernel_mul_mv_q2b0_f32( kernel_mul_mv_q2b0_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); } +#define NB_Q2B_1 8 +constant float code_lut[4] = { 0.0f, 1.0f, -1.0f, 0.0f }; + +inline void kernel_mul_mv_q2b1_f32_impl( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne10, + constant int64_t & ne12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig [[threadgroup_position_in_grid]], + uint tiisg [[thread_index_in_simdgroup]], + uint sgitg [[simdgroup_index_in_threadgroup]] +) { + // These come from your headers or #defines + const int nr = N_DST; // number of "rows" each thread processes + const int nsg = N_SIMDGROUP; // number of simdgroups per dimension + const int nw = N_SIMDWIDTH; // simd width + + const int nb = ne00 / Q2B_0; // number of Q2B_0 blocks in a row of X + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + // Each simdgroup processes 'nr' rows, so figure out which "chunk" we do: + const int first_row = (r0 * nsg + sgitg) * nr; + + // Flatten z index using ne12 + const uint i12 = im % ne12; + const uint i13 = im / ne12; + + // Compute offset into src0 (the quantized blocks array) + const uint offset0 = + first_row * nb + + (i12 / r2) * (nb * ne01) + + (i13 / r3) * (nb * ne01 * ne02); + + // Pointer to the first quantized block + device const block_q2b_1 * x = (device const block_q2b_1 *)src0 + offset0; + + // Pointer to the appropriate row of src1 + device const float * y = src1 + + r1 * ne10 // stride in y dimension + + im * ne00 * ne1; // stride in z dimension + + // Accumulators for partial sums, one per row + float sumf[nr]; + for (int i = 0; i < nr; i++) { + sumf[i] = 0.0f; + } + + // Figure out which quarter of the thread is ours + const int ix = tiisg / 4; + const int il = tiisg % 4; + + // This pointer yb will move through src1 in steps of NB_Q2B_1*nw + device const float * yb = y + ix * Q2B_0 + NB_Q2B_1 * il; + + // Main loop: each thread processes some subset of 'nb' blocks + for (int ib = ix; ib < nb; ib += (nw / 4)) { + + // Load 8 floats (NB_Q2B_0) into local array to keep them in registers + float yl[NB_Q2B_0]; + { + // Compiler usually unrolls such a small loop automatically + // but you can force it: + #pragma unroll 8 + for (int i = 0; i < NB_Q2B_0; i++) { + yl[i] = yb[i]; + } + } + + // For each row in [0..nr), compute partial dot-product + // with quantized data from 'x + ib + row * nb' + for (int row = 0; row < nr; row++) { + device const block_q2b_1 * bq = x + ib + row * nb; + + float sumq = 0.0f; + + // Each Q2B_0 = 8 bits, but we do them in steps of 2 + // 'startBit' is the first bit for the code. + // We unroll this loop as well. + const int startBit = NB_Q2B_1 * il; + #pragma unroll 8 + for (int iBit = 0; iBit < NB_Q2B_0; iBit++) { + const int bit = startBit + iBit; + const int bByte = bit >> 2; // bit / 4 + const int shift = 2 * (bit & 3); // (bit % 4)*2 + const int code = (bq->qs[bByte] >> shift) & 0x3; + + // Use the LUT to get +1 / -1 / 0 + sumq += code_lut[code] * yl[iBit]; + } + + sumf[row] += sumq; + } + + // Advance yb to the next group of 8 floats + yb += NB_Q2B_1 * nw; + } + + // Reduction across the simdgroup: each row's sum -> simd_sum(...) + // Then store to output if we're the "first lane" (tiisg == 0) + for (int row = 0; row < nr; row++) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && (first_row + row) < ne01) { + dst[r1 * ne0 + im * ne0 * ne1 + (first_row + row)] = tot; + } + } +} + + +[[host_name("kernel_mul_mv_q2b1_f32")]] +kernel void kernel_mul_mv_q2b1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_q2b1_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg); +} + //============================= templates and their specializations ============================= // NOTE: this is not dequantizing - we are simply fitting the template diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 34fc8b57e4..0fca20e2d4 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -75,6 +75,7 @@ pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_llama2_c; +pub mod quantized_llama_bitnet; pub mod quantized_metavoice; pub mod quantized_mistral; pub mod quantized_mixformer; @@ -110,4 +111,3 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; -pub mod quantized_llama_bitnet; \ No newline at end of file diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index bfc3113ce7..ed745f622e 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -59,15 +59,22 @@ impl BitQMatMul { let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); let weight_scale = weight_scale.dequantize(&weight_scale.device())?; - Ok(Self { inner, span, weight_scale }) + Ok(Self { + inner, + span, + weight_scale, + }) } fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { - let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = x + .abs()? + .max_keepdim(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?; let scale = (127.0 / scale)?; let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?; - + Ok((y, scale)) } @@ -75,12 +82,10 @@ impl BitQMatMul { let (x, xscale) = self.activation_quant(x)?; let _enter = self.span.enter(); let scale = self.weight_scale.broadcast_mul(&xscale)?; - self.inner.forward(&x)? - .broadcast_div(&scale) + self.inner.forward(&x)?.broadcast_div(&scale) } } - #[derive(Debug, Clone)] struct Mlp { feed_forward_w1: BitQMatMul, @@ -337,11 +342,14 @@ impl ModelWeights { let attention_wo_ws = ct.remove(&format!("{prefix}.attention.wo.weight_scale"))?; let mlp_or_moe = { let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?; - let feed_forward_w1_ws = ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; + let feed_forward_w1_ws = + ct.remove(&format!("{prefix}.feed_forward.w1.weight_scale"))?; let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?; - let feed_forward_w2_ws = ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; + let feed_forward_w2_ws = + ct.remove(&format!("{prefix}.feed_forward.w2.weight_scale"))?; let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?; - let feed_forward_w3_ws = ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; + let feed_forward_w3_ws = + ct.remove(&format!("{prefix}.feed_forward.w3.weight_scale"))?; MlpOrMoe::Mlp(Mlp { feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, @@ -431,13 +439,21 @@ impl ModelWeights { for layer_idx in 0..block_count { let prefix = format!("blk.{layer_idx}"); let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; - let attention_wq_ws = ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; + let attention_wq_ws = + ct.tensor(reader, &format!("{prefix}.attn_q.weight_scale"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; - let attention_wk_ws = ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; + let attention_wk_ws = + ct.tensor(reader, &format!("{prefix}.attn_k.weight_scale"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; - let attention_wv_ws = ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; - let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; - let attention_wo_ws = ct.tensor(reader, &format!("{prefix}.attn_output.weight_scale"), device)?; + let attention_wv_ws = + ct.tensor(reader, &format!("{prefix}.attn_v.weight_scale"), device)?; + let attention_wo = + ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; + let attention_wo_ws = ct.tensor( + reader, + &format!("{prefix}.attn_output.weight_scale"), + device, + )?; let mlp_or_moe = if n_expert <= 1 { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; @@ -463,21 +479,36 @@ impl ModelWeights { for i in 0..n_expert { let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; - let feed_forward_w1_ws = - ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight_scale"), device)?; + let feed_forward_w1_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_gate.{i}.weight_scale"), + device, + )?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; - let feed_forward_w2_ws = - ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight_scale"), device)?; + let feed_forward_w2_ws = ct.tensor( + reader, + &format!("{prefix}.ffn_down.{i}.weight_scale"), + device, + )?; let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; let feed_forward_w3_ws = ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight_scale"), device)?; - + experts.push(Mlp { - feed_forward_w1: BitQMatMul::from_qtensor(feed_forward_w1, feed_forward_w1_ws)?, - feed_forward_w2: BitQMatMul::from_qtensor(feed_forward_w2, feed_forward_w2_ws)?, - feed_forward_w3: BitQMatMul::from_qtensor(feed_forward_w3, feed_forward_w3_ws)?, + feed_forward_w1: BitQMatMul::from_qtensor( + feed_forward_w1, + feed_forward_w1_ws, + )?, + feed_forward_w2: BitQMatMul::from_qtensor( + feed_forward_w2, + feed_forward_w2_ws, + )?, + feed_forward_w3: BitQMatMul::from_qtensor( + feed_forward_w3, + feed_forward_w3_ws, + )?, }) } MlpOrMoe::MoE { @@ -567,4 +598,4 @@ impl ModelWeights { let _enter = self.span_output.enter(); self.output.forward(&x) } -} \ No newline at end of file +} diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index 4a1e9bc039..78ba68840c 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -14,7 +14,13 @@ enum QuantizationMode { } impl QuantizationMode { - fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType, bitnet_mode: bool) -> Result { + fn quantize( + &self, + name: &str, + tensor: QTensor, + dtype: GgmlDType, + bitnet_mode: bool, + ) -> Result { match self { Self::Llama => { // Same behavior as the llama.cpp quantization. @@ -50,6 +56,8 @@ enum Quantization { Q8_1, #[value(name = "q2b0")] Q2b0, + #[value(name = "q2b1")] + Q2b1, #[value(name = "qi8")] QI8, Q2k, @@ -81,6 +89,7 @@ impl Quantization { Quantization::F32 => GgmlDType::F32, Quantization::Q2b0 => GgmlDType::Q2b0, Quantization::QI8 => GgmlDType::QI8, + Quantization::Q2b1 => GgmlDType::Q2b1, } } } @@ -441,11 +450,11 @@ fn unpack_bitnet_weights(tensor: &Tensor) -> Result { } use core::num; +use rayon::prelude::*; +use serde_json::Value; use std::collections::HashMap; use std::fs::File; use std::path::PathBuf; -use rayon::prelude::*; -use serde_json::Value; fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result { let n_head = match n_head_kv { @@ -464,7 +473,7 @@ fn permute(weights: &Tensor, n_head: usize, n_head_kv: Option) -> Result< let permuted = weights .reshape(new_shape)? - .transpose(1, 2)? + .transpose(1, 2)? .reshape(weights.shape())?; Ok(permuted) @@ -481,7 +490,9 @@ fn run_quantize_safetensors( let dtype = q.dtype(); let block_size = dtype.block_size(); - let metadata_file = in_files.iter().find(|f| f.to_string_lossy().ends_with("config.json")); + let metadata_file = in_files + .iter() + .find(|f| f.to_string_lossy().ends_with("config.json")); let mut qtensors = Vec::new(); @@ -489,7 +500,6 @@ fn run_quantize_safetensors( let mut num_key_value_heads = 0; let mut architecture = String::new(); - let gguf_metadata = if let Some(metadata_file) = metadata_file { let metadata_content = std::fs::read_to_string(metadata_file)?; let metadata: serde_json::Value = serde_json::from_str(&metadata_content).unwrap(); @@ -499,16 +509,41 @@ fn run_quantize_safetensors( architecture = metadata["model_type"].as_str().unwrap().to_string(); vec![ - ("llama.attention.head_count", gguf_file::Value::from_u32(num_attention_heads as u32)), - ("llama.attention.head_count_kv", gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32)), - ("llama.block_count", gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32)), - ("llama.embedding_length", gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32)), - ("llama.attention.layer_norm_rms_epsilon", gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32)), - ("llama.rope.dimension_count", gguf_file::Value::from_u32( - (metadata["hidden_size"].as_u64().unwrap() as u32) / (metadata["num_attention_heads"].as_u64().unwrap() as u32), - )), - ("llama.rope.freq_base", gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32)), - ("general.architecture", gguf_file::Value::from_string(architecture.clone())), + ( + "llama.attention.head_count", + gguf_file::Value::from_u32(num_attention_heads as u32), + ), + ( + "llama.attention.head_count_kv", + gguf_file::Value::from_u32(metadata["num_key_value_heads"].as_u64().unwrap() as u32), + ), + ( + "llama.block_count", + gguf_file::Value::from_u32(metadata["num_hidden_layers"].as_u64().unwrap() as u32), + ), + ( + "llama.embedding_length", + gguf_file::Value::from_u32(metadata["hidden_size"].as_u64().unwrap() as u32), + ), + ( + "llama.attention.layer_norm_rms_epsilon", + gguf_file::Value::from_f32(metadata["rms_norm_eps"].as_f64().unwrap() as f32), + ), + ( + "llama.rope.dimension_count", + gguf_file::Value::from_u32( + (metadata["hidden_size"].as_u64().unwrap() as u32) + / (metadata["num_attention_heads"].as_u64().unwrap() as u32), + ), + ), + ( + "llama.rope.freq_base", + gguf_file::Value::from_f32(metadata["rope_theta"].as_f64().unwrap() as f32), + ), + ( + "general.architecture", + gguf_file::Value::from_string(architecture.clone()), + ), ] } else { vec![] @@ -531,14 +566,13 @@ fn run_quantize_safetensors( let mut tensor = tensor; if should_quantize && bitnet_mode { - let is_bitnet_weight = - name.contains("self_attn.v_proj") || - name.contains("self_attn.q_proj") || - name.contains("self_attn.o_proj") || - name.contains("self_attn.k_proj") || - name.contains("mlp.down_proj") || - name.contains("mlp.up_proj") || - name.contains("mlp.gate_proj"); + let is_bitnet_weight = name.contains("self_attn.v_proj") + || name.contains("self_attn.q_proj") + || name.contains("self_attn.o_proj") + || name.contains("self_attn.k_proj") + || name.contains("mlp.down_proj") + || name.contains("mlp.up_proj") + || name.contains("mlp.gate_proj"); if is_bitnet_weight { println!(" unpacking {name} {tensor:?} {should_quantize}"); @@ -555,10 +589,18 @@ fn run_quantize_safetensors( match architecture.as_str() { "llama" => { if name.ends_with("self_attn.q_proj.weight") { - tensor = permute(&tensor, num_attention_heads as usize, Some(num_attention_heads as usize))?; + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_attention_heads as usize), + )?; } if name.ends_with("self_attn.k_proj.weight") { - tensor = permute(&tensor, num_attention_heads as usize, Some(num_key_value_heads as usize))?; + tensor = permute( + &tensor, + num_attention_heads as usize, + Some(num_key_value_heads as usize), + )?; } } _ => {} @@ -716,7 +758,15 @@ fn main() -> anyhow::Result<()> { bitnet_quantization, mode, bitnet_mode, - } => run_quantize(&in_file, out_file, quantization, mode, bitnet_quantization, bitnet_mode, &device)?, + } => run_quantize( + &in_file, + out_file, + quantization, + mode, + bitnet_quantization, + bitnet_mode, + &device, + )?, Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?, } Ok(()) From 9da753aad6ffb315e90e5d95ddac385689f11c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Wed, 1 Jan 2025 13:14:32 +0100 Subject: [PATCH 11/11] Q2B1 is the default quant on example --- .../examples/quantized-bitnet/main.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index a40f3a8747..3196ab881e 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -157,35 +157,35 @@ impl Args { let (repo, filename) = match self.which { Which::Falcon3_1bInstruct1_58 => ( "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", - "Falcon3-1B-Instruct-1.58bit-q2b0.gguf", + "Falcon3-1B-Instruct-1.58bit-q2b1.gguf", ), Which::Falcon3_3bInstruct1_58 => ( "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", - "Falcon3-3B-Instruct-1.58bit-q2b0.gguf", + "Falcon3-3B-Instruct-1.58bit-q2b1.gguf", ), Which::Falcon3_3b1_58 => ( "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", - "Falcon3-3B-Base-1.58bit-q2b0.gguf", + "Falcon3-3B-Base-1.58bit-q2b1.gguf", ), Which::Falcon3_7bInstruct1_58 => ( "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", - "Falcon3-7B-Instruct-1.58bit-q2b0.gguf", + "Falcon3-7B-Instruct-1.58bit-q2b1.gguf", ), Which::Falcon3_7b1_58 => ( "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", - "Falcon3-7B-Base-1.58bit-q2b0.gguf", + "Falcon3-7B-Base-1.58bit-q2b1.gguf", ), Which::Falcon3_10b1_58 => ( "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", - "Falcon3-10B-Base-1.58bit-q2b0.gguf", + "Falcon3-10B-Base-1.58bit-q2b1.gguf", ), Which::Falcon3_10bInstruct1_58 => ( "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", - "Falcon3-10B-Instruct-1.58bit-q2b0.gguf", + "Falcon3-10B-Instruct-1.58bit-q2b1.gguf", ), Which::Llama3_8b1_58 => ( "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", - "Llama3-8B-1.58-100B-tokens-q2b0.gguf", + "Llama3-8B-1.58-100B-tokens-q2b1.gguf", ), }; let revision = "main";