diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index 2a261d8bd..e55c2c912 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -32,6 +32,7 @@ enum Which { BitnetB1_58Large, Bitnet51_58XL, Bitnet51_38_3B, + Falcon3_7bInstruct158, } #[derive(Parser, Debug)] @@ -128,6 +129,7 @@ fn main() -> Result<()> { Which::BitnetB1_58Large => "1bitLLM/bitnet_b1_58-large", Which::Bitnet51_58XL => "1bitLLM/bitnet_b1_58-xl", Which::Bitnet51_38_3B => "1bitLLM/bitnet_b1_38-3b", + Which::Falcon3_7bInstruct158 => "tiiuae/Falcon3-7B-Instruct-1.58bit", }; str.to_string() }); @@ -141,7 +143,7 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - Which::BitnetB1_58Large => { + Which::Falcon3_7bInstruct158 | Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { diff --git a/candle-transformers/src/models/llama_bitnet.rs b/candle-transformers/src/models/llama_bitnet.rs index b76c77d42..1fc870cb7 100644 --- a/candle-transformers/src/models/llama_bitnet.rs +++ b/candle-transformers/src/models/llama_bitnet.rs @@ -4,7 +4,7 @@ //! //! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) -use super::with_tracing::{bit_linear_no_bias as bit_linear, linear, BitLinear, Linear, RmsNorm}; +use super::with_tracing::{bit_linear_no_bias as bit_linear, linear_no_bias as linear, BitLinear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI};