diff --git a/candle-examples/examples/llama-bitnet/main.rs b/candle-examples/examples/llama-bitnet/main.rs index fe093fe2a..2a261d8bd 100644 --- a/candle-examples/examples/llama-bitnet/main.rs +++ b/candle-examples/examples/llama-bitnet/main.rs @@ -141,10 +141,10 @@ fn main() -> Result<()> { let config = config.into_config(args.use_flash_attn); let filenames = match args.which { - | Which::BitnetB1_58Large => { + Which::BitnetB1_58Large => { vec![api.get("model.safetensors")?] } - | Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { + Which::Bitnet51_38_3B | Which::Bitnet51_58XL => { candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")? } }; diff --git a/candle-nn/src/bit_linear.rs b/candle-nn/src/bit_linear.rs index ded3c351c..b20e8dd71 100644 --- a/candle-nn/src/bit_linear.rs +++ b/candle-nn/src/bit_linear.rs @@ -31,7 +31,7 @@ fn weight_quant(x: &Tensor) -> Result { .abs()? .mean_all()? .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + .to_dtype(x.dtype())?; let u = (x.broadcast_mul(&scale))? .round()? @@ -47,17 +47,16 @@ fn activation_quant(x: &Tensor) -> Result { .max(D::Minus1)? .max(D::Minus1)? .clamp(1e-5, f32::INFINITY)?)? - .to_dtype(x.dtype())?; + .to_dtype(x.dtype())?; let y = x - .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? - .clamp(-128.0, 127.0)? - .broadcast_div(&scale)?; + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .clamp(-128.0, 127.0)? + .broadcast_div(&scale)?; Ok(y) } - impl BitLinear { pub fn new(weight: Tensor, bias: Option) -> Self { let weight = weight_quant(&weight).unwrap(); @@ -109,7 +108,11 @@ pub fn bit_linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Resul } /// Create or initialize a new bit_linear layer without biases. -pub fn bit_linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result { +pub fn bit_linear_no_bias( + in_dim: usize, + out_dim: usize, + vb: crate::VarBuilder, +) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(BitLinear::new(ws, None)) diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 2917c60eb..530e8fb2b 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -17,6 +17,7 @@ pub mod activation; pub mod batch_norm; +pub mod bit_linear; pub mod conv; pub mod embedding; pub mod encoding; @@ -34,9 +35,9 @@ pub mod rotary_emb; pub mod sequential; pub mod var_builder; pub mod var_map; -pub mod bit_linear; pub use activation::{prelu, Activation, PReLU}; pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig}; +pub use bit_linear::{bit_linear, bit_linear_b, bit_linear_no_bias, BitLinear}; pub use conv::{ conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig, @@ -48,7 +49,6 @@ pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; -pub use bit_linear::{bit_linear, BitLinear, bit_linear_b, bit_linear_no_bias}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; diff --git a/candle-transformers/src/models/llama_bitnet.rs b/candle-transformers/src/models/llama_bitnet.rs index c1b9ffd26..b76c77d42 100644 --- a/candle-transformers/src/models/llama_bitnet.rs +++ b/candle-transformers/src/models/llama_bitnet.rs @@ -4,7 +4,7 @@ //! //! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) -use super::with_tracing::{bit_linear_no_bias as bit_linear, BitLinear, RmsNorm, Linear, linear}; +use super::with_tracing::{bit_linear_no_bias as bit_linear, linear, BitLinear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use std::{collections::HashMap, f32::consts::PI}; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 94eb79679..e47174b37 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -48,6 +48,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod llama_bitnet; pub mod llava; pub mod mamba; pub mod marian; @@ -110,4 +111,3 @@ pub mod whisper; pub mod with_tracing; pub mod wuerstchen; pub mod yi; -pub mod llama_bitnet; \ No newline at end of file