From 91d46023e6cb9f5aec249a4712e034bb70718331 Mon Sep 17 00:00:00 2001 From: lynxeco Date: Sat, 9 Nov 2024 11:35:45 -0800 Subject: [PATCH 1/7] Adds support for stella_en_v5 embedding model -400M variant --- .../examples/stella-en-400m-v5/README.md | 46 + .../examples/stella-en-400m-v5/main.rs | 384 ++++++++ candle-transformers/src/models/mod.rs | 1 + .../src/models/stella_en_v5_400m.rs | 896 ++++++++++++++++++ 4 files changed, 1327 insertions(+) create mode 100644 candle-examples/examples/stella-en-400m-v5/README.md create mode 100644 candle-examples/examples/stella-en-400m-v5/main.rs create mode 100644 candle-transformers/src/models/stella_en_v5_400m.rs diff --git a/candle-examples/examples/stella-en-400m-v5/README.md b/candle-examples/examples/stella-en-400m-v5/README.md new file mode 100644 index 0000000000..ef1de31d09 --- /dev/null +++ b/candle-examples/examples/stella-en-400m-v5/README.md @@ -0,0 +1,46 @@ +--- +model-index: +- name: stella_en_400M_v5 +license: mit +--- + + +# Introduction + +The models are trained based on `Alibaba-NLP/gte-large-en-v1.5` and `Alibaba-NLP/gte-Qwen2-1.5B-instruct`. Thanks for +their contributions! + +**We simplify usage of prompts, providing two prompts for most general tasks, one is for s2p, another one is for s2s.** + + + +The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5).for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. + +```bash +$ cargo run --example stella-en-v5 --release --features + +Similarity======== + [[0.8398, 0.2990], + [0.3282, 0.8095]] + +Score: 0.83975387 +Query: What are some ways to reduce stress? +Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up. + + + +Score: 0.8095451 +Query: What are the benefits of drinking green tea? +Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. + + +``` + + +The models are finally trained by [MRL](https://arxiv.org/abs/2205.13147), so they have multiple dimensions: 512, 768, +1024, 2048, 4096, 6144 and 8192. + +## Supported options: +- `Stella_en_400m_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. + +- As per the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-400m-v5/main.rs b/candle-examples/examples/stella-en-400m-v5/main.rs new file mode 100644 index 0000000000..6774c62aa3 --- /dev/null +++ b/candle-examples/examples/stella-en-400m-v5/main.rs @@ -0,0 +1,384 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use std::path::Path; + +use anyhow::{ anyhow, Error as E, Result }; +use clap::Parser; + +use candle_transformers::models::stella_en_v5_400m::{ + Config, + EmbedDim as StellaEmbedDim, + EmbeddingModel, +}; + +use candle::{ DType, Device, IndexOp, Tensor }; +use candle_nn::VarBuilder; +use hf_hub::{ api::sync::Api, Repo }; +use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; + +struct Embedding { + model: EmbeddingModel, + device: Device, + tokenizer: Tokenizer, +} + +impl Embedding { + fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self { + Self { + model, + tokenizer, + device: device.clone(), + } + } + + fn encode(&mut self, task: EncodeTask, text: Option) -> Result<()> { + // Just shocasing embeddings, this has no real value + if let Some(text) = text { + let qry = task.query_preproc(&[text]); + let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?; + + let shape = (1, encoding.len()); + let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; + let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; + + let result = self.model.forward(&input, &mask)?; + println!("result shape: {:?}", result.shape()); + println!("embeddings: {result}"); + } else { + // Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers) + let queries = [ + "What are some ways to reduce stress?".to_string(), + "What are the benefits of drinking green tea?".to_string(), + ]; + + let docs = [ + "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(), + "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(), + ]; + + // We only encode the queries and not the data + let qry = task.query_preproc(&queries); + let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; + + let mut docs_encoded = self.tokenizer + .encode_batch(docs.to_vec(), true) + .map_err(|e| anyhow!(e))?; + + let qry_embed = { + // Now, we generate the tensors for the `input` and `mask` + let shape = (qry_encoded.len(), qry_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in qry_encoded.drain(..).enumerate() { + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + + let doc_embed = { + let shape = (docs_encoded.len(), docs_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in docs_encoded.drain(..).enumerate() { + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + println!("Query vectors:\n {qry_embed}\n"); + println!("Document vectors:\n {doc_embed}\n"); + + println!( + "Embed shapes======\nQuery: {:?}\nDocs: {:?}\n", + qry_embed.shape(), + doc_embed.shape() + ); // [2, 1024] for head dim `1024` + + let answer = self.similarity(&qry_embed, &doc_embed); + println!("Similarity========\n {}", answer.unwrap()); + + // a matmul to generate the `similarity` score + let res = qry_embed.matmul(&doc_embed.t()?)?; + for (k, v) in queries.iter().enumerate() { + let tnsr = res.get(k)?; + let max = tnsr.argmax(0)?.to_scalar::()?; + println!( + "\nScore: {}\nQuery: {}\nAnswer: {}\n\n", + tnsr.get(max as usize)?.to_scalar::()?, + v, + docs[k] + ); + } + } + + Ok(()) + } + + /// Computes the cosine similarity between two tensors of embeddings + /// Similar to sentence-transformers' similarity() function + pub fn similarity(&self, embeddings1: &Tensor, embeddings2: &Tensor) -> Result { + // Normalize the embeddings (L2 norm) + let norm1 = embeddings1.broadcast_div(&embeddings1.sqr()?.sum_keepdim(1)?.sqrt()?)?; + let norm2 = embeddings2.broadcast_div(&embeddings2.sqr()?.sum_keepdim(1)?.sqrt()?)?; + + // Compute cosine similarity: dot product of normalized vectors + Ok(norm1.matmul(&norm2.t()?)?) + } + + fn _encode_single_doc(&mut self) -> Result { + // Example document text + let doc = + "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(); + + // Encode the document + let encoding = self.tokenizer.encode(doc, true).map_err(|e| anyhow!(e))?; + + // Create input tensors + let shape = (1, encoding.len()); + let ids = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; + let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; + + // Get embeddings and print intermediate values + let embeddings = self.model.forward_norm(&ids, &mask)?; + + // Print the shape and first few values + println!("Document embedding shape: {:?}", embeddings.shape()); + println!("First few values: {:?}", embeddings.i(0)?.narrow(0, 0, 3)?.to_vec1::()?); + + // Return normalized embeddings + let norm_embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; + + Ok(norm_embeddings) + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum EmbedDim { + #[value(name = "256")] + Dim256, + #[value(name = "768")] + Dim768, + #[value(name = "1024")] + Dim1024, + #[value(name = "2048")] + Dim2048, + #[value(name = "4096")] + Dim4096, + #[value(name = "6144")] + Dim6144, + #[value(name = "8192")] + Dim8192, +} + +impl EmbedDim { + /// Returns dir path to the embed head weights int he repo + pub fn embed_dim_default_dir(&self) -> &'static str { + match self { + Self::Dim256 => "2_Dense_256", + Self::Dim768 => "2_Dense_768", + Self::Dim1024 => "2_Dense_1024", + Self::Dim2048 => "2_Dense_2048", + Self::Dim4096 => "2_Dense_4096", + Self::Dim6144 => "2_Dense_6144", + Self::Dim8192 => "2_Dense_8192", + } + } + + /// Resolves the `EmbedDim` for given variant + pub fn embed_dim(&self) -> StellaEmbedDim { + match self { + Self::Dim256 => StellaEmbedDim::Dim256, + Self::Dim768 => StellaEmbedDim::Dim768, + Self::Dim1024 => StellaEmbedDim::Dim1024, + Self::Dim2048 => StellaEmbedDim::Dim2048, + Self::Dim4096 => StellaEmbedDim::Dim4096, + Self::Dim6144 => StellaEmbedDim::Dim6144, + Self::Dim8192 => StellaEmbedDim::Dim8192, + } + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +pub enum EncodeTask { + /// `s2p` is the `retrieval` task + /// Default in this example + #[value(name = "s2p")] + S2P, + /// `s2s` is the semantic similarity task + #[value(name = "s2s")] + S2S, +} + +impl EncodeTask { + /// Preprocess a set of inputs basef on a template suggested by the model authors + /// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction + pub fn query_preproc(&self, txt: &[String]) -> Vec { + let instruct = match self { + Self::S2P => { + "Given a web search query, retrieve relevant passages that answer the query." + } + Self::S2S => "Retrieve semantically similar text.", + }; + + txt.iter() + .map(|s| format!("Instruct: {instruct}\nQuery: {s}")) + .collect::>() + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + query: Option, + + #[arg(long, default_value = "1024")] + embed_dim: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + base_weight_files: Option, + + #[arg(long)] + embed_head_weight_files: Option, + + /// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5) + /// `s2s`: Semantic textual similarity + /// `s2p`: Retrieval task - `Default` in this example + #[arg(long, default_value = "s2p")] + task: Option, +} + +// Tokenizer creation is super critical in our case. +// We are going to be `padding: Left` for each batch +fn create_tokenizer(tokenizer_file: &Path) -> Result { + let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + }) + ); + + Ok(tokenizer) +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let embed_dim = match args.embed_dim { + Some(d) => d, + None => EmbedDim::Dim1024, + }; + let repo = api.repo(Repo::model("dunzhang/stella_en_400M_v5".to_string())); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights + // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo + let base_weight_files = match args.base_weight_files { + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), + None => { vec![repo.get("model.safetensors")?] } + }; + + let embed_weight_files = match args.embed_head_weight_files { + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), + None => { + let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); + vec![repo.get(&head_w_path)?] + } + }; + + println!("retrieved the files in {:?}", start.elapsed()); + + // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding + let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + + let start = std::time::Instant::now(); + + let device = candle_examples::device(args.cpu)?; + let dtype = DType::F32; + + let base_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? + }; + // Embedding layer is always built on F32 for accuracy + let embed_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? + }; + + let model = EmbeddingModel::new(&Config::new_400m(embed_dim.embed_dim()), base_vb, embed_vb)?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut embedding = Embedding::new(model, tokenizer, &device); + + let task = args.task.map_or(EncodeTask::S2P, |t| t); + + // let _doc_esmbeddings = embedding._encode_single_doc()?; + embedding.encode(task, args.query)?; + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 23edf349ad..8c0564db84 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -85,6 +85,7 @@ pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; pub mod stella_en_v5; +pub mod stella_en_v5_400m; pub mod t5; pub mod trocr; pub mod vgg; diff --git a/candle-transformers/src/models/stella_en_v5_400m.rs b/candle-transformers/src/models/stella_en_v5_400m.rs new file mode 100644 index 0000000000..3b54d14234 --- /dev/null +++ b/candle-transformers/src/models/stella_en_v5_400m.rs @@ -0,0 +1,896 @@ +use candle::{ DType, Device, IndexOp, Module, Result, Tensor }; +use candle_nn::{ Activation, VarBuilder, layer_norm }; +use std::sync::Arc; +use std::time::Instant; + +use super::with_tracing::{ linear, linear_no_bias, Linear, RmsNorm }; + +#[derive(Debug, Clone, Copy)] +pub enum EmbedDim { + Dim256, + Dim768, + Dim1024, + Dim2048, + Dim4096, + Dim6144, + Dim8192, +} + +impl Default for EmbedDim { + fn default() -> Self { + Self::Dim1024 + } +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct EmbedHead { + pub in_features: usize, + pub out_features: usize, +} + +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub pad_token_id: usize, + pub hidden_dropout_prob: f64, + pub attention_probs_dropout_prob: f64, + pub layer_norm_eps: f64, + pub initializer_range: f64, + pub position_embedding_type: String, + pub scaling_factor: f64, + pub rope_theta: f64, + pub use_memory_efficient_attention: bool, + pub unpad_inputs: bool, + pub layer_norm_type: String, + pub logn_attention_scale: bool, + pub logn_attention_clip1: bool, + pub activation_fn: Activation, + pub embed_head: EmbedHead, +} + +impl Config { + pub fn new_400m(embed_dim: EmbedDim) -> Self { + let embed_head = EmbedHead { + in_features: 1024, + out_features: match embed_dim { + EmbedDim::Dim256 => 256, + EmbedDim::Dim768 => 768, + EmbedDim::Dim1024 => 1024, + EmbedDim::Dim2048 => 2048, + EmbedDim::Dim4096 => 4096, + EmbedDim::Dim6144 => 6144, + EmbedDim::Dim8192 => 8192, + }, + }; + + Self { + vocab_size: 30528, + hidden_size: 1024, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + max_position_embeddings: 8192, + type_vocab_size: 2, + pad_token_id: 0, + hidden_dropout_prob: 0.1, + attention_probs_dropout_prob: 0.0, + layer_norm_eps: 1e-12, + initializer_range: 0.02, + position_embedding_type: "rope".to_string(), + scaling_factor: 2.0, + rope_theta: 160000.0, + use_memory_efficient_attention: true, + unpad_inputs: false, + layer_norm_type: "layer_norm".to_string(), + logn_attention_scale: false, + logn_attention_clip1: false, + activation_fn: Activation::Gelu, + embed_head, + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, + _scaling_factor: f64, + _mixed_b: Option, + _dim: usize, + _base: f64, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let scaling_factor = cfg.scaling_factor; // Can be configured in Config if needed + let base = cfg.rope_theta; + + // Calculate scaled position embeddings + let scaled_max_seq_len = ((max_seq_len as f64) * scaling_factor) as usize; + + // Calculate inv_freq with NTK scaling + let inv_freq: Vec<_> = (0..dim / 2) + .map(|i| { + // Apply base scaling + let base = base * scaling_factor; + let freq = 1.0 / base.powf((2.0 * (i as f64)) / (dim as f64)); + + // Apply fixed NTK scaling + let freq = freq / scaling_factor.powf(2.0 / (dim as f64)); + + freq as f32 + }) + .collect(); + + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + + // Calculate position embeddings with scaled sequence length + let t = Tensor::arange(0u32, scaled_max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((scaled_max_seq_len, 1))?; + + let freqs = t.matmul(&inv_freq)?; + let emb = Tensor::cat(&[&freqs, &freqs], 1)?; + + Ok(Self { + sin: emb.sin()?, + cos: emb.cos()?, + _scaling_factor: scaling_factor, + _mixed_b: None, + _dim: dim, + _base: base, + }) + } +} + +#[derive(Debug, Clone)] +enum NormType { + LayerNorm(candle_nn::LayerNorm), + RmsNorm(RmsNorm), +} + +impl NormType { + fn forward(&self, x: &Tensor) -> Result { + match self { + Self::LayerNorm(ln) => ln.forward(x), + Self::RmsNorm(rms) => rms.forward(x), + } + } +} + +#[derive(Debug)] +pub struct Embeddings { + word_embeddings: candle_nn::Embedding, + position_embeddings: Option, + token_type_embeddings: Option, + layer_norm: NormType, + _padding_idx: usize, + _position_embedding_type: String, + rotary_emb: Option>, + position_ids: Option, +} + +impl Embeddings { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let word_embeddings = candle_nn::embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings") + )?; + + let position_embeddings = if cfg.position_embedding_type == "absolute" { + Some( + candle_nn::embedding( + cfg.max_position_embeddings, + cfg.hidden_size, + vb.pp("position_embeddings") + )? + ) + } else { + None + }; + + let token_type_embeddings = if cfg.type_vocab_size > 0 { + Some( + candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings") + )? + ) + } else { + None + }; + + let layer_norm = if cfg.layer_norm_type == "layer_norm" { + let weight = vb + .pp("LayerNorm") + .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; + let bias = vb + .pp("LayerNorm") + .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; + NormType::LayerNorm(candle_nn::LayerNorm::new(weight, bias, cfg.layer_norm_eps)) + } else { + NormType::RmsNorm( + RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? + ) + }; + + let rotary_emb = if cfg.position_embedding_type == "rope" { + Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) + } else { + None + }; + + let position_ids = if cfg.position_embedding_type == "absolute" { + Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) + } else { + None + }; + + Ok(Self { + word_embeddings, + position_embeddings, + token_type_embeddings, + layer_norm, + _padding_idx: cfg.pad_token_id, + _position_embedding_type: cfg.position_embedding_type.clone(), + rotary_emb, + position_ids, + }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor>, + inputs_embeds: Option<&Tensor>, + unpad_inputs: bool, + attention_mask: Option<&Tensor> + ) -> Result<(Tensor, Option, Option<(Tensor, Tensor)>, Option>)> { + let (batch_size, seq_length) = input_ids.dims2()?; + + let mut embeddings = match inputs_embeds { + Some(e) => e.clone(), + None => self.word_embeddings.forward(input_ids)?, + }; + + // Get position_ids first + let position_ids = if let Some(ids) = position_ids { + ids.clone() + } else { + // Get device from input_ids which is always available + let device = input_ids.device(); + + // Initialize position_ids if None + if self.position_ids.is_none() { + self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + } + + // Now check if we need to extend it + if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { + self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + } + + if unpad_inputs { + // For now, just use the same position IDs as padded case since we don't have lengths + self.position_ids + .as_ref() + .unwrap() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))? + } else { + self.position_ids + .as_ref() + .unwrap() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))? + } + }; + + // Get rotary embeddings if using RoPE + let rope_embeds = if let Some(rotary) = &self.rotary_emb { + // Get the cos and sin for this sequence length + let cos = rotary.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + let sin = rotary.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + + // Index using position_ids if needed + let position_ids = position_ids.flatten_all()?; + let cos = cos.index_select(&position_ids, 0)?; // Use index_select instead of i() + let sin = sin.index_select(&position_ids, 0)?; // Use index_select instead of i() + + Some((cos, sin)) + } else { + None + }; + + // Handle token type embeddings + if let Some(token_emb) = &self.token_type_embeddings { + let token_type_ids = if let Some(ids) = token_type_ids { + ids.clone() + } else { + position_ids.zeros_like()? // Use mul(0) equivalent + }; + if unpad_inputs { + todo!("Implement unpadded case"); + } else { + embeddings = embeddings.add(&token_emb.forward(&token_type_ids)?)?; + } + } + + // Handle absolute position embeddings + if let Some(pos_emb) = &self.position_embeddings { + let position_embeddings = pos_emb.forward(&position_ids)?; + embeddings = embeddings.add(&position_embeddings)?; + } + + let embeddings = self.layer_norm.forward(&embeddings)?; + + Ok((embeddings, attention_mask.cloned(), rope_embeds, None)) + } +} + +#[derive(Debug)] +struct NewAttention { + qkv_proj: Linear, + o_proj: Linear, + num_heads: usize, + head_dim: usize, + hidden_size: usize, + _use_memory_efficient_attention: bool, +} + +impl NewAttention { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let head_dim = hidden_sz / num_heads; + + let qkv_proj = linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?; + let o_proj = linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + + Ok(Self { + qkv_proj, + o_proj, + num_heads, + head_dim, + hidden_size: hidden_sz, + _use_memory_efficient_attention: cfg.use_memory_efficient_attention, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + rope_embeds: Option<&(Tensor, Tensor)>, + _attention_scale: Option<&Tensor> + ) -> Result { + let (b_sz, seq_len, _) = hidden_states.dims3()?; + + // QKV projection + let qkv = self.qkv_proj.forward(hidden_states)?; + + // Split into Q, K, V and reshape to match PyTorch shapes + let qkv = qkv.reshape((b_sz, seq_len, 3, self.num_heads, self.head_dim))?; + + // Get Q, K, V with shape [batch, seq_len, num_heads, head_dim] + let query_states = qkv.i((.., .., 0, .., ..))?.contiguous()?; + let key_states = qkv.i((.., .., 1, .., ..))?.contiguous()?; + let value_states = qkv.i((.., .., 2, .., ..))?.contiguous()?; + + // Apply RoPE if provided + let (query_states, key_states) = if let Some((cos, sin)) = rope_embeds { + apply_rotary_pos_emb(&query_states, &key_states, cos, sin)? + } else { + (query_states, key_states) + }; + + // Transpose for attention computation [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] + let query_states = query_states.transpose(1, 2)?.contiguous()?; + let key_states = key_states.transpose(1, 2)?.contiguous()?; + let value_states = value_states.transpose(1, 2)?.contiguous()?; + + // For key, we want to transpose the last two dimensions for the matmul + // Is this equivalent to PyTorch's transpose(-1, -2)? + let key_states_t = key_states.transpose(2, 3)?.contiguous()?; + + // Prepare tensors for batched matmul using matmul + // Reshape tensors to merge batch and head dimensions + let bsz = b_sz as usize; + let nh = self.num_heads as usize; + let s_len = seq_len as usize; + let h_dim = self.head_dim as usize; + + // Reshape tensors to [batch_size * num_heads, seq_len, head_dim] + let query_states_reshaped = query_states.reshape((bsz * nh, s_len, h_dim))?; + let key_states_t_reshaped = key_states_t.reshape((bsz * nh, h_dim, s_len))?; + + // Perform batched matmul using matmul + // The matmul should handle batch dimensions if tensors are 3D + let attn_weights = query_states_reshaped.matmul(&key_states_t_reshaped)?; + + // Reshape attn_weights back to [batch_size, num_heads, seq_len, seq_len] + let attn_weights = attn_weights.reshape((bsz, nh, s_len, s_len))?; + + // Scale attention scores + let scale = 1f32 / (self.head_dim as f32).sqrt(); + + let scale_tensor = Tensor::new(scale, attn_weights.device())? + .to_dtype(attn_weights.dtype())? + .broadcast_as(attn_weights.shape())?; + + // Multiply the attention weights by the scalar tensor + let attn_weights = attn_weights.mul(&scale_tensor)?; + + // Apply attention mask + let mut attn_weights = if let Some(bias) = attention_bias { + let attn_weights = attn_weights.broadcast_add(bias)?; + attn_weights + } else { + attn_weights + }; + + // Normalize attention scores + attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + + // Reshape value_states for batched matmul + let value_states_reshaped = value_states.reshape((bsz * nh, s_len, h_dim))?; + + // Reshape attn_weights to [batch_size * num_heads, seq_len, seq_len] + let attn_weights_reshaped = attn_weights.reshape((bsz * nh, s_len, s_len))?; + + // Compute attention output + let attn_output = attn_weights_reshaped.matmul(&value_states_reshaped)?; + + // Reshape attn_output back to [batch_size, num_heads, seq_len, head_dim] + let attn_output = attn_output.reshape((bsz, nh, s_len, h_dim))?; + + // Transpose back to [batch_size, seq_len, num_heads, head_dim] + let attn_output = attn_output.transpose(1, 2)?; + + // Project to final dimension + let attn_output = attn_output.reshape((b_sz, seq_len, self.hidden_size))?; + self.o_proj.forward(&attn_output) + } +} + +#[derive(Debug)] +struct NewGatedMLP { + up_gate_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl NewGatedMLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_size = cfg.intermediate_size; + let act_fn = cfg.activation_fn; + let up_gate_proj = linear_no_bias(hidden_sz, intermediate_size * 2, vb.pp("up_gate_proj"))?; + let down_proj = linear(intermediate_size, hidden_sz, vb.pp("down_proj"))?; + + Ok(Self { + up_gate_proj, + down_proj, + act_fn, + }) + } +} + +impl Module for NewGatedMLP { + fn forward(&self, xs: &Tensor) -> Result { + let up_gate = self.up_gate_proj.forward(xs)?; + + // Get the dimensions + let (_batch_size, _seq_len, hidden_dim) = up_gate.dims3()?; + let split_size = hidden_dim / 2; + + // Split along the last dimension (hidden_dim) + let up_states = up_gate.narrow(2, 0, split_size)?; + let gate = up_gate.narrow(2, split_size, split_size)?; + + // Apply activation to gate and multiply + let gate = gate.apply(&self.act_fn)?; + + let gated_states = up_states.mul(&gate)?; + + // Project back to hidden dimension + let output = self.down_proj.forward(&gated_states)?; + + Ok(output) + } +} + +#[derive(Debug)] +struct NewLayer { + attention: NewAttention, + mlp: NewGatedMLP, + attn_ln: NormType, + mlp_ln: NormType, +} + +impl NewLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention = NewAttention::new(cfg, vb.pp("attention"))?; + let mlp = NewGatedMLP::new(cfg, vb.pp("mlp"))?; + + let ln_eps = cfg.layer_norm_eps; + + // Use LayerNorm or RmsNorm based on config + let (attn_ln, mlp_ln) = if cfg.layer_norm_type == "layer_norm" { + let attn_ln = layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + vb.pp("attn_ln") + )?; + let mlp_ln = layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + vb.pp("mlp_ln") + )?; + (NormType::LayerNorm(attn_ln), NormType::LayerNorm(mlp_ln)) + } else { + let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; + let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; + (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) + }; + + Ok(Self { + attention, + mlp, + attn_ln, + mlp_ln, + }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + rope_embeds: Option<&(Tensor, Tensor)>, + attention_scale: Option<&Tensor> + ) -> Result { + // Store original input + let original = hidden_states; + + // Use normalized states for attention + let hidden_states = self.attention.forward( + original, + attention_bias, + rope_embeds, + attention_scale + )?; + + let hidden_states = original.add(&hidden_states)?; + + // Apply layer norm + let hidden_states = self.attn_ln.forward(&hidden_states)?; + + // Store residual + let residual = &hidden_states; + + // Pass through MLP + let hidden_states = self.mlp.forward(&hidden_states)?; + + // Add residual connection + let hidden_states = residual.add(&hidden_states)?; + + // Final layer norm + self.mlp_ln.forward(&hidden_states) + } +} + +#[derive(Debug)] +struct NewEncoder { + layers: Vec, +} + +impl NewEncoder { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("layer"); + for layer_idx in 0..cfg.num_hidden_layers { + layers.push(NewLayer::new(cfg, vb_l.pp(layer_idx))?); + } + Ok(Self { layers }) + } + + fn forward( + &self, + hidden_states: &Tensor, + attention_bias: Option<&Tensor>, + rope_embeds: Option<&(Tensor, Tensor)>, + attention_scale: Option<&Tensor> + ) -> Result { + let mut hidden_states = hidden_states.clone(); + + for layer in self.layers.iter() { + hidden_states = layer.forward( + &hidden_states, + attention_bias, + rope_embeds, + attention_scale + )?; + } + + Ok(hidden_states) + } +} + +#[derive(Debug)] +pub struct NewModel { + embeddings: Embeddings, + encoder: NewEncoder, + device: Device, + dtype: DType, + config: Config, +} + +impl NewModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("new"); + let embeddings = Embeddings::new(cfg, vb_m.pp("embeddings"))?; + let encoder = NewEncoder::new(cfg, vb_m.pp("encoder"))?; + Ok(Self { + embeddings, + encoder, + device: vb.device().clone(), + dtype: vb.dtype(), + config: cfg.clone(), + }) + } + + fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result { + let (b_sz, seq_len) = attn_mask.dims2()?; + let mask = attn_mask + .unsqueeze(1) + ? // [b_sz, 1, seq_len] + .unsqueeze(2) + ? // [b_sz, 1, 1, seq_len] + .broadcast_as((b_sz, 1, 1, seq_len))?; // [b_sz, 1, 1, seq_len] + + // Use a large negative value for mask instead of -0.0 + let on_true = mask.zeros_like()?.to_dtype(self.dtype)?; + let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)? + .broadcast_as(mask.shape())? + .to_dtype(self.dtype)?; + + mask.where_cond(&on_true, &on_false) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor> + ) -> Result { + let (batch_size, seq_length) = input_ids.dims2()?; + + // Get attention mask if not provided + let attention_mask = match attention_mask { + Some(mask) => mask.clone(), + None => Tensor::ones((batch_size, seq_length), self.dtype, &self.device)?, + }; + + // Prepare attention bias + let attention_bias = if seq_length <= 1 { + None + } else { + Some(self.prepare_attention_mask(&attention_mask)?) + }; + + // Get embeddings and rotary embeddings + let (hidden_states, _, rope_embeds, _) = self.embeddings.forward( + input_ids, + token_type_ids, + position_ids, + None, + self.config.unpad_inputs, + Some(&attention_mask) + )?; + + // Compute attention scale if needed + let attention_scale = if self.config.logn_attention_scale { + let scale = + attention_mask.sum_keepdim(1)?.log()? / + (self.config.max_position_embeddings as f64).ln(); + if self.config.logn_attention_clip1 { + let scale = scale?; + Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) + } else { + Some(scale?) + } + } else { + None + }; + + // Forward through encoder + let hidden_states = self.encoder.forward( + &hidden_states, + attention_bias.as_ref(), + rope_embeds.as_ref(), + attention_scale.as_ref() + )?; + + Ok(hidden_states) + } +} + +// Optional pooler implementation +#[derive(Debug)] +pub struct NewPooler { + dense: Linear, +} + +impl NewPooler { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + Ok(Self { dense }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let first_token = hidden_states.i((.., 0, ..))?; + let pooled = self.dense.forward(&first_token)?; + pooled.tanh() + } +} + +// Complete model with pooler +#[derive(Debug)] +pub struct NewModelWithPooler { + model: NewModel, + pooler: Option, +} + +impl NewModelWithPooler { + pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { + let vb_m = vb.pp("new"); + let model = NewModel::new(cfg, vb_m.pp("model"))?; + let pooler = if add_pooling_layer { + Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) + } else { + None + }; + Ok(Self { model, pooler }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + attention_mask: Option<&Tensor>, + token_type_ids: Option<&Tensor>, + position_ids: Option<&Tensor> + ) -> Result<(Tensor, Option)> { + let hidden_states = self.model.forward( + input_ids, + attention_mask, + token_type_ids, + position_ids + )?; + + let pooled_output = match &self.pooler { + Some(pooler) => Some(pooler.forward(&hidden_states)?), + None => None, + }; + + Ok((hidden_states, pooled_output)) + } +} + +#[derive(Debug)] +pub struct EmbeddingModel { + base_model: NewModel, + lm_head: Linear, +} + +impl EmbeddingModel { + pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result { + let base_model = NewModel::new(cfg, base_vb)?; + let lm_head = linear( + cfg.embed_head.in_features, + cfg.embed_head.out_features, + embed_vb.pp("linear") + )?; + + Ok(Self { + base_model, + lm_head, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.base_model.forward(input_ids, Some(mask), None, None)?; + let x = self.pool(&x, mask)?; + self.lm_head.forward(&x.to_dtype(DType::F32)?) + } + + pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.forward(input_ids, mask)?; + x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?) + } + fn pool(&self, x: &Tensor, mask: &Tensor) -> Result { + let mask = mask.to_dtype(x.dtype())?; + let (batch_size, seq_len, hidden_dim) = x.dims3()?; + let mask_expanded = mask.unsqueeze(2)?.broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim] + let x = x.mul(&mask_expanded)?; + let sum_mask = mask.sum(1)?.unsqueeze(1)?.expand((batch_size, hidden_dim))?; + x.sum(1)? / sum_mask + } +} + +pub fn time_run(f: F) -> (T, std::time::Duration) where F: FnOnce() -> T { + let start = Instant::now(); + let result = f(); + let duration = start.elapsed(); + (result, duration) +} + +fn apply_rotary_pos_emb( + q: &Tensor, + k: &Tensor, + cos: &Tensor, + sin: &Tensor +) -> Result<(Tensor, Tensor)> { + let cos = cos.to_dtype(q.dtype())?; + let sin = sin.to_dtype(q.dtype())?; + + let (batch_size, seq_len, num_heads, head_dim) = q.dims4()?; + let half_dim = head_dim / 2; + + // Reshape q and k to split the head dim for rotation + let q_split = q.chunk(2, 3)?; // Split along head_dim + let k_split = k.chunk(2, 3)?; + + let q1 = &q_split[0]; + let q2 = &q_split[1]; + let k1 = &k_split[0]; + let k2 = &k_split[1]; + + // Handle cos/sin for the sequence length we have + let cos = cos.narrow(0, 0, seq_len)?; + let sin = sin.narrow(0, 0, seq_len)?; + + // Reshape cos/sin to match the dimensions we need + let cos = cos + .reshape((seq_len, head_dim))? + .chunk(2, 1)? + [0].reshape((seq_len, 1, half_dim))? + .broadcast_as((seq_len, num_heads, half_dim))? + .unsqueeze(0)? + .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; + + let sin = sin + .reshape((seq_len, head_dim))? + .chunk(2, 1)? + [0].reshape((seq_len, 1, half_dim))? + .broadcast_as((seq_len, num_heads, half_dim))? + .unsqueeze(0)? + .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; + + // Apply rotation using the formulas: + // q = q * cos + rotate_half(q) * sin + // k = k * cos + rotate_half(k) * sin + let q_out = Tensor::cat( + &[&q1.mul(&cos)?.sub(&q2.mul(&sin)?)?, &q2.mul(&cos)?.add(&q1.mul(&sin)?)?], + 3 + )?; + + let k_out = Tensor::cat( + &[&k1.mul(&cos)?.sub(&k2.mul(&sin)?)?, &k2.mul(&cos)?.add(&k1.mul(&sin)?)?], + 3 + )?; + + Ok((q_out, k_out)) +} From f137f69a80f68e18385d06646f95cab10249efe1 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Mon, 18 Nov 2024 02:26:55 +0530 Subject: [PATCH 2/7] Unified stella --- .../src/models/stella_en_v5.rs | 119 +++- .../src/models/stella_en_v5_400m.rs | 524 +++++++++--------- 2 files changed, 365 insertions(+), 278 deletions(-) diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 9d933fade5..419ee56bda 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -3,29 +3,45 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; + // internal representation for identifying which model is being used + #[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub enum ModelVariant { + Large, // 1.5B + Small // 400M +} + +impl Default for ModelVariant { + fn default() -> Self { + Self::Large + } +} + // Same as `qwen2` family of models with the exception being the `embed_head` // The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head` -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct Config { + pub variant: ModelVariant, pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, - pub num_key_value_heads: usize, pub max_position_embeddings: usize, - pub max_window_layers: usize, - pub tie_word_embeddings: bool, pub rope_theta: f64, - pub rms_norm_eps: f64, - pub hidden_act: Activation, pub embed_head: EmbedHead, + pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M + pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M + // Unique to 1.5B + pub num_key_value_heads: usize, + // Unique to 400M + pub type_vocab_size: usize, + pub scaling_factor: f64, } // Excerpt from `stella` model card: // `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions // Embed head represents the config for various embedding dims supported -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct EmbedHead { pub in_features: usize, pub out_features: usize, @@ -51,9 +67,9 @@ impl Default for EmbedDim { } impl EmbedDim { - pub fn config(&self) -> EmbedHead { + pub fn config(&self, in_features: usize) -> EmbedHead { EmbedHead { - in_features: 1536, + in_features, out_features: match &self { Self::Dim256 => 256, Self::Dim768 => 768, @@ -74,7 +90,8 @@ impl Config { // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here Self { - hidden_act: candle_nn::Activation::Silu, + variant: ModelVariant::Large, + activation_fn: candle_nn::Activation::Silu, vocab_size: 151646, hidden_size: 1536, intermediate_size: 8960, @@ -82,11 +99,32 @@ impl Config { num_attention_heads: 12, num_key_value_heads: 2, max_position_embeddings: 131072, - max_window_layers: 21, - tie_word_embeddings: false, + // max_window_layers: 21, + // tie_word_embeddings: false, rope_theta: 1000000., - rms_norm_eps: 1e-06, - embed_head: embed_dim.config(), + norm_eps: 1e-06, + embed_head: embed_dim.config(1536), + ..Default::default() + } + } + + /// Initialize new `stella_en_400M_v5` + pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self { + Self { + variant: ModelVariant::Small, + vocab_size: 30528, + hidden_size: 1024, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + max_position_embeddings: 8192, + type_vocab_size: 2, + norm_eps: 1e-12, + scaling_factor: 2.0, + rope_theta: 160000.0, + activation_fn: Activation::Gelu, + embed_head: embed_dim.config(1024), + ..Default::default() } } } @@ -100,23 +138,48 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = cfg.max_position_embeddings; + // Factoring in `scaling factor` for `400M` variant + let max_seq_len = if cfg.scaling_factor == 0. { + cfg.max_position_embeddings + } else { + ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize + }; + let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| { + // Scaled rope_theta for 400M variant + let rope_theta = if cfg.scaling_factor == 0. { cfg.rope_theta } else { cfg.rope_theta * cfg.scaling_factor }; + let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64); + + if cfg.scaling_factor != 0. { + freq /= cfg.scaling_factor.powf(2.0 / (dim as f64)) + } + + freq as f32 + + }) .collect(); + let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + + // Calculate position embeddings with scaled sequence length let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; - let freqs = t.matmul(&inv_freq)?; + let mut freqs = t.matmul(&inv_freq)?; + if cfg.variant == ModelVariant::Small { + freqs = Tensor::cat(&[&freqs, &freqs], 1)? + } + Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } - + + // TODO: re-visit this fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, 0, seq_len)?; @@ -127,6 +190,12 @@ impl RotaryEmbedding { } } +// impl Module for EmbeddingLayer { +// fn forward(&self, xs: &Tensor) -> Result { + +// } +// } + #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { @@ -147,7 +216,7 @@ impl MLP { gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.activation_fn, }) } } @@ -255,10 +324,10 @@ impl DecoderLayer { let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = - RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = RmsNorm::new( cfg.hidden_size, - cfg.rms_norm_eps, + cfg.norm_eps, vb.pp("post_attention_layernorm"), )?; Ok(Self { @@ -301,7 +370,7 @@ impl Model { let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let norm = RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; Ok(Self { embed_tokens, layers, @@ -343,7 +412,7 @@ impl Model { } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct EmbeddingModel { base_model: Model, lm_head: Linear, @@ -355,7 +424,7 @@ impl EmbeddingModel { let lm_head = linear( cfg.embed_head.in_features, cfg.embed_head.out_features, - embed_vb.pp("linear"), + embed_vb.pp("linear") )?; Ok(Self { @@ -396,4 +465,4 @@ impl EmbeddingModel { .expand((batch_size, hidden_dim))?; x.sum(1)? / sum_mask } -} +} \ No newline at end of file diff --git a/candle-transformers/src/models/stella_en_v5_400m.rs b/candle-transformers/src/models/stella_en_v5_400m.rs index 3b54d14234..bbe16377d6 100644 --- a/candle-transformers/src/models/stella_en_v5_400m.rs +++ b/candle-transformers/src/models/stella_en_v5_400m.rs @@ -1,9 +1,9 @@ use candle::{ DType, Device, IndexOp, Module, Result, Tensor }; -use candle_nn::{ Activation, VarBuilder, layer_norm }; +use candle_nn::{ layer_norm, Activation, LayerNorm, VarBuilder }; use std::sync::Arc; use std::time::Instant; -use super::with_tracing::{ linear, linear_no_bias, Linear, RmsNorm }; +use super::with_tracing::{ linear, linear_no_bias, Linear }; #[derive(Debug, Clone, Copy)] pub enum EmbedDim { @@ -37,19 +37,19 @@ pub struct Config { pub num_attention_heads: usize, pub max_position_embeddings: usize, pub type_vocab_size: usize, - pub pad_token_id: usize, - pub hidden_dropout_prob: f64, - pub attention_probs_dropout_prob: f64, - pub layer_norm_eps: f64, - pub initializer_range: f64, - pub position_embedding_type: String, + // pub pad_token_id: usize, + // pub hidden_dropout_prob: f64, + // pub attention_probs_dropout_prob: f64, + pub norm_eps: f64, + // pub initializer_range: f64, + // pub position_embedding_type: String, pub scaling_factor: f64, pub rope_theta: f64, - pub use_memory_efficient_attention: bool, - pub unpad_inputs: bool, - pub layer_norm_type: String, - pub logn_attention_scale: bool, - pub logn_attention_clip1: bool, + // pub use_memory_efficient_attention: bool, + // pub unpad_inputs: bool, + // pub layer_norm_type: String, + // pub logn_attention_scale: bool, + // pub logn_attention_clip1: bool, pub activation_fn: Activation, pub embed_head: EmbedHead, } @@ -77,19 +77,19 @@ impl Config { num_attention_heads: 16, max_position_embeddings: 8192, type_vocab_size: 2, - pad_token_id: 0, - hidden_dropout_prob: 0.1, - attention_probs_dropout_prob: 0.0, - layer_norm_eps: 1e-12, - initializer_range: 0.02, - position_embedding_type: "rope".to_string(), + // pad_token_id: 0, + // hidden_dropout_prob: 0.1, + // attention_probs_dropout_prob: 0.0, + norm_eps: 1e-12, + // initializer_range: 0.02, + // position_embedding_type: "rope".to_string(), scaling_factor: 2.0, rope_theta: 160000.0, - use_memory_efficient_attention: true, - unpad_inputs: false, - layer_norm_type: "layer_norm".to_string(), - logn_attention_scale: false, - logn_attention_clip1: false, + // use_memory_efficient_attention: true, + // unpad_inputs: false, + // layer_norm_type: "layer_norm".to_string(), + // logn_attention_scale: false, + // logn_attention_clip1: false, activation_fn: Activation::Gelu, embed_head, } @@ -100,10 +100,10 @@ impl Config { struct RotaryEmbedding { sin: Tensor, cos: Tensor, - _scaling_factor: f64, - _mixed_b: Option, - _dim: usize, - _base: f64, + // _scaling_factor: f64, + // _mixed_b: Option, + // _dim: usize, + // _base: f64, } impl RotaryEmbedding { @@ -144,10 +144,10 @@ impl RotaryEmbedding { Ok(Self { sin: emb.sin()?, cos: emb.cos()?, - _scaling_factor: scaling_factor, - _mixed_b: None, - _dim: dim, - _base: base, + // _scaling_factor: scaling_factor, + // _mixed_b: None, + // _dim: dim, + // _base: base, }) } } @@ -155,14 +155,14 @@ impl RotaryEmbedding { #[derive(Debug, Clone)] enum NormType { LayerNorm(candle_nn::LayerNorm), - RmsNorm(RmsNorm), + // RmsNorm(RmsNorm), } impl NormType { fn forward(&self, x: &Tensor) -> Result { match self { Self::LayerNorm(ln) => ln.forward(x), - Self::RmsNorm(rms) => rms.forward(x), + // Self::RmsNorm(rms) => rms.forward(x), } } } @@ -170,13 +170,13 @@ impl NormType { #[derive(Debug)] pub struct Embeddings { word_embeddings: candle_nn::Embedding, - position_embeddings: Option, - token_type_embeddings: Option, - layer_norm: NormType, - _padding_idx: usize, - _position_embedding_type: String, - rotary_emb: Option>, - position_ids: Option, + // position_embeddings: Option, + token_type_embeddings: candle_nn::Embedding, + layer_norm: LayerNorm, + // _padding_idx: usize, + // _position_embedding_type: String, + rotary_emb: Arc, + position_ids: Tensor, } impl Embeddings { @@ -187,63 +187,70 @@ impl Embeddings { vb.pp("word_embeddings") )?; - let position_embeddings = if cfg.position_embedding_type == "absolute" { - Some( - candle_nn::embedding( - cfg.max_position_embeddings, - cfg.hidden_size, - vb.pp("position_embeddings") - )? - ) - } else { - None - }; - - let token_type_embeddings = if cfg.type_vocab_size > 0 { - Some( - candle_nn::embedding( - cfg.type_vocab_size, - cfg.hidden_size, - vb.pp("token_type_embeddings") - )? - ) - } else { - None - }; - - let layer_norm = if cfg.layer_norm_type == "layer_norm" { + // let position_embeddings = if cfg.position_embedding_type == "absolute" { + // Some( + // candle_nn::embedding( + // cfg.max_position_embeddings, + // cfg.hidden_size, + // vb.pp("position_embeddings") + // )? + // ) + // } else { + // None + // }; + + let token_type_embeddings = candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings") + )?; + // if cfg.type_vocab_size > 0 { + // Some( + // candle_nn::embedding( + // cfg.type_vocab_size, + // cfg.hidden_size, + // vb.pp("token_type_embeddings") + // )? + // ) + // } else { + // None + // }; + + //if cfg.layer_norm_type == "layer_norm" { let weight = vb .pp("LayerNorm") .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; let bias = vb .pp("LayerNorm") .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; - NormType::LayerNorm(candle_nn::LayerNorm::new(weight, bias, cfg.layer_norm_eps)) - } else { - NormType::RmsNorm( - RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? - ) - }; - - let rotary_emb = if cfg.position_embedding_type == "rope" { - Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) - } else { - None - }; - - let position_ids = if cfg.position_embedding_type == "absolute" { - Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) - } else { - None - }; + let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); + // } else { + // NormType::RmsNorm( + // RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? + // ) + // }; + + // let rotary_emb = if cfg.position_embedding_type == "rope" { + // Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) + // } else { + // None + // }; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + + // let position_ids = if cfg.position_embedding_type == "absolute" { + // Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) + // } else { + // None + // }; + let position_ids = Tensor::arange(0u32, cfg.max_position_embeddings as u32, word_embeddings.embeddings().device())?; Ok(Self { word_embeddings, - position_embeddings, + // position_embeddings, token_type_embeddings, layer_norm, - _padding_idx: cfg.pad_token_id, - _position_embedding_type: cfg.position_embedding_type.clone(), + // _padding_idx: cfg.pad_token_id, + // _position_embedding_type: cfg.position_embedding_type.clone(), rotary_emb, position_ids, }) @@ -252,57 +259,64 @@ impl Embeddings { pub fn forward( &mut self, input_ids: &Tensor, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor>, - inputs_embeds: Option<&Tensor>, - unpad_inputs: bool, - attention_mask: Option<&Tensor> - ) -> Result<(Tensor, Option, Option<(Tensor, Tensor)>, Option>)> { + // token_type_ids: Option<&Tensor>, + // position_ids: Option<&Tensor>, + // inputs_embeds: Option<&Tensor>, + // unpad_inputs: bool, + // attention_mask: Option<&Tensor> + ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> { let (batch_size, seq_length) = input_ids.dims2()?; - let mut embeddings = match inputs_embeds { - Some(e) => e.clone(), - None => self.word_embeddings.forward(input_ids)?, - }; + let mut embeddings = self.word_embeddings.forward(input_ids)?; + // match inputs_embeds { + // Some(e) => e.clone(), + // None => self.word_embeddings.forward(input_ids)?, + // }; // Get position_ids first - let position_ids = if let Some(ids) = position_ids { - ids.clone() - } else { + // let position_ids = if let Some(ids) = position_ids { + // ids.clone() + // } else { // Get device from input_ids which is always available - let device = input_ids.device(); + // let device = input_ids.device(); // Initialize position_ids if None - if self.position_ids.is_none() { - self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - } + // if self.position_ids.is_none() { + // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + // } - // Now check if we need to extend it - if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { - self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - } + // // Now check if we need to extend it + // if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { + // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); + // } - if unpad_inputs { + // let position_ids = + /*if unpad_inputs { // For now, just use the same position IDs as padded case since we don't have lengths self.position_ids .as_ref() .unwrap() .narrow(0, 0, seq_length)? .expand((batch_size, seq_length))? - } else { - self.position_ids - .as_ref() - .unwrap() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))? - } - }; - - // Get rotary embeddings if using RoPE - let rope_embeds = if let Some(rotary) = &self.rotary_emb { + } else {*/ + // self.position_ids + // .as_ref() + // .unwrap() + // .narrow(0, 0, seq_length)? + // .expand((batch_size, seq_length))?; + // }; + // }; + + let position_ids = self.position_ids + .as_ref() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))?; + + + let rope_embeds = { // Get the cos and sin for this sequence length - let cos = rotary.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - let sin = rotary.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + let cos = self.rotary_emb.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] + let sin = self.rotary_emb.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] // Index using position_ids if needed let position_ids = position_ids.flatten_all()?; @@ -310,33 +324,38 @@ impl Embeddings { let sin = sin.index_select(&position_ids, 0)?; // Use index_select instead of i() Some((cos, sin)) - } else { - None }; + // // Get rotary embeddings if using RoPE + // let rope_embeds = if let Some(rotary) = &self.rotary_emb { + + // } else { + // None + // }; // Handle token type embeddings - if let Some(token_emb) = &self.token_type_embeddings { - let token_type_ids = if let Some(ids) = token_type_ids { - ids.clone() - } else { - position_ids.zeros_like()? // Use mul(0) equivalent - }; - if unpad_inputs { - todo!("Implement unpadded case"); - } else { - embeddings = embeddings.add(&token_emb.forward(&token_type_ids)?)?; - } - } + embeddings = embeddings.add(&self.token_type_embeddings.forward(&position_ids.zeros_like()?)?).unwrap(); + // if let Some(token_emb) = &self.token_type_embeddings { + // let token_type_ids = if let Some(ids) = token_type_ids { + // ids.clone() + // } else { + // position_ids.zeros_like()? // Use mul(0) equivalent + // }; + // if unpad_inputs { + // todo!("Implement unpadded case"); + // } else { + // embeddings = embeddings.add(&token_emb.forward(&position_ids.zeros_like()?)?).unwrap(); + // } + // } // Handle absolute position embeddings - if let Some(pos_emb) = &self.position_embeddings { - let position_embeddings = pos_emb.forward(&position_ids)?; - embeddings = embeddings.add(&position_embeddings)?; - } + // if let Some(pos_emb) = &self.position_embeddings { + // let position_embeddings = pos_emb.forward(&position_ids)?; + // embeddings = embeddings.add(&position_embeddings)?; + // } let embeddings = self.layer_norm.forward(&embeddings)?; - Ok((embeddings, attention_mask.cloned(), rope_embeds, None)) + Ok((embeddings, rope_embeds)) } } @@ -347,7 +366,7 @@ struct NewAttention { num_heads: usize, head_dim: usize, hidden_size: usize, - _use_memory_efficient_attention: bool, + // _use_memory_efficient_attention: bool, } impl NewAttention { @@ -365,7 +384,7 @@ impl NewAttention { num_heads, head_dim, hidden_size: hidden_sz, - _use_memory_efficient_attention: cfg.use_memory_efficient_attention, + // _use_memory_efficient_attention: cfg.use_memory_efficient_attention, }) } @@ -374,7 +393,7 @@ impl NewAttention { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - _attention_scale: Option<&Tensor> + // _attention_scale: Option<&Tensor> ) -> Result { let (b_sz, seq_len, _) = hidden_states.dims3()?; @@ -407,10 +426,10 @@ impl NewAttention { // Prepare tensors for batched matmul using matmul // Reshape tensors to merge batch and head dimensions - let bsz = b_sz as usize; - let nh = self.num_heads as usize; - let s_len = seq_len as usize; - let h_dim = self.head_dim as usize; + let bsz = b_sz; + let nh = self.num_heads; + let s_len = seq_len; + let h_dim = self.head_dim; // Reshape tensors to [batch_size * num_heads, seq_len, head_dim] let query_states_reshaped = query_states.reshape((bsz * nh, s_len, h_dim))?; @@ -435,8 +454,8 @@ impl NewAttention { // Apply attention mask let mut attn_weights = if let Some(bias) = attention_bias { - let attn_weights = attn_weights.broadcast_add(bias)?; - attn_weights + // let attn_weights = attn_weights.broadcast_add(bias)?; + attn_weights.broadcast_add(bias)? } else { attn_weights }; @@ -525,26 +544,28 @@ impl NewLayer { let attention = NewAttention::new(cfg, vb.pp("attention"))?; let mlp = NewGatedMLP::new(cfg, vb.pp("mlp"))?; - let ln_eps = cfg.layer_norm_eps; + // let ln_eps = cfg.layer_norm_eps; // Use LayerNorm or RmsNorm based on config - let (attn_ln, mlp_ln) = if cfg.layer_norm_type == "layer_norm" { + let (attn_ln, mlp_ln) = { let attn_ln = layer_norm( cfg.hidden_size, - candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, vb.pp("attn_ln") )?; let mlp_ln = layer_norm( cfg.hidden_size, - candle_nn::LayerNormConfig { eps: ln_eps, ..Default::default() }, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, vb.pp("mlp_ln") )?; (NormType::LayerNorm(attn_ln), NormType::LayerNorm(mlp_ln)) - } else { - let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; - let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; - (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) }; + // else + // { + // let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; + // let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; + // (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) + // }; Ok(Self { attention, @@ -559,7 +580,7 @@ impl NewLayer { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - attention_scale: Option<&Tensor> + // attention_scale: Option<&Tensor> ) -> Result { // Store original input let original = hidden_states; @@ -569,7 +590,7 @@ impl NewLayer { original, attention_bias, rope_embeds, - attention_scale + // attention_scale )?; let hidden_states = original.add(&hidden_states)?; @@ -611,7 +632,7 @@ impl NewEncoder { hidden_states: &Tensor, attention_bias: Option<&Tensor>, rope_embeds: Option<&(Tensor, Tensor)>, - attention_scale: Option<&Tensor> + // attention_scale: Option<&Tensor> ) -> Result { let mut hidden_states = hidden_states.clone(); @@ -620,7 +641,7 @@ impl NewEncoder { &hidden_states, attention_bias, rope_embeds, - attention_scale + // attention_scale )?; } @@ -634,7 +655,7 @@ pub struct NewModel { encoder: NewEncoder, device: Device, dtype: DType, - config: Config, + // config: Config, } impl NewModel { @@ -647,7 +668,7 @@ impl NewModel { encoder, device: vb.device().clone(), dtype: vb.dtype(), - config: cfg.clone(), + // config: cfg.clone(), }) } @@ -672,56 +693,53 @@ impl NewModel { pub fn forward( &mut self, input_ids: &Tensor, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor> + attention_mask: &Tensor, + // token_type_ids: Option<&Tensor>, + // position_ids: Option<&Tensor> ) -> Result { - let (batch_size, seq_length) = input_ids.dims2()?; + let (_, seq_length) = input_ids.dims2()?; // Get attention mask if not provided - let attention_mask = match attention_mask { - Some(mask) => mask.clone(), - None => Tensor::ones((batch_size, seq_length), self.dtype, &self.device)?, - }; + // let attention_mask = mask; // Prepare attention bias let attention_bias = if seq_length <= 1 { None } else { - Some(self.prepare_attention_mask(&attention_mask)?) + Some(self.prepare_attention_mask(attention_mask)?) }; // Get embeddings and rotary embeddings - let (hidden_states, _, rope_embeds, _) = self.embeddings.forward( + let (hidden_states, rope_embeds) = self.embeddings.forward( input_ids, - token_type_ids, - position_ids, - None, - self.config.unpad_inputs, - Some(&attention_mask) + // token_type_ids, + // position_ids, + // None, + // self.config.unpad_inputs, + // Some(&attention_mask) )?; // Compute attention scale if needed - let attention_scale = if self.config.logn_attention_scale { - let scale = - attention_mask.sum_keepdim(1)?.log()? / - (self.config.max_position_embeddings as f64).ln(); - if self.config.logn_attention_clip1 { - let scale = scale?; - Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) - } else { - Some(scale?) - } - } else { - None - }; + // let attention_scale = if self.config.logn_attention_scale { + // let scale = + // attention_mask.sum_keepdim(1)?.log()? / + // (self.config.max_position_embeddings as f64).ln(); + // if self.config.logn_attention_clip1 { + // let scale = scale?; + // Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) + // } else { + // Some(scale?) + // } + // } else { + // None + // }; // Forward through encoder let hidden_states = self.encoder.forward( &hidden_states, attention_bias.as_ref(), rope_embeds.as_ref(), - attention_scale.as_ref() + // attention_scale.as_ref() )?; Ok(hidden_states) @@ -729,65 +747,65 @@ impl NewModel { } // Optional pooler implementation -#[derive(Debug)] -pub struct NewPooler { - dense: Linear, -} - -impl NewPooler { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; - Ok(Self { dense }) - } - - pub fn forward(&self, hidden_states: &Tensor) -> Result { - let first_token = hidden_states.i((.., 0, ..))?; - let pooled = self.dense.forward(&first_token)?; - pooled.tanh() - } -} - -// Complete model with pooler -#[derive(Debug)] -pub struct NewModelWithPooler { - model: NewModel, - pooler: Option, -} - -impl NewModelWithPooler { - pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { - let vb_m = vb.pp("new"); - let model = NewModel::new(cfg, vb_m.pp("model"))?; - let pooler = if add_pooling_layer { - Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) - } else { - None - }; - Ok(Self { model, pooler }) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - attention_mask: Option<&Tensor>, - token_type_ids: Option<&Tensor>, - position_ids: Option<&Tensor> - ) -> Result<(Tensor, Option)> { - let hidden_states = self.model.forward( - input_ids, - attention_mask, - token_type_ids, - position_ids - )?; - - let pooled_output = match &self.pooler { - Some(pooler) => Some(pooler.forward(&hidden_states)?), - None => None, - }; - - Ok((hidden_states, pooled_output)) - } -} +// #[derive(Debug)] +// pub struct NewPooler { +// dense: Linear, +// } + +// impl NewPooler { +// pub fn new(cfg: &Config, vb: VarBuilder) -> Result { +// let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; +// Ok(Self { dense }) +// } + +// pub fn forward(&self, hidden_states: &Tensor) -> Result { +// let first_token = hidden_states.i((.., 0, ..))?; +// let pooled = self.dense.forward(&first_token)?; +// pooled.tanh() +// } +// } + +// // Complete model with pooler +// #[derive(Debug)] +// pub struct NewModelWithPooler { +// model: NewModel, +// pooler: Option, +// } + +// impl NewModelWithPooler { +// pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { +// let vb_m = vb.pp("new"); +// let model = NewModel::new(cfg, vb_m.pp("model"))?; +// let pooler = if add_pooling_layer { +// Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) +// } else { +// None +// }; +// Ok(Self { model, pooler }) +// } + +// pub fn forward( +// &mut self, +// input_ids: &Tensor, +// attention_mask: Option<&Tensor>, +// token_type_ids: Option<&Tensor>, +// position_ids: Option<&Tensor> +// ) -> Result<(Tensor, Option)> { +// let hidden_states = self.model.forward( +// input_ids, +// attention_mask, +// token_type_ids, +// position_ids +// )?; + +// let pooled_output = match &self.pooler { +// Some(pooler) => Some(pooler.forward(&hidden_states)?), +// None => None, +// }; + +// Ok((hidden_states, pooled_output)) +// } +// } #[derive(Debug)] pub struct EmbeddingModel { @@ -811,7 +829,7 @@ impl EmbeddingModel { } pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.base_model.forward(input_ids, Some(mask), None, None)?; + let x = self.base_model.forward(input_ids, mask)?;//, None, None)?; let x = self.pool(&x, mask)?; self.lm_head.forward(&x.to_dtype(DType::F32)?) } From 14fc76b3e9d091f69cfed5b3ddc22001c17d5429 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Mon, 18 Nov 2024 14:10:56 +0530 Subject: [PATCH 3/7] WIP: Unified Stella --- .../src/models/stella_en_v5.rs | 276 ++++++++++++++---- 1 file changed, 217 insertions(+), 59 deletions(-) diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 419ee56bda..47991d8867 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -1,10 +1,10 @@ use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Module, Result, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; use std::sync::Arc; // internal representation for identifying which model is being used - #[derive(Debug, Clone, PartialEq, serde::Deserialize)] + #[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)] pub enum ModelVariant { Large, // 1.5B Small // 400M @@ -99,8 +99,6 @@ impl Config { num_attention_heads: 12, num_key_value_heads: 2, max_position_embeddings: 131072, - // max_window_layers: 21, - // tie_word_embeddings: false, rope_theta: 1000000., norm_eps: 1e-06, embed_head: embed_dim.config(1536), @@ -190,17 +188,12 @@ impl RotaryEmbedding { } } -// impl Module for EmbeddingLayer { -// fn forward(&self, xs: &Tensor) -> Result { - -// } -// } - #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { + variant: ModelVariant, gate_proj: Linear, - up_proj: Linear, + up_proj: Option, // `up_proj` only for 1.5B variant down_proj: Linear, act_fn: Activation, } @@ -209,10 +202,26 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + + let (gate_proj, up_proj, down_proj) = match cfg.variant { + ModelVariant::Large => { + ( + linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?, + Some(linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?), + linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))? + ) + } + ModelVariant::Small => { + ( + linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?, + None, + linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))? + ) + } + }; + Ok(Self { + variant: cfg.variant, gate_proj, up_proj, down_proj, @@ -223,17 +232,36 @@ impl MLP { impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { - let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = xs.apply(&self.up_proj)?; + let up = self.gate_proj.forward(xs)?; + + let (lhs, rhs) = match self.variant { + ModelVariant::Large => { + let lhs = up.apply(&self.act_fn)?; + let rhs = xs.apply(self.up_proj.as_ref().unwrap())?; + + (lhs, rhs) + } + ModelVariant::Small => { + // Get the dimensions + let (_batch_size, _seq_len, hidden_dim) = up.dims3()?; + let split_size = hidden_dim / 2; + + // Split along the last dimension (hidden_dim) + let up_states = up.narrow(2, 0, split_size)?; + let gate = up.narrow(2, split_size, split_size)? + .apply(&self.act_fn)?; + + (up_states, gate) + } + }; + (lhs * rhs)?.apply(&self.down_proj) } } #[derive(Debug, Clone)] struct Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, + qkv_proj: Linear, o_proj: Linear, num_heads: usize, num_kv_heads: usize, @@ -241,6 +269,7 @@ struct Attention { head_dim: usize, hidden_size: usize, rotary_emb: Arc, + variant: ModelVariant } impl Attention { @@ -250,14 +279,30 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let num_kv_groups = num_heads / num_kv_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + + let qkv_proj = match cfg.variant { + ModelVariant::Large => { + // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize + // Weights + let q_w = vb.pp("q_proj").get((num_heads * head_dim, hidden_sz), "weight")?; + let k_w = vb.pp("k_proj").get((num_kv_heads * head_dim, hidden_sz), "weight")?; + let v_w = vb.pp("v_proj").get((num_kv_heads * head_dim, hidden_sz), "weight")?; + // Biases + let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?; + let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?; + let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?; + + let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; + let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?; + + Linear::from_weights(qkv_w, Some(qkv_b)) + } + ModelVariant::Small => linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))? + }; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; Ok(Self { - q_proj, - k_proj, - v_proj, + qkv_proj, o_proj, num_heads, num_kv_heads, @@ -265,15 +310,24 @@ impl Attention { head_dim, hidden_size: hidden_sz, rotary_emb, + variant: cfg.variant }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { let (b_sz, q_len, _) = xs.dims3()?; - let query_states = self.q_proj.forward(xs)?; - let key_states = self.k_proj.forward(xs)?; - let value_states = self.v_proj.forward(xs)?; + let qkv = self.qkv_proj.forward(xs)?; + + let (query_states, key_states, value_states) = { + let q_sz = self.num_heads * self.head_dim; + let kv_sz = self.num_kv_heads * self.head_dim; + + let q = qkv.narrow(D::Minus1, 0, q_sz)?; + let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?; + let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?; + (q, k, v) + }; let query_states = query_states .reshape((b_sz, q_len, self.num_heads, self.head_dim))? @@ -289,9 +343,18 @@ impl Attention { .rotary_emb .apply_rotary_emb_qkv(&query_states, &key_states)?; - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; - let value_states = - crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + // The 1.5B is expected to have grouped query attention + let (key_states, value_states) = if self.variant == ModelVariant::Large { + ( + crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?, + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()? + ) + } else { + ( + key_states, + value_states + ) + }; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); @@ -312,48 +375,135 @@ impl Attention { } #[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Attention, +enum NormType { + Layer(LayerNorm), + Rms(RmsNorm) +} + +#[derive(Debug, Clone)] +struct Layer { + variant: ModelVariant, + attention: Attention, mlp: MLP, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, + // For 1.5B: this is `input_layernorm` + // For 400M: this is `output_layernorm` + layernorm: NormType, + post_attention_layernorm: NormType, } -impl DecoderLayer { +impl Layer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let attention = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let input_layernorm = - RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = RmsNorm::new( - cfg.hidden_size, - cfg.norm_eps, - vb.pp("post_attention_layernorm"), - )?; + let (layernorm, post_attention_layernorm) = match cfg.variant { + ModelVariant::Large => { + ( + NormType::Rms(RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?), + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?) + ) + } + ModelVariant::Small => { + ( + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, + vb.pp("attn_ln") + )?), + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, + vb.pp("mlp_ln") + )?) + ) + } + }; + Ok(Self { - self_attn, + variant: cfg.variant, + attention, mlp, - input_layernorm, + layernorm, post_attention_layernorm, }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + // Here, the application of normalizations and activation calculations differ + // For Large [1.5B]: + // residual = x + // state = other_layernorm(xs) + // state = attention(state) + // state += residual + // residual = state + // state = mlp(attention_layernorm(state)) + // -> residual + state + // For Small [400M]: + // residual = x; + // state = attention(x) + // state += residual + // state = attention_layernorm(state) + // residual = state + // state = mlp(state) + // state += residual + // -> other_layernorm(state) let residual = xs; - let xs = self.input_layernorm.forward(xs)?; - let xs = self.self_attn.forward(&xs, attention_mask)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - residual + xs + + match self.variant { + ModelVariant::Large => { + let (attn_ln, input_ln) = if let ( + NormType::Rms(attn_ln), + NormType::Rms(input_ln) + ) = ( + &self.post_attention_layernorm, + &self.layernorm + ) { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg("Stella 1.5B expects RMSNorm".to_string())); + }; + + let xs = input_ln.forward(xs)?; + let xs = (self.attention.forward(&xs, attention_mask)? + residual)?; + + let residual = &xs; + let xs = xs.apply(attn_ln)?.apply(&self.mlp)?; + + residual + xs + } + ModelVariant::Small => { + let (attn_ln, output_ln) = if let ( + NormType::Layer(attn_ln), + NormType::Layer(input_ln) + ) = ( + &self.post_attention_layernorm, + &self.layernorm + ) { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg("Stella 400M expects RMSNorm".to_string())); + }; + + let xs = self.attention.forward(xs, attention_mask)?; + let xs = attn_ln.forward(&(xs + residual)?)?; + + let residual = &xs; + let xs = (self.mlp.forward(&xs)? + residual)?; + + output_ln.forward(&xs) + } + } } } #[derive(Debug, Clone)] pub struct Model { embed_tokens: candle_nn::Embedding, - layers: Vec, - norm: RmsNorm, + layers: Vec, + norm: Option, device: Device, dtype: DType, } @@ -361,21 +511,24 @@ pub struct Model { impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb_m = vb.pp("model"); + // 400M Notes: Embedding seems to follow different path? Investigate and integrate let embed_tokens = candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?; + let norm = match cfg.variant { + ModelVariant::Large => Some(RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?), + ModelVariant::Small => None + }; Ok(Self { embed_tokens, layers, norm, - // sliding_window: 0, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -408,7 +561,12 @@ impl Model { for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref())? } - xs.apply(&self.norm) + + if let Some(n) = &self.norm { + xs.apply(n) + } else { + Ok(xs) + } } } From c92f5699339058fdd3491b7a2444fb531a2e5d8a Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 19 Nov 2024 01:58:29 +0530 Subject: [PATCH 4/7] Combined stella for both 1.5B and 400M variants --- .../examples/stella-en-v5/README.md | 24 +- candle-examples/examples/stella-en-v5/main.rs | 67 ++++-- .../src/models/stella_en_v5.rs | 219 ++++++++++++++---- 3 files changed, 247 insertions(+), 63 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index 5fcc67c351..3a87b2956a 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. ```bash -$ cargo run --example stella-en-v5 --release --features +$ cargo run --example stella-en-v5 --release --features -- --which 1.5b > > Score: 0.8178786 @@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features > caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > > of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. > + +$ cargo run --example stella-en-v5 --release --features -- --which 400m + +> +> Score: 0.8397539 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> +> Score: 0.809545 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> ``` ## Supported options: -- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. +- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`. + +- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 2408262b1a..e5b4259e89 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -212,6 +212,14 @@ impl EncodeTask { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "1.5b")] + Large, + #[value(name = "400m")] + Small +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -219,6 +227,9 @@ struct Args { #[arg(long)] cpu: bool, + #[arg(long)] + which: Which, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -250,24 +261,36 @@ struct Args { // Tokenizer creation is super critical in our case. // We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result { +fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { - pad_id - } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); - }; - // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + if which == Which::Large { + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + } else { + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + }) + ); + } + Ok(tokenizer) } @@ -298,7 +321,13 @@ fn main() -> Result<()> { Some(d) => d, None => EmbedDim::Dim1024, }; - let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + + let (repo, cfg) = match args.which { + Which::Large => ("dunzhang/stella_en_1.5B_v5", Config::new_1_5_b_v5(embed_dim.embed_dim())), + Which::Small => ("dunzhang/stella_en_400M_v5", Config::new_400_m_v5(embed_dim.embed_dim())) + }; + + let repo = api.repo(Repo::model(repo.to_string())); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("tokenizer.json")?, @@ -330,7 +359,7 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?; let start = std::time::Instant::now(); @@ -344,7 +373,7 @@ fn main() -> Result<()> { unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; let model = EmbeddingModel::new( - &Config::new_1_5_b_v5(embed_dim.embed_dim()), + &cfg, base_vb, embed_vb, )?; diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 47991d8867..fb4d539a26 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -1,5 +1,5 @@ use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; use std::sync::Arc; @@ -143,6 +143,7 @@ impl RotaryEmbedding { ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize }; + // let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim }; let inv_freq: Vec<_> = (0..dim) .step_by(2) .map(|i| { @@ -166,10 +167,10 @@ impl RotaryEmbedding { let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; - let mut freqs = t.matmul(&inv_freq)?; - if cfg.variant == ModelVariant::Small { - freqs = Tensor::cat(&[&freqs, &freqs], 1)? - } + let freqs = t.matmul(&inv_freq)?; + // if cfg.variant == ModelVariant::Small { + // freqs = Tensor::cat(&[&freqs, &freqs], 1)? + // } Ok(Self { sin: freqs.sin()?, @@ -182,6 +183,7 @@ impl RotaryEmbedding { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, 0, seq_len)?; let sin = self.sin.narrow(0, 0, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) @@ -277,10 +279,10 @@ impl Attention { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let num_kv_groups = num_heads / num_kv_heads; + let num_kv_groups = if num_kv_heads > 0 { num_heads / num_kv_heads } else { 0 }; let head_dim = hidden_sz / num_heads; - let qkv_proj = match cfg.variant { + let (qkv_proj, o_proj) = match cfg.variant { ModelVariant::Large => { // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize // Weights @@ -295,12 +297,17 @@ impl Attention { let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?; - Linear::from_weights(qkv_w, Some(qkv_b)) + ( + Linear::from_weights(qkv_w, Some(qkv_b)), + linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))? + ) } - ModelVariant::Small => linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))? + ModelVariant::Small => ( + linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?, + linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))? + ) }; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; Ok(Self { qkv_proj, o_proj, @@ -319,30 +326,46 @@ impl Attention { let qkv = self.qkv_proj.forward(xs)?; - let (query_states, key_states, value_states) = { - let q_sz = self.num_heads * self.head_dim; - let kv_sz = self.num_kv_heads * self.head_dim; - - let q = qkv.narrow(D::Minus1, 0, q_sz)?; - let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?; - let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?; - (q, k, v) + let n_kv_heads = match self.variant { + ModelVariant::Large => self.num_kv_heads, + ModelVariant::Small => self.num_heads }; - let query_states = query_states - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let key_states = key_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let value_states = value_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + let (query_states, key_states, value_states) = match self.variant { + ModelVariant::Large => { + let q_sz = self.num_heads * self.head_dim; + let kv_sz = n_kv_heads * self.head_dim; + + let q = qkv.narrow(D::Minus1, 0, q_sz)? + .reshape((b_sz, q_len, self.num_heads, self.head_dim))?; + let k = qkv.narrow(D::Minus1, q_sz, kv_sz)? + .reshape((b_sz, q_len, n_kv_heads, self.head_dim))?; + let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)? + .reshape((b_sz, q_len, n_kv_heads, self.head_dim))?; + + (q, k, v) + } + ModelVariant::Small => { + // Split into Q, K, V and reshape to match PyTorch shapes + let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?; + + ( + qkv.i((.., .., 0, .., ..))?, + qkv.i((.., .., 1, .., ..))?, + qkv.i((.., .., 2, .., ..))? + ) + } + }; + + let query_states = query_states.transpose(1, 2)?.contiguous()?; + let key_states = key_states.transpose(1, 2)?.contiguous()?; + let value_states = value_states.transpose(1, 2)?.contiguous()?; let (query_states, key_states) = self .rotary_emb .apply_rotary_emb_qkv(&query_states, &key_states)?; + // The 1.5B is expected to have grouped query attention let (key_states, value_states) = if self.variant == ModelVariant::Large { ( @@ -358,15 +381,18 @@ impl Attention { let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; - + let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; + let attn_weights = (attn_weights * scale)?; + let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? }; + attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))? @@ -393,7 +419,7 @@ struct Layer { impl Layer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let attention = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let attention = Attention::new(rotary_emb, cfg, vb.pp(if cfg.variant == ModelVariant::Large { "self_attn"} else { "attention" }))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let (layernorm, post_attention_layernorm) = match cfg.variant { ModelVariant::Large => { @@ -411,12 +437,12 @@ impl Layer { NormType::Layer(layer_norm( cfg.hidden_size, candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("attn_ln") + vb.pp("mlp_ln") )?), NormType::Layer(layer_norm( cfg.hidden_size, candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("mlp_ln") + vb.pp("attn_ln") )?) ) } @@ -487,8 +513,8 @@ impl Layer { return Err(candle::error::Error::Msg("Stella 400M expects RMSNorm".to_string())); }; - let xs = self.attention.forward(xs, attention_mask)?; - let xs = attn_ln.forward(&(xs + residual)?)?; + let xs = (self.attention.forward(xs, attention_mask)? + residual)?; + let xs = attn_ln.forward(&xs)?; let residual = &xs; let xs = (self.mlp.forward(&xs)? + residual)?; @@ -499,9 +525,110 @@ impl Layer { } } +#[derive(Debug, Clone)] +pub struct Embeddings { + variant: ModelVariant, + // For 1.5B: this is the `embed_tokens` + // For 400M: this is the `word_embeddings` + embeddings: candle_nn::Embedding, + // folloing are specifically for 400M + token_type_embeddings: Option, + layer_norm: Option, + position_ids: Option +} + +impl Embeddings { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let ( + embeddings, + token_type_embeddings, + layer_norm, + position_ids + ) = match cfg.variant { + ModelVariant::Large => { + ( + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?, + None, + None, + None + ) + }, + ModelVariant::Small => { + let vb = vb.pp("embeddings"); + let weight = vb + .pp("LayerNorm") + .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; + let bias = vb + .pp("LayerNorm") + .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; + let dev = bias.device().clone(); + + let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); + + ( + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?, + Some(candle_nn::embedding(cfg.type_vocab_size, cfg.hidden_size, vb.pp("token_type_embeddings"))?), + Some(layer_norm), + Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, &dev)?) + ) + } + }; + + Ok(Self { + variant: cfg.variant, + embeddings, + token_type_embeddings, + layer_norm, + position_ids + }) + } +} + +impl Module for Embeddings { + fn forward(&self, xs: &Tensor) -> Result { + let embd = self.embeddings.forward(xs)?; + // For 1.5B just forward the embeddings + if self.variant == ModelVariant::Large { + return Ok(embd); + } + + let (token_type_embed, layer_norm, pos_ids) = if let ( + Some(token_type_embd), + Some(layer_norm), + Some(position_ids) + ) = ( + &self.token_type_embeddings, + &self.layer_norm, + &self.position_ids + ) { + ( + token_type_embd, + layer_norm, + position_ids + ) + } else { + return Err(Error::Msg("Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`".to_string())); + }; + + let (batch_size, seq_length) = xs.dims2()?; + + let pos_ids = pos_ids + .as_ref() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))?; + + layer_norm.forward( + &embd.add( + &token_type_embed.forward(&pos_ids.zeros_like()?)? + )? + ) + + } +} + #[derive(Debug, Clone)] pub struct Model { - embed_tokens: candle_nn::Embedding, + embeddings: Embeddings, layers: Vec, norm: Option, device: Device, @@ -510,13 +637,21 @@ pub struct Model { impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vb_m = vb.pp("model"); - // 400M Notes: Embedding seems to follow different path? Investigate and integrate - let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let vb_m = match cfg.variant { + ModelVariant::Large => vb.pp("model"), + ModelVariant::Small => vb.pp("new") + }; + // let embed_tokens = + // candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let embeddings = Embeddings::new(cfg, vb_m.clone())?; let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb_m.pp("layers"); + let vb_l = match cfg.variant { + ModelVariant::Large => vb_m.pp("layers"), + ModelVariant::Small => { + vb_m.pp("encoder").pp("layer") + } + }; for layer_idx in 0..cfg.num_hidden_layers { let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) @@ -526,7 +661,7 @@ impl Model { ModelVariant::Small => None }; Ok(Self { - embed_tokens, + embeddings, layers, norm, device: vb.device().clone(), @@ -557,7 +692,7 @@ impl Model { Some(self.prepare_attention_mask(mask)?) }; - let mut xs = self.embed_tokens.forward(input_ids)?; + let mut xs = self.embeddings.forward(input_ids)?; for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref())? } From ed7fd9bcfc507b09df6177ccf5588d2651a29433 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 19 Nov 2024 02:18:15 +0530 Subject: [PATCH 5/7] Cargo fmt for the CI --- candle-examples/examples/stella-en-v5/main.rs | 35 +- .../src/models/stella_en_v5.rs | 345 ++++++++++-------- 2 files changed, 206 insertions(+), 174 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index e5b4259e89..68ed7e70c6 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -217,7 +217,7 @@ enum Which { #[value(name = "1.5b")] Large, #[value(name = "400m")] - Small + Small, } #[derive(Parser, Debug)] @@ -272,7 +272,7 @@ fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { "Tokenizer doesn't contain expected `<|endoftext|>` token" )); }; - + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding tokenizer.with_padding(Some(PaddingParams { strategy: PaddingStrategy::BatchLongest, @@ -282,15 +282,12 @@ fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { ..Default::default() })); } else { - tokenizer.with_padding( - Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - }) - ); + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + })); } - Ok(tokenizer) } @@ -321,10 +318,16 @@ fn main() -> Result<()> { Some(d) => d, None => EmbedDim::Dim1024, }; - + let (repo, cfg) = match args.which { - Which::Large => ("dunzhang/stella_en_1.5B_v5", Config::new_1_5_b_v5(embed_dim.embed_dim())), - Which::Small => ("dunzhang/stella_en_400M_v5", Config::new_400_m_v5(embed_dim.embed_dim())) + Which::Large => ( + "dunzhang/stella_en_1.5B_v5", + Config::new_1_5_b_v5(embed_dim.embed_dim()), + ), + Which::Small => ( + "dunzhang/stella_en_400M_v5", + Config::new_400_m_v5(embed_dim.embed_dim()), + ), }; let repo = api.repo(Repo::model(repo.to_string())); @@ -372,11 +375,7 @@ fn main() -> Result<()> { let embed_vb = unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; - let model = EmbeddingModel::new( - &cfg, - base_vb, - embed_vb, - )?; + let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index fb4d539a26..08cd808c43 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -3,11 +3,11 @@ use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D}; use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; use std::sync::Arc; - // internal representation for identifying which model is being used - #[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)] +// internal representation for identifying which model is being used +#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)] pub enum ModelVariant { Large, // 1.5B - Small // 400M + Small, // 400M } impl Default for ModelVariant { @@ -29,7 +29,7 @@ pub struct Config { pub max_position_embeddings: usize, pub rope_theta: f64, pub embed_head: EmbedHead, - pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M + pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M // Unique to 1.5B pub num_key_value_heads: usize, @@ -148,7 +148,11 @@ impl RotaryEmbedding { .step_by(2) .map(|i| { // Scaled rope_theta for 400M variant - let rope_theta = if cfg.scaling_factor == 0. { cfg.rope_theta } else { cfg.rope_theta * cfg.scaling_factor }; + let rope_theta = if cfg.scaling_factor == 0. { + cfg.rope_theta + } else { + cfg.rope_theta * cfg.scaling_factor + }; let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64); if cfg.scaling_factor != 0. { @@ -156,7 +160,6 @@ impl RotaryEmbedding { } freq as f32 - }) .collect(); @@ -177,7 +180,7 @@ impl RotaryEmbedding { cos: freqs.cos()?, }) } - + // TODO: re-visit this fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; @@ -206,22 +209,22 @@ impl MLP { let intermediate_sz = cfg.intermediate_size; let (gate_proj, up_proj, down_proj) = match cfg.variant { - ModelVariant::Large => { - ( - linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?, - Some(linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?), - linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))? - ) - } - ModelVariant::Small => { - ( - linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?, - None, - linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))? - ) - } + ModelVariant::Large => ( + linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?, + Some(linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + )?), + linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), + ModelVariant::Small => ( + linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?, + None, + linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), }; - + Ok(Self { variant: cfg.variant, gate_proj, @@ -250,8 +253,7 @@ impl Module for MLP { // Split along the last dimension (hidden_dim) let up_states = up.narrow(2, 0, split_size)?; - let gate = up.narrow(2, split_size, split_size)? - .apply(&self.act_fn)?; + let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?; (up_states, gate) } @@ -271,7 +273,7 @@ struct Attention { head_dim: usize, hidden_size: usize, rotary_emb: Arc, - variant: ModelVariant + variant: ModelVariant, } impl Attention { @@ -279,33 +281,43 @@ impl Attention { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let num_kv_groups = if num_kv_heads > 0 { num_heads / num_kv_heads } else { 0 }; + let num_kv_groups = if num_kv_heads > 0 { + num_heads / num_kv_heads + } else { + 0 + }; let head_dim = hidden_sz / num_heads; - + let (qkv_proj, o_proj) = match cfg.variant { ModelVariant::Large => { // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize // Weights - let q_w = vb.pp("q_proj").get((num_heads * head_dim, hidden_sz), "weight")?; - let k_w = vb.pp("k_proj").get((num_kv_heads * head_dim, hidden_sz), "weight")?; - let v_w = vb.pp("v_proj").get((num_kv_heads * head_dim, hidden_sz), "weight")?; + let q_w = vb + .pp("q_proj") + .get((num_heads * head_dim, hidden_sz), "weight")?; + let k_w = vb + .pp("k_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; + let v_w = vb + .pp("v_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; // Biases let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?; let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?; let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?; - + let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?; - + ( Linear::from_weights(qkv_w, Some(qkv_b)), - linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))? + linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, ) } ModelVariant::Small => ( linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?, - linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))? - ) + linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, + ), }; Ok(Self { @@ -317,7 +329,7 @@ impl Attention { head_dim, hidden_size: hidden_sz, rotary_emb, - variant: cfg.variant + variant: cfg.variant, }) } @@ -328,20 +340,32 @@ impl Attention { let n_kv_heads = match self.variant { ModelVariant::Large => self.num_kv_heads, - ModelVariant::Small => self.num_heads + ModelVariant::Small => self.num_heads, }; let (query_states, key_states, value_states) = match self.variant { ModelVariant::Large => { let q_sz = self.num_heads * self.head_dim; let kv_sz = n_kv_heads * self.head_dim; - - let q = qkv.narrow(D::Minus1, 0, q_sz)? - .reshape((b_sz, q_len, self.num_heads, self.head_dim))?; - let k = qkv.narrow(D::Minus1, q_sz, kv_sz)? - .reshape((b_sz, q_len, n_kv_heads, self.head_dim))?; - let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)? - .reshape((b_sz, q_len, n_kv_heads, self.head_dim))?; + + let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape(( + b_sz, + q_len, + self.num_heads, + self.head_dim, + ))?; + let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; + let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; (q, k, v) } @@ -352,7 +376,7 @@ impl Attention { ( qkv.i((.., .., 0, .., ..))?, qkv.i((.., .., 1, .., ..))?, - qkv.i((.., .., 2, .., ..))? + qkv.i((.., .., 2, .., ..))?, ) } }; @@ -365,31 +389,27 @@ impl Attention { .rotary_emb .apply_rotary_emb_qkv(&query_states, &key_states)?; - // The 1.5B is expected to have grouped query attention let (key_states, value_states) = if self.variant == ModelVariant::Large { ( crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?, - crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()? + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?, ) } else { - ( - key_states, - value_states - ) + (key_states, value_states) }; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; let attn_weights = (attn_weights * scale)?; - + let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - + attn_weights.matmul(&value_states)? }; @@ -403,7 +423,7 @@ impl Attention { #[derive(Debug, Clone)] enum NormType { Layer(LayerNorm), - Rms(RmsNorm) + Rms(RmsNorm), } #[derive(Debug, Clone)] @@ -419,35 +439,49 @@ struct Layer { impl Layer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let attention = Attention::new(rotary_emb, cfg, vb.pp(if cfg.variant == ModelVariant::Large { "self_attn"} else { "attention" }))?; + let attention = Attention::new( + rotary_emb, + cfg, + vb.pp(if cfg.variant == ModelVariant::Large { + "self_attn" + } else { + "attention" + }), + )?; let mlp = MLP::new(cfg, vb.pp("mlp"))?; let (layernorm, post_attention_layernorm) = match cfg.variant { - ModelVariant::Large => { - ( - NormType::Rms(RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb.pp("input_layernorm"))?), - NormType::Rms(RmsNorm::new( - cfg.hidden_size, - cfg.norm_eps, - vb.pp("post_attention_layernorm"), - )?) - ) - } - ModelVariant::Small => { - ( - NormType::Layer(layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("mlp_ln") - )?), - NormType::Layer(layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("attn_ln") - )?) - ) - } + ModelVariant::Large => ( + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("input_layernorm"), + )?), + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?), + ), + ModelVariant::Small => ( + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("mlp_ln"), + )?), + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("attn_ln"), + )?), + ), }; - + Ok(Self { variant: cfg.variant, attention, @@ -480,16 +514,14 @@ impl Layer { match self.variant { ModelVariant::Large => { - let (attn_ln, input_ln) = if let ( - NormType::Rms(attn_ln), - NormType::Rms(input_ln) - ) = ( - &self.post_attention_layernorm, - &self.layernorm - ) { + let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { (attn_ln, input_ln) } else { - return Err(candle::error::Error::Msg("Stella 1.5B expects RMSNorm".to_string())); + return Err(candle::error::Error::Msg( + "Stella 1.5B expects RMSNorm".to_string(), + )); }; let xs = input_ln.forward(xs)?; @@ -501,17 +533,16 @@ impl Layer { residual + xs } ModelVariant::Small => { - let (attn_ln, output_ln) = if let ( - NormType::Layer(attn_ln), - NormType::Layer(input_ln) - ) = ( - &self.post_attention_layernorm, - &self.layernorm - ) { - (attn_ln, input_ln) - } else { - return Err(candle::error::Error::Msg("Stella 400M expects RMSNorm".to_string())); - }; + let (attn_ln, output_ln) = + if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg( + "Stella 400M expects RMSNorm".to_string(), + )); + }; let xs = (self.attention.forward(xs, attention_mask)? + residual)?; let xs = attn_ln.forward(&xs)?; @@ -534,42 +565,51 @@ pub struct Embeddings { // folloing are specifically for 400M token_type_embeddings: Option, layer_norm: Option, - position_ids: Option + position_ids: Option, } impl Embeddings { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let ( - embeddings, - token_type_embeddings, - layer_norm, - position_ids - ) = match cfg.variant { - ModelVariant::Large => { - ( - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?, - None, - None, - None - ) - }, + let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant { + ModelVariant::Large => ( + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?, + None, + None, + None, + ), ModelVariant::Small => { let vb = vb.pp("embeddings"); - let weight = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; - let bias = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; + let weight = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "weight", + candle_nn::Init::Const(1.0), + )?; + let bias = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "bias", + candle_nn::Init::Const(0.0), + )?; let dev = bias.device().clone(); let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); - + ( - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?, - Some(candle_nn::embedding(cfg.type_vocab_size, cfg.hidden_size, vb.pp("token_type_embeddings"))?), + candle_nn::embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings"), + )?, + Some(candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings"), + )?), Some(layer_norm), - Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, &dev)?) + Some(Tensor::arange( + 0u32, + cfg.max_position_embeddings as u32, + &dev, + )?), ) } }; @@ -579,7 +619,7 @@ impl Embeddings { embeddings, token_type_embeddings, layer_norm, - position_ids + position_ids, }) } } @@ -592,23 +632,19 @@ impl Module for Embeddings { return Ok(embd); } - let (token_type_embed, layer_norm, pos_ids) = if let ( - Some(token_type_embd), - Some(layer_norm), - Some(position_ids) - ) = ( - &self.token_type_embeddings, - &self.layer_norm, - &self.position_ids - ) { - ( - token_type_embd, - layer_norm, - position_ids - ) - } else { - return Err(Error::Msg("Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`".to_string())); - }; + let (token_type_embed, layer_norm, pos_ids) = + if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = ( + &self.token_type_embeddings, + &self.layer_norm, + &self.position_ids, + ) { + (token_type_embd, layer_norm, position_ids) + } else { + return Err(Error::Msg( + "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`" + .to_string(), + )); + }; let (batch_size, seq_length) = xs.dims2()?; @@ -617,12 +653,7 @@ impl Module for Embeddings { .narrow(0, 0, seq_length)? .expand((batch_size, seq_length))?; - layer_norm.forward( - &embd.add( - &token_type_embed.forward(&pos_ids.zeros_like()?)? - )? - ) - + layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?) } } @@ -639,7 +670,7 @@ impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { let vb_m = match cfg.variant { ModelVariant::Large => vb.pp("model"), - ModelVariant::Small => vb.pp("new") + ModelVariant::Small => vb.pp("new"), }; // let embed_tokens = // candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; @@ -648,17 +679,19 @@ impl Model { let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = match cfg.variant { ModelVariant::Large => vb_m.pp("layers"), - ModelVariant::Small => { - vb_m.pp("encoder").pp("layer") - } + ModelVariant::Small => vb_m.pp("encoder").pp("layer"), }; for layer_idx in 0..cfg.num_hidden_layers { let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } let norm = match cfg.variant { - ModelVariant::Large => Some(RmsNorm::new(cfg.hidden_size, cfg.norm_eps, vb_m.pp("norm"))?), - ModelVariant::Small => None + ModelVariant::Large => Some(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb_m.pp("norm"), + )?), + ModelVariant::Small => None, }; Ok(Self { embeddings, @@ -717,7 +750,7 @@ impl EmbeddingModel { let lm_head = linear( cfg.embed_head.in_features, cfg.embed_head.out_features, - embed_vb.pp("linear") + embed_vb.pp("linear"), )?; Ok(Self { @@ -758,4 +791,4 @@ impl EmbeddingModel { .expand((batch_size, hidden_dim))?; x.sum(1)? / sum_mask } -} \ No newline at end of file +} From aec2aa20784e17f7b8d95f690ba681ef5a706dcf Mon Sep 17 00:00:00 2001 From: lynxeco Date: Fri, 22 Nov 2024 19:41:03 -0800 Subject: [PATCH 6/7] removed redundant stella-400m model and example after merge into stella-en-v5 --- .../examples/stella-en-400m-v5/README.md | 46 - .../examples/stella-en-400m-v5/main.rs | 384 -------- candle-examples/examples/stella-en-v5/main.rs | 102 +- candle-transformers/src/models/mod.rs | 1 - .../src/models/stella_en_v5_400m.rs | 914 ------------------ 5 files changed, 45 insertions(+), 1402 deletions(-) delete mode 100644 candle-examples/examples/stella-en-400m-v5/README.md delete mode 100644 candle-examples/examples/stella-en-400m-v5/main.rs delete mode 100644 candle-transformers/src/models/stella_en_v5_400m.rs diff --git a/candle-examples/examples/stella-en-400m-v5/README.md b/candle-examples/examples/stella-en-400m-v5/README.md deleted file mode 100644 index ef1de31d09..0000000000 --- a/candle-examples/examples/stella-en-400m-v5/README.md +++ /dev/null @@ -1,46 +0,0 @@ ---- -model-index: -- name: stella_en_400M_v5 -license: mit ---- - - -# Introduction - -The models are trained based on `Alibaba-NLP/gte-large-en-v1.5` and `Alibaba-NLP/gte-Qwen2-1.5B-instruct`. Thanks for -their contributions! - -**We simplify usage of prompts, providing two prompts for most general tasks, one is for s2p, another one is for s2s.** - - - -The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5).for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. - -```bash -$ cargo run --example stella-en-v5 --release --features - -Similarity======== - [[0.8398, 0.2990], - [0.3282, 0.8095]] - -Score: 0.83975387 -Query: What are some ways to reduce stress? -Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up. - - - -Score: 0.8095451 -Query: What are the benefits of drinking green tea? -Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. - - -``` - - -The models are finally trained by [MRL](https://arxiv.org/abs/2205.13147), so they have multiple dimensions: 512, 768, -1024, 2048, 4096, 6144 and 8192. - -## Supported options: -- `Stella_en_400m_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - -- As per the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-400m-v5/main.rs b/candle-examples/examples/stella-en-400m-v5/main.rs deleted file mode 100644 index 6774c62aa3..0000000000 --- a/candle-examples/examples/stella-en-400m-v5/main.rs +++ /dev/null @@ -1,384 +0,0 @@ -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; - -#[cfg(feature = "accelerate")] -extern crate accelerate_src; - -use std::path::Path; - -use anyhow::{ anyhow, Error as E, Result }; -use clap::Parser; - -use candle_transformers::models::stella_en_v5_400m::{ - Config, - EmbedDim as StellaEmbedDim, - EmbeddingModel, -}; - -use candle::{ DType, Device, IndexOp, Tensor }; -use candle_nn::VarBuilder; -use hf_hub::{ api::sync::Api, Repo }; -use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; - -struct Embedding { - model: EmbeddingModel, - device: Device, - tokenizer: Tokenizer, -} - -impl Embedding { - fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self { - Self { - model, - tokenizer, - device: device.clone(), - } - } - - fn encode(&mut self, task: EncodeTask, text: Option) -> Result<()> { - // Just shocasing embeddings, this has no real value - if let Some(text) = text { - let qry = task.query_preproc(&[text]); - let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?; - - let shape = (1, encoding.len()); - let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; - let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; - - let result = self.model.forward(&input, &mask)?; - println!("result shape: {:?}", result.shape()); - println!("embeddings: {result}"); - } else { - // Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers) - let queries = [ - "What are some ways to reduce stress?".to_string(), - "What are the benefits of drinking green tea?".to_string(), - ]; - - let docs = [ - "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(), - "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(), - ]; - - // We only encode the queries and not the data - let qry = task.query_preproc(&queries); - let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; - - let mut docs_encoded = self.tokenizer - .encode_batch(docs.to_vec(), true) - .map_err(|e| anyhow!(e))?; - - let qry_embed = { - // Now, we generate the tensors for the `input` and `mask` - let shape = (qry_encoded.len(), qry_encoded[1].len()); - let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; - let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; - - for (i, e) in qry_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; - let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? - .to_dtype(DType::U8)? - .unsqueeze(0)?; - - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; - masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; - } - - // Let's generate the embeddings for the query, we are going to be normalizing the result. - // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data - self.model.forward_norm(&ids, &masks)? - }; - - let doc_embed = { - let shape = (docs_encoded.len(), docs_encoded[1].len()); - let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; - let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; - - for (i, e) in docs_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; - let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? - .to_dtype(DType::U8)? - .unsqueeze(0)?; - - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; - masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; - } - - // Let's generate the embeddings for the query, we are going to be normalizing the result. - // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data - self.model.forward_norm(&ids, &masks)? - }; - println!("Query vectors:\n {qry_embed}\n"); - println!("Document vectors:\n {doc_embed}\n"); - - println!( - "Embed shapes======\nQuery: {:?}\nDocs: {:?}\n", - qry_embed.shape(), - doc_embed.shape() - ); // [2, 1024] for head dim `1024` - - let answer = self.similarity(&qry_embed, &doc_embed); - println!("Similarity========\n {}", answer.unwrap()); - - // a matmul to generate the `similarity` score - let res = qry_embed.matmul(&doc_embed.t()?)?; - for (k, v) in queries.iter().enumerate() { - let tnsr = res.get(k)?; - let max = tnsr.argmax(0)?.to_scalar::()?; - println!( - "\nScore: {}\nQuery: {}\nAnswer: {}\n\n", - tnsr.get(max as usize)?.to_scalar::()?, - v, - docs[k] - ); - } - } - - Ok(()) - } - - /// Computes the cosine similarity between two tensors of embeddings - /// Similar to sentence-transformers' similarity() function - pub fn similarity(&self, embeddings1: &Tensor, embeddings2: &Tensor) -> Result { - // Normalize the embeddings (L2 norm) - let norm1 = embeddings1.broadcast_div(&embeddings1.sqr()?.sum_keepdim(1)?.sqrt()?)?; - let norm2 = embeddings2.broadcast_div(&embeddings2.sqr()?.sum_keepdim(1)?.sqrt()?)?; - - // Compute cosine similarity: dot product of normalized vectors - Ok(norm1.matmul(&norm2.t()?)?) - } - - fn _encode_single_doc(&mut self) -> Result { - // Example document text - let doc = - "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(); - - // Encode the document - let encoding = self.tokenizer.encode(doc, true).map_err(|e| anyhow!(e))?; - - // Create input tensors - let shape = (1, encoding.len()); - let ids = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; - let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; - - // Get embeddings and print intermediate values - let embeddings = self.model.forward_norm(&ids, &mask)?; - - // Print the shape and first few values - println!("Document embedding shape: {:?}", embeddings.shape()); - println!("First few values: {:?}", embeddings.i(0)?.narrow(0, 0, 3)?.to_vec1::()?); - - // Return normalized embeddings - let norm_embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; - - Ok(norm_embeddings) - } -} - -#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] -enum EmbedDim { - #[value(name = "256")] - Dim256, - #[value(name = "768")] - Dim768, - #[value(name = "1024")] - Dim1024, - #[value(name = "2048")] - Dim2048, - #[value(name = "4096")] - Dim4096, - #[value(name = "6144")] - Dim6144, - #[value(name = "8192")] - Dim8192, -} - -impl EmbedDim { - /// Returns dir path to the embed head weights int he repo - pub fn embed_dim_default_dir(&self) -> &'static str { - match self { - Self::Dim256 => "2_Dense_256", - Self::Dim768 => "2_Dense_768", - Self::Dim1024 => "2_Dense_1024", - Self::Dim2048 => "2_Dense_2048", - Self::Dim4096 => "2_Dense_4096", - Self::Dim6144 => "2_Dense_6144", - Self::Dim8192 => "2_Dense_8192", - } - } - - /// Resolves the `EmbedDim` for given variant - pub fn embed_dim(&self) -> StellaEmbedDim { - match self { - Self::Dim256 => StellaEmbedDim::Dim256, - Self::Dim768 => StellaEmbedDim::Dim768, - Self::Dim1024 => StellaEmbedDim::Dim1024, - Self::Dim2048 => StellaEmbedDim::Dim2048, - Self::Dim4096 => StellaEmbedDim::Dim4096, - Self::Dim6144 => StellaEmbedDim::Dim6144, - Self::Dim8192 => StellaEmbedDim::Dim8192, - } - } -} - -#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] -pub enum EncodeTask { - /// `s2p` is the `retrieval` task - /// Default in this example - #[value(name = "s2p")] - S2P, - /// `s2s` is the semantic similarity task - #[value(name = "s2s")] - S2S, -} - -impl EncodeTask { - /// Preprocess a set of inputs basef on a template suggested by the model authors - /// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction - pub fn query_preproc(&self, txt: &[String]) -> Vec { - let instruct = match self { - Self::S2P => { - "Given a web search query, retrieve relevant passages that answer the query." - } - Self::S2S => "Retrieve semantically similar text.", - }; - - txt.iter() - .map(|s| format!("Instruct: {instruct}\nQuery: {s}")) - .collect::>() - } -} - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - - /// Enable tracing (generates a trace-timestamp.json file). - #[arg(long)] - tracing: bool, - - #[arg(long)] - use_flash_attn: bool, - - #[arg(long)] - query: Option, - - #[arg(long, default_value = "1024")] - embed_dim: Option, - - #[arg(long)] - tokenizer_file: Option, - - #[arg(long)] - base_weight_files: Option, - - #[arg(long)] - embed_head_weight_files: Option, - - /// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5) - /// `s2s`: Semantic textual similarity - /// `s2p`: Retrieval task - `Default` in this example - #[arg(long, default_value = "s2p")] - task: Option, -} - -// Tokenizer creation is super critical in our case. -// We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result { - let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - - tokenizer.with_padding( - Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - }) - ); - - Ok(tokenizer) -} - -fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - - let args = Args::parse(); - let _guard = if args.tracing { - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; - println!( - "avx: {}, neon: {}, simd128: {}, f16c: {}", - candle::utils::with_avx(), - candle::utils::with_neon(), - candle::utils::with_simd128(), - candle::utils::with_f16c() - ); - - let start = std::time::Instant::now(); - let api = Api::new()?; - let embed_dim = match args.embed_dim { - Some(d) => d, - None => EmbedDim::Dim1024, - }; - let repo = api.repo(Repo::model("dunzhang/stella_en_400M_v5".to_string())); - let tokenizer_filename = match args.tokenizer_file { - Some(file) => std::path::PathBuf::from(file), - None => repo.get("tokenizer.json")?, - }; - - // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights - // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo - let base_weight_files = match args.base_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), - None => { vec![repo.get("model.safetensors")?] } - }; - - let embed_weight_files = match args.embed_head_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), - None => { - let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); - vec![repo.get(&head_w_path)?] - } - }; - - println!("retrieved the files in {:?}", start.elapsed()); - - // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; - - let start = std::time::Instant::now(); - - let device = candle_examples::device(args.cpu)?; - let dtype = DType::F32; - - let base_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? - }; - // Embedding layer is always built on F32 for accuracy - let embed_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? - }; - - let model = EmbeddingModel::new(&Config::new_400m(embed_dim.embed_dim()), base_vb, embed_vb)?; - - println!("loaded the model in {:?}", start.elapsed()); - - let mut embedding = Embedding::new(model, tokenizer, &device); - - let task = args.task.map_or(EncodeTask::S2P, |t| t); - - // let _doc_esmbeddings = embedding._encode_single_doc()?; - embedding.encode(task, args.query)?; - Ok(()) -} diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 68ed7e70c6..64239274e8 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -6,17 +6,19 @@ extern crate accelerate_src; use std::path::Path; -use anyhow::{anyhow, Error as E, Result}; +use anyhow::{ anyhow, Error as E, Result }; use clap::Parser; use candle_transformers::models::stella_en_v5::{ - Config, EmbedDim as StellaEmbedDim, EmbeddingModel, + Config, + EmbedDim as StellaEmbedDim, + EmbeddingModel, }; -use candle::{DType, Device, Tensor}; +use candle::{ DType, Device, Tensor }; use candle_nn::VarBuilder; -use hf_hub::{api::sync::Api, Repo}; -use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; +use hf_hub::{ api::sync::Api, Repo }; +use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; struct Embedding { model: EmbeddingModel, @@ -59,13 +61,9 @@ impl Embedding { // We only encode the queries and not the data let qry = task.query_preproc(&queries); - let mut qry_encoded = self - .tokenizer - .encode_batch(qry, true) - .map_err(|e| anyhow!(e))?; + let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; - let mut docs_encoded = self - .tokenizer + let mut docs_encoded = self.tokenizer .encode_batch(docs.to_vec(), true) .map_err(|e| anyhow!(e))?; @@ -76,14 +74,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in qry_encoded.drain(..).enumerate() { - let input_id = - Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = - ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -98,14 +96,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in docs_encoded.drain(..).enumerate() { - let input_id = - Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = - ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -268,25 +266,27 @@ fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { pad_id } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); + return Err(anyhow!("Tokenizer doesn't contain expected `<|endoftext|>` token")); }; // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + }) + ); } else { - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - })); + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + }) + ); } Ok(tokenizer) @@ -320,14 +320,8 @@ fn main() -> Result<()> { }; let (repo, cfg) = match args.which { - Which::Large => ( - "dunzhang/stella_en_1.5B_v5", - Config::new_1_5_b_v5(embed_dim.embed_dim()), - ), - Which::Small => ( - "dunzhang/stella_en_400M_v5", - Config::new_400_m_v5(embed_dim.embed_dim()), - ), + Which::Large => ("dunzhang/stella_en_1.5B_v5", Config::new_1_5_b_v5(embed_dim.embed_dim())), + Which::Small => ("dunzhang/stella_en_400M_v5", Config::new_400_m_v5(embed_dim.embed_dim())), }; let repo = api.repo(Repo::model(repo.to_string())); @@ -339,20 +333,12 @@ fn main() -> Result<()> { // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo let base_weight_files = match args.base_weight_files { - Some(files) => files - .split(',') - .map(std::path::PathBuf::from) - .collect::>(), - None => { - vec![repo.get("model.safetensors")?] - } + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), + None => { vec![repo.get("model.safetensors")?] } }; let embed_weight_files = match args.embed_head_weight_files { - Some(files) => files - .split(',') - .map(std::path::PathBuf::from) - .collect::>(), + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), None => { let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); vec![repo.get(&head_w_path)?] @@ -369,11 +355,13 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let dtype = DType::F32; - let base_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? }; + let base_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? + }; // Embedding layer is always built on F32 for accuracy - let embed_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; + let embed_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? + }; let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8c0564db84..23edf349ad 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -85,7 +85,6 @@ pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; pub mod stella_en_v5; -pub mod stella_en_v5_400m; pub mod t5; pub mod trocr; pub mod vgg; diff --git a/candle-transformers/src/models/stella_en_v5_400m.rs b/candle-transformers/src/models/stella_en_v5_400m.rs deleted file mode 100644 index bbe16377d6..0000000000 --- a/candle-transformers/src/models/stella_en_v5_400m.rs +++ /dev/null @@ -1,914 +0,0 @@ -use candle::{ DType, Device, IndexOp, Module, Result, Tensor }; -use candle_nn::{ layer_norm, Activation, LayerNorm, VarBuilder }; -use std::sync::Arc; -use std::time::Instant; - -use super::with_tracing::{ linear, linear_no_bias, Linear }; - -#[derive(Debug, Clone, Copy)] -pub enum EmbedDim { - Dim256, - Dim768, - Dim1024, - Dim2048, - Dim4096, - Dim6144, - Dim8192, -} - -impl Default for EmbedDim { - fn default() -> Self { - Self::Dim1024 - } -} - -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -pub struct EmbedHead { - pub in_features: usize, - pub out_features: usize, -} - -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -pub struct Config { - pub vocab_size: usize, - pub hidden_size: usize, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub max_position_embeddings: usize, - pub type_vocab_size: usize, - // pub pad_token_id: usize, - // pub hidden_dropout_prob: f64, - // pub attention_probs_dropout_prob: f64, - pub norm_eps: f64, - // pub initializer_range: f64, - // pub position_embedding_type: String, - pub scaling_factor: f64, - pub rope_theta: f64, - // pub use_memory_efficient_attention: bool, - // pub unpad_inputs: bool, - // pub layer_norm_type: String, - // pub logn_attention_scale: bool, - // pub logn_attention_clip1: bool, - pub activation_fn: Activation, - pub embed_head: EmbedHead, -} - -impl Config { - pub fn new_400m(embed_dim: EmbedDim) -> Self { - let embed_head = EmbedHead { - in_features: 1024, - out_features: match embed_dim { - EmbedDim::Dim256 => 256, - EmbedDim::Dim768 => 768, - EmbedDim::Dim1024 => 1024, - EmbedDim::Dim2048 => 2048, - EmbedDim::Dim4096 => 4096, - EmbedDim::Dim6144 => 6144, - EmbedDim::Dim8192 => 8192, - }, - }; - - Self { - vocab_size: 30528, - hidden_size: 1024, - intermediate_size: 4096, - num_hidden_layers: 24, - num_attention_heads: 16, - max_position_embeddings: 8192, - type_vocab_size: 2, - // pad_token_id: 0, - // hidden_dropout_prob: 0.1, - // attention_probs_dropout_prob: 0.0, - norm_eps: 1e-12, - // initializer_range: 0.02, - // position_embedding_type: "rope".to_string(), - scaling_factor: 2.0, - rope_theta: 160000.0, - // use_memory_efficient_attention: true, - // unpad_inputs: false, - // layer_norm_type: "layer_norm".to_string(), - // logn_attention_scale: false, - // logn_attention_clip1: false, - activation_fn: Activation::Gelu, - embed_head, - } - } -} - -#[derive(Debug, Clone)] -struct RotaryEmbedding { - sin: Tensor, - cos: Tensor, - // _scaling_factor: f64, - // _mixed_b: Option, - // _dim: usize, - // _base: f64, -} - -impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = cfg.max_position_embeddings; - let scaling_factor = cfg.scaling_factor; // Can be configured in Config if needed - let base = cfg.rope_theta; - - // Calculate scaled position embeddings - let scaled_max_seq_len = ((max_seq_len as f64) * scaling_factor) as usize; - - // Calculate inv_freq with NTK scaling - let inv_freq: Vec<_> = (0..dim / 2) - .map(|i| { - // Apply base scaling - let base = base * scaling_factor; - let freq = 1.0 / base.powf((2.0 * (i as f64)) / (dim as f64)); - - // Apply fixed NTK scaling - let freq = freq / scaling_factor.powf(2.0 / (dim as f64)); - - freq as f32 - }) - .collect(); - - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - - // Calculate position embeddings with scaled sequence length - let t = Tensor::arange(0u32, scaled_max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((scaled_max_seq_len, 1))?; - - let freqs = t.matmul(&inv_freq)?; - let emb = Tensor::cat(&[&freqs, &freqs], 1)?; - - Ok(Self { - sin: emb.sin()?, - cos: emb.cos()?, - // _scaling_factor: scaling_factor, - // _mixed_b: None, - // _dim: dim, - // _base: base, - }) - } -} - -#[derive(Debug, Clone)] -enum NormType { - LayerNorm(candle_nn::LayerNorm), - // RmsNorm(RmsNorm), -} - -impl NormType { - fn forward(&self, x: &Tensor) -> Result { - match self { - Self::LayerNorm(ln) => ln.forward(x), - // Self::RmsNorm(rms) => rms.forward(x), - } - } -} - -#[derive(Debug)] -pub struct Embeddings { - word_embeddings: candle_nn::Embedding, - // position_embeddings: Option, - token_type_embeddings: candle_nn::Embedding, - layer_norm: LayerNorm, - // _padding_idx: usize, - // _position_embedding_type: String, - rotary_emb: Arc, - position_ids: Tensor, -} - -impl Embeddings { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let word_embeddings = candle_nn::embedding( - cfg.vocab_size, - cfg.hidden_size, - vb.pp("word_embeddings") - )?; - - // let position_embeddings = if cfg.position_embedding_type == "absolute" { - // Some( - // candle_nn::embedding( - // cfg.max_position_embeddings, - // cfg.hidden_size, - // vb.pp("position_embeddings") - // )? - // ) - // } else { - // None - // }; - - let token_type_embeddings = candle_nn::embedding( - cfg.type_vocab_size, - cfg.hidden_size, - vb.pp("token_type_embeddings") - )?; - // if cfg.type_vocab_size > 0 { - // Some( - // candle_nn::embedding( - // cfg.type_vocab_size, - // cfg.hidden_size, - // vb.pp("token_type_embeddings") - // )? - // ) - // } else { - // None - // }; - - //if cfg.layer_norm_type == "layer_norm" { - let weight = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; - let bias = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; - let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); - // } else { - // NormType::RmsNorm( - // RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? - // ) - // }; - - // let rotary_emb = if cfg.position_embedding_type == "rope" { - // Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) - // } else { - // None - // }; - let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); - - // let position_ids = if cfg.position_embedding_type == "absolute" { - // Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) - // } else { - // None - // }; - let position_ids = Tensor::arange(0u32, cfg.max_position_embeddings as u32, word_embeddings.embeddings().device())?; - - Ok(Self { - word_embeddings, - // position_embeddings, - token_type_embeddings, - layer_norm, - // _padding_idx: cfg.pad_token_id, - // _position_embedding_type: cfg.position_embedding_type.clone(), - rotary_emb, - position_ids, - }) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - // token_type_ids: Option<&Tensor>, - // position_ids: Option<&Tensor>, - // inputs_embeds: Option<&Tensor>, - // unpad_inputs: bool, - // attention_mask: Option<&Tensor> - ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> { - let (batch_size, seq_length) = input_ids.dims2()?; - - let mut embeddings = self.word_embeddings.forward(input_ids)?; - // match inputs_embeds { - // Some(e) => e.clone(), - // None => self.word_embeddings.forward(input_ids)?, - // }; - - // Get position_ids first - // let position_ids = if let Some(ids) = position_ids { - // ids.clone() - // } else { - // Get device from input_ids which is always available - // let device = input_ids.device(); - - // Initialize position_ids if None - // if self.position_ids.is_none() { - // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - // } - - // // Now check if we need to extend it - // if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { - // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - // } - - // let position_ids = - /*if unpad_inputs { - // For now, just use the same position IDs as padded case since we don't have lengths - self.position_ids - .as_ref() - .unwrap() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))? - } else {*/ - // self.position_ids - // .as_ref() - // .unwrap() - // .narrow(0, 0, seq_length)? - // .expand((batch_size, seq_length))?; - // }; - // }; - - let position_ids = self.position_ids - .as_ref() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))?; - - - let rope_embeds = { - // Get the cos and sin for this sequence length - let cos = self.rotary_emb.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - let sin = self.rotary_emb.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - - // Index using position_ids if needed - let position_ids = position_ids.flatten_all()?; - let cos = cos.index_select(&position_ids, 0)?; // Use index_select instead of i() - let sin = sin.index_select(&position_ids, 0)?; // Use index_select instead of i() - - Some((cos, sin)) - }; - // // Get rotary embeddings if using RoPE - // let rope_embeds = if let Some(rotary) = &self.rotary_emb { - - // } else { - // None - // }; - - // Handle token type embeddings - embeddings = embeddings.add(&self.token_type_embeddings.forward(&position_ids.zeros_like()?)?).unwrap(); - // if let Some(token_emb) = &self.token_type_embeddings { - // let token_type_ids = if let Some(ids) = token_type_ids { - // ids.clone() - // } else { - // position_ids.zeros_like()? // Use mul(0) equivalent - // }; - // if unpad_inputs { - // todo!("Implement unpadded case"); - // } else { - // embeddings = embeddings.add(&token_emb.forward(&position_ids.zeros_like()?)?).unwrap(); - // } - // } - - // Handle absolute position embeddings - // if let Some(pos_emb) = &self.position_embeddings { - // let position_embeddings = pos_emb.forward(&position_ids)?; - // embeddings = embeddings.add(&position_embeddings)?; - // } - - let embeddings = self.layer_norm.forward(&embeddings)?; - - Ok((embeddings, rope_embeds)) - } -} - -#[derive(Debug)] -struct NewAttention { - qkv_proj: Linear, - o_proj: Linear, - num_heads: usize, - head_dim: usize, - hidden_size: usize, - // _use_memory_efficient_attention: bool, -} - -impl NewAttention { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let hidden_sz = cfg.hidden_size; - let num_heads = cfg.num_attention_heads; - let head_dim = hidden_sz / num_heads; - - let qkv_proj = linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?; - let o_proj = linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; - - Ok(Self { - qkv_proj, - o_proj, - num_heads, - head_dim, - hidden_size: hidden_sz, - // _use_memory_efficient_attention: cfg.use_memory_efficient_attention, - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // _attention_scale: Option<&Tensor> - ) -> Result { - let (b_sz, seq_len, _) = hidden_states.dims3()?; - - // QKV projection - let qkv = self.qkv_proj.forward(hidden_states)?; - - // Split into Q, K, V and reshape to match PyTorch shapes - let qkv = qkv.reshape((b_sz, seq_len, 3, self.num_heads, self.head_dim))?; - - // Get Q, K, V with shape [batch, seq_len, num_heads, head_dim] - let query_states = qkv.i((.., .., 0, .., ..))?.contiguous()?; - let key_states = qkv.i((.., .., 1, .., ..))?.contiguous()?; - let value_states = qkv.i((.., .., 2, .., ..))?.contiguous()?; - - // Apply RoPE if provided - let (query_states, key_states) = if let Some((cos, sin)) = rope_embeds { - apply_rotary_pos_emb(&query_states, &key_states, cos, sin)? - } else { - (query_states, key_states) - }; - - // Transpose for attention computation [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] - let query_states = query_states.transpose(1, 2)?.contiguous()?; - let key_states = key_states.transpose(1, 2)?.contiguous()?; - let value_states = value_states.transpose(1, 2)?.contiguous()?; - - // For key, we want to transpose the last two dimensions for the matmul - // Is this equivalent to PyTorch's transpose(-1, -2)? - let key_states_t = key_states.transpose(2, 3)?.contiguous()?; - - // Prepare tensors for batched matmul using matmul - // Reshape tensors to merge batch and head dimensions - let bsz = b_sz; - let nh = self.num_heads; - let s_len = seq_len; - let h_dim = self.head_dim; - - // Reshape tensors to [batch_size * num_heads, seq_len, head_dim] - let query_states_reshaped = query_states.reshape((bsz * nh, s_len, h_dim))?; - let key_states_t_reshaped = key_states_t.reshape((bsz * nh, h_dim, s_len))?; - - // Perform batched matmul using matmul - // The matmul should handle batch dimensions if tensors are 3D - let attn_weights = query_states_reshaped.matmul(&key_states_t_reshaped)?; - - // Reshape attn_weights back to [batch_size, num_heads, seq_len, seq_len] - let attn_weights = attn_weights.reshape((bsz, nh, s_len, s_len))?; - - // Scale attention scores - let scale = 1f32 / (self.head_dim as f32).sqrt(); - - let scale_tensor = Tensor::new(scale, attn_weights.device())? - .to_dtype(attn_weights.dtype())? - .broadcast_as(attn_weights.shape())?; - - // Multiply the attention weights by the scalar tensor - let attn_weights = attn_weights.mul(&scale_tensor)?; - - // Apply attention mask - let mut attn_weights = if let Some(bias) = attention_bias { - // let attn_weights = attn_weights.broadcast_add(bias)?; - attn_weights.broadcast_add(bias)? - } else { - attn_weights - }; - - // Normalize attention scores - attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - - // Reshape value_states for batched matmul - let value_states_reshaped = value_states.reshape((bsz * nh, s_len, h_dim))?; - - // Reshape attn_weights to [batch_size * num_heads, seq_len, seq_len] - let attn_weights_reshaped = attn_weights.reshape((bsz * nh, s_len, s_len))?; - - // Compute attention output - let attn_output = attn_weights_reshaped.matmul(&value_states_reshaped)?; - - // Reshape attn_output back to [batch_size, num_heads, seq_len, head_dim] - let attn_output = attn_output.reshape((bsz, nh, s_len, h_dim))?; - - // Transpose back to [batch_size, seq_len, num_heads, head_dim] - let attn_output = attn_output.transpose(1, 2)?; - - // Project to final dimension - let attn_output = attn_output.reshape((b_sz, seq_len, self.hidden_size))?; - self.o_proj.forward(&attn_output) - } -} - -#[derive(Debug)] -struct NewGatedMLP { - up_gate_proj: Linear, - down_proj: Linear, - act_fn: Activation, -} - -impl NewGatedMLP { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let hidden_sz = cfg.hidden_size; - let intermediate_size = cfg.intermediate_size; - let act_fn = cfg.activation_fn; - let up_gate_proj = linear_no_bias(hidden_sz, intermediate_size * 2, vb.pp("up_gate_proj"))?; - let down_proj = linear(intermediate_size, hidden_sz, vb.pp("down_proj"))?; - - Ok(Self { - up_gate_proj, - down_proj, - act_fn, - }) - } -} - -impl Module for NewGatedMLP { - fn forward(&self, xs: &Tensor) -> Result { - let up_gate = self.up_gate_proj.forward(xs)?; - - // Get the dimensions - let (_batch_size, _seq_len, hidden_dim) = up_gate.dims3()?; - let split_size = hidden_dim / 2; - - // Split along the last dimension (hidden_dim) - let up_states = up_gate.narrow(2, 0, split_size)?; - let gate = up_gate.narrow(2, split_size, split_size)?; - - // Apply activation to gate and multiply - let gate = gate.apply(&self.act_fn)?; - - let gated_states = up_states.mul(&gate)?; - - // Project back to hidden dimension - let output = self.down_proj.forward(&gated_states)?; - - Ok(output) - } -} - -#[derive(Debug)] -struct NewLayer { - attention: NewAttention, - mlp: NewGatedMLP, - attn_ln: NormType, - mlp_ln: NormType, -} - -impl NewLayer { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let attention = NewAttention::new(cfg, vb.pp("attention"))?; - let mlp = NewGatedMLP::new(cfg, vb.pp("mlp"))?; - - // let ln_eps = cfg.layer_norm_eps; - - // Use LayerNorm or RmsNorm based on config - let (attn_ln, mlp_ln) = { - let attn_ln = layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("attn_ln") - )?; - let mlp_ln = layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("mlp_ln") - )?; - (NormType::LayerNorm(attn_ln), NormType::LayerNorm(mlp_ln)) - }; - // else - // { - // let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; - // let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; - // (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) - // }; - - Ok(Self { - attention, - mlp, - attn_ln, - mlp_ln, - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // attention_scale: Option<&Tensor> - ) -> Result { - // Store original input - let original = hidden_states; - - // Use normalized states for attention - let hidden_states = self.attention.forward( - original, - attention_bias, - rope_embeds, - // attention_scale - )?; - - let hidden_states = original.add(&hidden_states)?; - - // Apply layer norm - let hidden_states = self.attn_ln.forward(&hidden_states)?; - - // Store residual - let residual = &hidden_states; - - // Pass through MLP - let hidden_states = self.mlp.forward(&hidden_states)?; - - // Add residual connection - let hidden_states = residual.add(&hidden_states)?; - - // Final layer norm - self.mlp_ln.forward(&hidden_states) - } -} - -#[derive(Debug)] -struct NewEncoder { - layers: Vec, -} - -impl NewEncoder { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("layer"); - for layer_idx in 0..cfg.num_hidden_layers { - layers.push(NewLayer::new(cfg, vb_l.pp(layer_idx))?); - } - Ok(Self { layers }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // attention_scale: Option<&Tensor> - ) -> Result { - let mut hidden_states = hidden_states.clone(); - - for layer in self.layers.iter() { - hidden_states = layer.forward( - &hidden_states, - attention_bias, - rope_embeds, - // attention_scale - )?; - } - - Ok(hidden_states) - } -} - -#[derive(Debug)] -pub struct NewModel { - embeddings: Embeddings, - encoder: NewEncoder, - device: Device, - dtype: DType, - // config: Config, -} - -impl NewModel { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vb_m = vb.pp("new"); - let embeddings = Embeddings::new(cfg, vb_m.pp("embeddings"))?; - let encoder = NewEncoder::new(cfg, vb_m.pp("encoder"))?; - Ok(Self { - embeddings, - encoder, - device: vb.device().clone(), - dtype: vb.dtype(), - // config: cfg.clone(), - }) - } - - fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result { - let (b_sz, seq_len) = attn_mask.dims2()?; - let mask = attn_mask - .unsqueeze(1) - ? // [b_sz, 1, seq_len] - .unsqueeze(2) - ? // [b_sz, 1, 1, seq_len] - .broadcast_as((b_sz, 1, 1, seq_len))?; // [b_sz, 1, 1, seq_len] - - // Use a large negative value for mask instead of -0.0 - let on_true = mask.zeros_like()?.to_dtype(self.dtype)?; - let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)? - .broadcast_as(mask.shape())? - .to_dtype(self.dtype)?; - - mask.where_cond(&on_true, &on_false) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - attention_mask: &Tensor, - // token_type_ids: Option<&Tensor>, - // position_ids: Option<&Tensor> - ) -> Result { - let (_, seq_length) = input_ids.dims2()?; - - // Get attention mask if not provided - // let attention_mask = mask; - - // Prepare attention bias - let attention_bias = if seq_length <= 1 { - None - } else { - Some(self.prepare_attention_mask(attention_mask)?) - }; - - // Get embeddings and rotary embeddings - let (hidden_states, rope_embeds) = self.embeddings.forward( - input_ids, - // token_type_ids, - // position_ids, - // None, - // self.config.unpad_inputs, - // Some(&attention_mask) - )?; - - // Compute attention scale if needed - // let attention_scale = if self.config.logn_attention_scale { - // let scale = - // attention_mask.sum_keepdim(1)?.log()? / - // (self.config.max_position_embeddings as f64).ln(); - // if self.config.logn_attention_clip1 { - // let scale = scale?; - // Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) - // } else { - // Some(scale?) - // } - // } else { - // None - // }; - - // Forward through encoder - let hidden_states = self.encoder.forward( - &hidden_states, - attention_bias.as_ref(), - rope_embeds.as_ref(), - // attention_scale.as_ref() - )?; - - Ok(hidden_states) - } -} - -// Optional pooler implementation -// #[derive(Debug)] -// pub struct NewPooler { -// dense: Linear, -// } - -// impl NewPooler { -// pub fn new(cfg: &Config, vb: VarBuilder) -> Result { -// let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; -// Ok(Self { dense }) -// } - -// pub fn forward(&self, hidden_states: &Tensor) -> Result { -// let first_token = hidden_states.i((.., 0, ..))?; -// let pooled = self.dense.forward(&first_token)?; -// pooled.tanh() -// } -// } - -// // Complete model with pooler -// #[derive(Debug)] -// pub struct NewModelWithPooler { -// model: NewModel, -// pooler: Option, -// } - -// impl NewModelWithPooler { -// pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { -// let vb_m = vb.pp("new"); -// let model = NewModel::new(cfg, vb_m.pp("model"))?; -// let pooler = if add_pooling_layer { -// Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) -// } else { -// None -// }; -// Ok(Self { model, pooler }) -// } - -// pub fn forward( -// &mut self, -// input_ids: &Tensor, -// attention_mask: Option<&Tensor>, -// token_type_ids: Option<&Tensor>, -// position_ids: Option<&Tensor> -// ) -> Result<(Tensor, Option)> { -// let hidden_states = self.model.forward( -// input_ids, -// attention_mask, -// token_type_ids, -// position_ids -// )?; - -// let pooled_output = match &self.pooler { -// Some(pooler) => Some(pooler.forward(&hidden_states)?), -// None => None, -// }; - -// Ok((hidden_states, pooled_output)) -// } -// } - -#[derive(Debug)] -pub struct EmbeddingModel { - base_model: NewModel, - lm_head: Linear, -} - -impl EmbeddingModel { - pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result { - let base_model = NewModel::new(cfg, base_vb)?; - let lm_head = linear( - cfg.embed_head.in_features, - cfg.embed_head.out_features, - embed_vb.pp("linear") - )?; - - Ok(Self { - base_model, - lm_head, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.base_model.forward(input_ids, mask)?;//, None, None)?; - let x = self.pool(&x, mask)?; - self.lm_head.forward(&x.to_dtype(DType::F32)?) - } - - pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.forward(input_ids, mask)?; - x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?) - } - fn pool(&self, x: &Tensor, mask: &Tensor) -> Result { - let mask = mask.to_dtype(x.dtype())?; - let (batch_size, seq_len, hidden_dim) = x.dims3()?; - let mask_expanded = mask.unsqueeze(2)?.broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim] - let x = x.mul(&mask_expanded)?; - let sum_mask = mask.sum(1)?.unsqueeze(1)?.expand((batch_size, hidden_dim))?; - x.sum(1)? / sum_mask - } -} - -pub fn time_run(f: F) -> (T, std::time::Duration) where F: FnOnce() -> T { - let start = Instant::now(); - let result = f(); - let duration = start.elapsed(); - (result, duration) -} - -fn apply_rotary_pos_emb( - q: &Tensor, - k: &Tensor, - cos: &Tensor, - sin: &Tensor -) -> Result<(Tensor, Tensor)> { - let cos = cos.to_dtype(q.dtype())?; - let sin = sin.to_dtype(q.dtype())?; - - let (batch_size, seq_len, num_heads, head_dim) = q.dims4()?; - let half_dim = head_dim / 2; - - // Reshape q and k to split the head dim for rotation - let q_split = q.chunk(2, 3)?; // Split along head_dim - let k_split = k.chunk(2, 3)?; - - let q1 = &q_split[0]; - let q2 = &q_split[1]; - let k1 = &k_split[0]; - let k2 = &k_split[1]; - - // Handle cos/sin for the sequence length we have - let cos = cos.narrow(0, 0, seq_len)?; - let sin = sin.narrow(0, 0, seq_len)?; - - // Reshape cos/sin to match the dimensions we need - let cos = cos - .reshape((seq_len, head_dim))? - .chunk(2, 1)? - [0].reshape((seq_len, 1, half_dim))? - .broadcast_as((seq_len, num_heads, half_dim))? - .unsqueeze(0)? - .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; - - let sin = sin - .reshape((seq_len, head_dim))? - .chunk(2, 1)? - [0].reshape((seq_len, 1, half_dim))? - .broadcast_as((seq_len, num_heads, half_dim))? - .unsqueeze(0)? - .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; - - // Apply rotation using the formulas: - // q = q * cos + rotate_half(q) * sin - // k = k * cos + rotate_half(k) * sin - let q_out = Tensor::cat( - &[&q1.mul(&cos)?.sub(&q2.mul(&sin)?)?, &q2.mul(&cos)?.add(&q1.mul(&sin)?)?], - 3 - )?; - - let k_out = Tensor::cat( - &[&k1.mul(&cos)?.sub(&k2.mul(&sin)?)?, &k2.mul(&cos)?.add(&k1.mul(&sin)?)?], - 3 - )?; - - Ok((q_out, k_out)) -} From 339b435dc530a96aaef23cf602c7e05db7ed704b Mon Sep 17 00:00:00 2001 From: lynxeco Date: Tue, 26 Nov 2024 09:41:34 -0800 Subject: [PATCH 7/7] cargo fmt --all --- candle-examples/examples/stella-en-v5/main.rs | 102 ++++++++++-------- 1 file changed, 57 insertions(+), 45 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 64239274e8..68ed7e70c6 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -6,19 +6,17 @@ extern crate accelerate_src; use std::path::Path; -use anyhow::{ anyhow, Error as E, Result }; +use anyhow::{anyhow, Error as E, Result}; use clap::Parser; use candle_transformers::models::stella_en_v5::{ - Config, - EmbedDim as StellaEmbedDim, - EmbeddingModel, + Config, EmbedDim as StellaEmbedDim, EmbeddingModel, }; -use candle::{ DType, Device, Tensor }; +use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; -use hf_hub::{ api::sync::Api, Repo }; -use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; +use hf_hub::{api::sync::Api, Repo}; +use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; struct Embedding { model: EmbeddingModel, @@ -61,9 +59,13 @@ impl Embedding { // We only encode the queries and not the data let qry = task.query_preproc(&queries); - let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; + let mut qry_encoded = self + .tokenizer + .encode_batch(qry, true) + .map_err(|e| anyhow!(e))?; - let mut docs_encoded = self.tokenizer + let mut docs_encoded = self + .tokenizer .encode_batch(docs.to_vec(), true) .map_err(|e| anyhow!(e))?; @@ -74,14 +76,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in qry_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -96,14 +98,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in docs_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -266,27 +268,25 @@ fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { pad_id } else { - return Err(anyhow!("Tokenizer doesn't contain expected `<|endoftext|>` token")); + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); }; // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding( - Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - }) - ); + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); } else { - tokenizer.with_padding( - Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - }) - ); + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + })); } Ok(tokenizer) @@ -320,8 +320,14 @@ fn main() -> Result<()> { }; let (repo, cfg) = match args.which { - Which::Large => ("dunzhang/stella_en_1.5B_v5", Config::new_1_5_b_v5(embed_dim.embed_dim())), - Which::Small => ("dunzhang/stella_en_400M_v5", Config::new_400_m_v5(embed_dim.embed_dim())), + Which::Large => ( + "dunzhang/stella_en_1.5B_v5", + Config::new_1_5_b_v5(embed_dim.embed_dim()), + ), + Which::Small => ( + "dunzhang/stella_en_400M_v5", + Config::new_400_m_v5(embed_dim.embed_dim()), + ), }; let repo = api.repo(Repo::model(repo.to_string())); @@ -333,12 +339,20 @@ fn main() -> Result<()> { // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo let base_weight_files = match args.base_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), - None => { vec![repo.get("model.safetensors")?] } + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + vec![repo.get("model.safetensors")?] + } }; let embed_weight_files = match args.embed_head_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), None => { let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); vec![repo.get(&head_w_path)?] @@ -355,13 +369,11 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let dtype = DType::F32; - let base_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? - }; + let base_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? }; // Embedding layer is always built on F32 for accuracy - let embed_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? - }; + let embed_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;