From 4c669dcac62c321365f9cb28bd96d694c9240117 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Carlos=20Garci=CC=81a?= Date: Tue, 31 Dec 2024 10:36:07 +0100 Subject: [PATCH] Add initial commit --- .../examples/quantized-bitnet/main.rs | 89 +++++++++++-------- candle-metal-kernels/src/lib.rs | 7 +- candle-metal-kernels/src/quantized.metal | 9 +- .../src/models/quantized_llama_bitnet.rs | 19 ++-- 4 files changed, 68 insertions(+), 56 deletions(-) diff --git a/candle-examples/examples/quantized-bitnet/main.rs b/candle-examples/examples/quantized-bitnet/main.rs index 3ae5e51f4..2b024b50b 100644 --- a/candle-examples/examples/quantized-bitnet/main.rs +++ b/candle-examples/examples/quantized-bitnet/main.rs @@ -28,34 +28,35 @@ enum Prompt { #[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] enum Which { - #[value(name = "falcon3-1b-1.58")] - Falcon3_1b1_58, + #[value(name = "falcon3-1b-instruct-1.58")] + Falcon3_1bInstruct1_58, + #[value(name = "falcon3-3b-instruct-1.58")] + Falcon3_3bInstruct1_58, #[value(name = "falcon3-3b-1.58")] Falcon3_3b1_58, + #[value(name = "falcon3-7b-instruct-1.58")] + Falcon3_7bInstruct1_58, #[value(name = "falcon3-7b-1.58")] Falcon3_7b1_58, + #[value(name = "falcon3-10b-instruct-1.58")] + Falcon3_10bInstruct1_58, #[value(name = "falcon3-10b-1.58")] Falcon3_10b1_58, #[value(name = "llama3-8b-1.58")] Llama3_8b1_58, } -impl Which { - fn is_falcon(&self) -> bool { - matches!(self, Self::Falcon3_1b1_58 | Self::Falcon3_3b1_58 | Self::Falcon3_7b1_58 | Self::Falcon3_10b1_58) - } - - fn is_llama(&self) -> bool { - matches!(self, Self::Llama3_8b1_58) - } - +impl Which { fn tokenizer_repo(&self) -> &'static str { match self { - Self::Falcon3_1b1_58 => "tiiuae/Falcon3-1B-Instruct-1.58bit", - Self::Falcon3_3b1_58 => "tiiuae/Falcon3-3B-Instruct-1.58bit", - Self::Llama3_8b1_58 => "HF1BitLLM/Llama3-8B-1.58-100B-tokens", - Self::Falcon3_10b1_58 => "tiiuae/Falcon3-10B-Base-1.58bit", - Self::Falcon3_7b1_58 => "tiiuae/Falcon3-7B-Instruct-1.58bit", + Self::Falcon3_1bInstruct1_58 => "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + Self::Falcon3_3bInstruct1_58 => "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + Self::Falcon3_3b1_58 => "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + Self::Falcon3_7bInstruct1_58 => "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + Self::Falcon3_10b1_58 => "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + Self::Falcon3_10bInstruct1_58 => "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + Self::Falcon3_7b1_58 => "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + Self::Llama3_8b1_58 => "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", } } } @@ -123,7 +124,7 @@ struct Args { repeat_last_n: usize, /// The model size to use. - #[arg(long, default_value = "falcon3-1b-1.58")] + #[arg(long, default_value = "falcon3-1b-instruct-1.58")] which: Which, /// Group-Query Attention, use 8 for the 70B version of LLaMAv2. @@ -154,25 +155,37 @@ impl Args { Some(config) => std::path::PathBuf::from(config), None => { let (repo, filename) = match self.which { - Which::Falcon3_1b1_58 => ( - "tiiuae/Falcon3-1B-Instruct-1.58bit", - "Falcon3-1B-Instruct-1.58bit.gguf", + Which::Falcon3_1bInstruct1_58 => ( + "nebuxcloud/Falcon3-1B-Instruct-1.58bit-GGUF", + "Falcon3-1B-Instruct-1.58bit-q2b0.gguf", + ), + Which::Falcon3_3bInstruct1_58 => ( + "nebuxcloud/Falcon3-3B-Instruct-1.58bit-GGUF", + "Falcon3-3B-Instruct-1.58bit-q2b0.gguf", ), Which::Falcon3_3b1_58 => ( - "tiiuae/Falcon3-3B-Instruct-1.58bit", - "Falcon3-3B-Instruct-1.58bit.gguf", + "nebuxcloud/Falcon3-3B-Base-1.58bit-GGUF", + "Falcon3-3B-Base-1.58bit-q2b0.gguf", ), - Which::Falcon3_10b1_58 => ( - "tiiuae/Falcon3-10B-Instruct-1.58bit", - "Falcon3-10B-Instruct-1.58bit.gguf", + Which::Falcon3_7bInstruct1_58 => ( + "nebuxcloud/Falcon3-7B-Instruct-1.58bit-GGUF", + "Falcon3-7B-Instruct-1.58bit-q2b0.gguf", ), Which::Falcon3_7b1_58 => ( - "tiiuae/Falcon3-7B-Instruct-1.58bit", - "Falcon3-7B-Instruct-1.58bit.gguf", + "nebuxcloud/Falcon3-7B-Base-1.58bit-GGUF", + "Falcon3-7B-Base-1.58bit-q2b0.gguf", + ), + Which::Falcon3_10b1_58 => ( + "nebuxcloud/Falcon3-10B-Base-1.58bit-GGUF", + "Falcon3-10B-Base-1.58bit-q2b0.gguf", + ), + Which::Falcon3_10bInstruct1_58 => ( + "nebuxcloud/Falcon3-10B-Instruct-1.58bit-GGUF", + "Falcon3-10B-Instruct-1.58bit-q2b0.gguf", ), Which::Llama3_8b1_58 => ( - "HF1BitLLM/Llama3-8B-1.58-100B-tokens", - "Llama3-8B-1.58bit.gguf", + "nebuxcloud/Llama3-8B-1.58-100B-tokens-GGUF", + "Llama3-8B-1.58-100B-tokens-q2b0.gguf", ), }; let revision = "main"; @@ -306,13 +319,7 @@ fn main() -> anyhow::Result<()> { } } - if args.which.is_llama() { - format!( - "<|start_header_id|>user<|end_header_id|>{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>" - ) - } else { - prompt - } + prompt } }; @@ -376,11 +383,17 @@ fn main() -> anyhow::Result<()> { } let eos_tokens = match args.which { - Which::Falcon3_10b1_58 | Which::Falcon3_7b1_58 | Which::Falcon3_3b1_58 | Which::Falcon3_1b1_58 => { + Which::Falcon3_10b1_58 | + Which::Falcon3_10bInstruct1_58 | + Which::Falcon3_7bInstruct1_58 | + Which::Falcon3_7b1_58 | + Which::Falcon3_3bInstruct1_58 | + Which::Falcon3_3b1_58 | + Which::Falcon3_1bInstruct1_58 => { vec!["<|endoftext|>"] } Which::Llama3_8b1_58 => { - vec!["<|eot_id|>"] + vec!["<|eot_id|>", "<|end_header_id|>", "<|start_header_id|>"] } }; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index fe82e8719..ddf0c6bb2 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2210,7 +2210,6 @@ pub fn call_quantized_matmul_mv_t( | GgmlDType::Q5_0 | GgmlDType::Q5_1 | GgmlDType::Q8_0 - | GgmlDType::Q2b0 | GgmlDType::Q8_1 => { let nth0 = 8; let nth1 = 8; @@ -2231,6 +2230,12 @@ pub fn call_quantized_matmul_mv_t( let align = 4; (nth0, nth1, align) } + GgmlDType::Q2b0 => { + let nth0 = 8; + let nth1 = 8; + let align = 8; + (nth0, nth1, align) + } GgmlDType::Q3K | GgmlDType::Q5K => { let nth0 = 2; let nth1 = 32; diff --git a/candle-metal-kernels/src/quantized.metal b/candle-metal-kernels/src/quantized.metal index ac860fd06..6972d3825 100644 --- a/candle-metal-kernels/src/quantized.metal +++ b/candle-metal-kernels/src/quantized.metal @@ -3544,10 +3544,11 @@ void kernel_mul_mv_q2b0_f32_impl( int bit = startBit + iBit; int bByte = bit >> 3; int bMask = 1 << (bit & 7); - int isPos = ((bx->qs[bByte] & bMask) != 0) ? 1 : 0; - int isNeg = ((bx->qd[bByte] & bMask) != 0) ? 1 : 0; - - sumq += float(isPos - isNeg) * yl[iBit]; + if ((bx->qs[bByte] & bMask) != 0) { + sumq += yl[iBit]; + } else if ((bx->qd[bByte] & bMask) != 0) { + sumq -= yl[iBit]; + } } sumf[row] += sumq; diff --git a/candle-transformers/src/models/quantized_llama_bitnet.rs b/candle-transformers/src/models/quantized_llama_bitnet.rs index e9a38a26a..bfc3113ce 100644 --- a/candle-transformers/src/models/quantized_llama_bitnet.rs +++ b/candle-transformers/src/models/quantized_llama_bitnet.rs @@ -62,20 +62,13 @@ impl BitQMatMul { Ok(Self { inner, span, weight_scale }) } - pub fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { - let target_dim = x.rank().saturating_sub(1); - - let max_abs = x.abs()?.max_keepdim(target_dim)?; - - let scale = (127.0/ &max_abs)?; - - let scaled_rounded = x - .broadcast_mul(&scale)? - .round()? - .clamp(-128f32, 127f32)?; - + fn activation_quant(&self, x: &Tensor) -> Result<(Tensor, Tensor)> { + let scale = x.abs()?.max_keepdim(D::Minus1)?.clamp(1e-5, f32::INFINITY)?; + let scale = (127.0 / scale)?; + + let y = (x.broadcast_mul(&scale))?.round()?.clamp(-128., 127.)?; - Ok((scaled_rounded, scale)) + Ok((y, scale)) } fn forward(&self, x: &Tensor) -> Result {