Skip to content

Commit

Permalink
Apply cargo fmt.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 9, 2024
1 parent f64d885 commit 8367e6a
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
4 changes: 2 additions & 2 deletions candle-examples/examples/llama-bitnet/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ fn main() -> Result<()> {
let config = config.into_config(args.use_flash_attn);

let filenames = match args.which {
| Which::BitnetB1_58Large => {
Which::BitnetB1_58Large => {
vec![api.get("model.safetensors")?]
}
| Which::Bitnet51_38_3B | Which::Bitnet51_58XL => {
Which::Bitnet51_38_3B | Which::Bitnet51_58XL => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
};
Expand Down
17 changes: 10 additions & 7 deletions candle-nn/src/bit_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn weight_quant(x: &Tensor) -> Result<Tensor> {
.abs()?
.mean_all()?
.clamp(1e-5, f32::INFINITY)?)?
.to_dtype(x.dtype())?;
.to_dtype(x.dtype())?;

let u = (x.broadcast_mul(&scale))?
.round()?
Expand All @@ -47,17 +47,16 @@ fn activation_quant(x: &Tensor) -> Result<Tensor> {
.max(D::Minus1)?
.max(D::Minus1)?
.clamp(1e-5, f32::INFINITY)?)?
.to_dtype(x.dtype())?;
.to_dtype(x.dtype())?;

let y = x
.broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)?
.clamp(-128.0, 127.0)?
.broadcast_div(&scale)?;
.broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)?
.clamp(-128.0, 127.0)?
.broadcast_div(&scale)?;

Ok(y)
}


impl BitLinear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let weight = weight_quant(&weight).unwrap();
Expand Down Expand Up @@ -109,7 +108,11 @@ pub fn bit_linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Resul
}

/// Create or initialize a new bit_linear layer without biases.
pub fn bit_linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<BitLinear> {
pub fn bit_linear_no_bias(
in_dim: usize,
out_dim: usize,
vb: crate::VarBuilder,
) -> Result<BitLinear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
Ok(BitLinear::new(ws, None))
Expand Down
4 changes: 2 additions & 2 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pub mod activation;
pub mod batch_norm;
pub mod bit_linear;
pub mod conv;
pub mod embedding;
pub mod encoding;
Expand All @@ -34,9 +35,9 @@ pub mod rotary_emb;
pub mod sequential;
pub mod var_builder;
pub mod var_map;
pub mod bit_linear;
pub use activation::{prelu, Activation, PReLU};
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use bit_linear::{bit_linear, bit_linear_b, bit_linear_no_bias, BitLinear};
pub use conv::{
conv1d, conv1d_no_bias, conv2d, conv2d_no_bias, conv_transpose1d, conv_transpose1d_no_bias,
conv_transpose2d, conv_transpose2d_no_bias, Conv1d, Conv1dConfig, Conv2d, Conv2dConfig,
Expand All @@ -48,7 +49,6 @@ pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use bit_linear::{bit_linear, BitLinear, bit_linear_b, bit_linear_no_bias};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/llama_bitnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//!
//! Implementation based on Hugging Face's [transformers](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)
use super::with_tracing::{bit_linear_no_bias as bit_linear, BitLinear, RmsNorm, Linear, linear};
use super::with_tracing::{bit_linear_no_bias as bit_linear, linear, BitLinear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::{collections::HashMap, f32::consts::PI};
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod llama_bitnet;
pub mod llava;
pub mod mamba;
pub mod marian;
Expand Down Expand Up @@ -110,4 +111,3 @@ pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;
pub mod yi;
pub mod llama_bitnet;

0 comments on commit 8367e6a

Please sign in to comment.