From aec2aa20784e17f7b8d95f690ba681ef5a706dcf Mon Sep 17 00:00:00 2001 From: lynxeco Date: Fri, 22 Nov 2024 19:41:03 -0800 Subject: [PATCH] removed redundant stella-400m model and example after merge into stella-en-v5 --- .../examples/stella-en-400m-v5/README.md | 46 - .../examples/stella-en-400m-v5/main.rs | 384 -------- candle-examples/examples/stella-en-v5/main.rs | 102 +- candle-transformers/src/models/mod.rs | 1 - .../src/models/stella_en_v5_400m.rs | 914 ------------------ 5 files changed, 45 insertions(+), 1402 deletions(-) delete mode 100644 candle-examples/examples/stella-en-400m-v5/README.md delete mode 100644 candle-examples/examples/stella-en-400m-v5/main.rs delete mode 100644 candle-transformers/src/models/stella_en_v5_400m.rs diff --git a/candle-examples/examples/stella-en-400m-v5/README.md b/candle-examples/examples/stella-en-400m-v5/README.md deleted file mode 100644 index ef1de31d09..0000000000 --- a/candle-examples/examples/stella-en-400m-v5/README.md +++ /dev/null @@ -1,46 +0,0 @@ ---- -model-index: -- name: stella_en_400M_v5 -license: mit ---- - - -# Introduction - -The models are trained based on `Alibaba-NLP/gte-large-en-v1.5` and `Alibaba-NLP/gte-Qwen2-1.5B-instruct`. Thanks for -their contributions! - -**We simplify usage of prompts, providing two prompts for most general tasks, one is for s2p, another one is for s2s.** - - - -The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5).for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. - -```bash -$ cargo run --example stella-en-v5 --release --features - -Similarity======== - [[0.8398, 0.2990], - [0.3282, 0.8095]] - -Score: 0.83975387 -Query: What are some ways to reduce stress? -Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up. - - - -Score: 0.8095451 -Query: What are the benefits of drinking green tea? -Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. - - -``` - - -The models are finally trained by [MRL](https://arxiv.org/abs/2205.13147), so they have multiple dimensions: 512, 768, -1024, 2048, 4096, 6144 and 8192. - -## Supported options: -- `Stella_en_400m_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - -- As per the [model card](https://huggingface.co/dunzhang/stella_en_400M_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-400m-v5/main.rs b/candle-examples/examples/stella-en-400m-v5/main.rs deleted file mode 100644 index 6774c62aa3..0000000000 --- a/candle-examples/examples/stella-en-400m-v5/main.rs +++ /dev/null @@ -1,384 +0,0 @@ -#[cfg(feature = "mkl")] -extern crate intel_mkl_src; - -#[cfg(feature = "accelerate")] -extern crate accelerate_src; - -use std::path::Path; - -use anyhow::{ anyhow, Error as E, Result }; -use clap::Parser; - -use candle_transformers::models::stella_en_v5_400m::{ - Config, - EmbedDim as StellaEmbedDim, - EmbeddingModel, -}; - -use candle::{ DType, Device, IndexOp, Tensor }; -use candle_nn::VarBuilder; -use hf_hub::{ api::sync::Api, Repo }; -use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; - -struct Embedding { - model: EmbeddingModel, - device: Device, - tokenizer: Tokenizer, -} - -impl Embedding { - fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self { - Self { - model, - tokenizer, - device: device.clone(), - } - } - - fn encode(&mut self, task: EncodeTask, text: Option) -> Result<()> { - // Just shocasing embeddings, this has no real value - if let Some(text) = text { - let qry = task.query_preproc(&[text]); - let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?; - - let shape = (1, encoding.len()); - let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; - let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; - - let result = self.model.forward(&input, &mask)?; - println!("result shape: {:?}", result.shape()); - println!("embeddings: {result}"); - } else { - // Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers) - let queries = [ - "What are some ways to reduce stress?".to_string(), - "What are the benefits of drinking green tea?".to_string(), - ]; - - let docs = [ - "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(), - "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(), - ]; - - // We only encode the queries and not the data - let qry = task.query_preproc(&queries); - let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; - - let mut docs_encoded = self.tokenizer - .encode_batch(docs.to_vec(), true) - .map_err(|e| anyhow!(e))?; - - let qry_embed = { - // Now, we generate the tensors for the `input` and `mask` - let shape = (qry_encoded.len(), qry_encoded[1].len()); - let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; - let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; - - for (i, e) in qry_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; - let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? - .to_dtype(DType::U8)? - .unsqueeze(0)?; - - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; - masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; - } - - // Let's generate the embeddings for the query, we are going to be normalizing the result. - // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data - self.model.forward_norm(&ids, &masks)? - }; - - let doc_embed = { - let shape = (docs_encoded.len(), docs_encoded[1].len()); - let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; - let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; - - for (i, e) in docs_encoded.drain(..).enumerate() { - let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( - 0 - )?; - let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? - .to_dtype(DType::U8)? - .unsqueeze(0)?; - - ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; - masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; - } - - // Let's generate the embeddings for the query, we are going to be normalizing the result. - // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data - self.model.forward_norm(&ids, &masks)? - }; - println!("Query vectors:\n {qry_embed}\n"); - println!("Document vectors:\n {doc_embed}\n"); - - println!( - "Embed shapes======\nQuery: {:?}\nDocs: {:?}\n", - qry_embed.shape(), - doc_embed.shape() - ); // [2, 1024] for head dim `1024` - - let answer = self.similarity(&qry_embed, &doc_embed); - println!("Similarity========\n {}", answer.unwrap()); - - // a matmul to generate the `similarity` score - let res = qry_embed.matmul(&doc_embed.t()?)?; - for (k, v) in queries.iter().enumerate() { - let tnsr = res.get(k)?; - let max = tnsr.argmax(0)?.to_scalar::()?; - println!( - "\nScore: {}\nQuery: {}\nAnswer: {}\n\n", - tnsr.get(max as usize)?.to_scalar::()?, - v, - docs[k] - ); - } - } - - Ok(()) - } - - /// Computes the cosine similarity between two tensors of embeddings - /// Similar to sentence-transformers' similarity() function - pub fn similarity(&self, embeddings1: &Tensor, embeddings2: &Tensor) -> Result { - // Normalize the embeddings (L2 norm) - let norm1 = embeddings1.broadcast_div(&embeddings1.sqr()?.sum_keepdim(1)?.sqrt()?)?; - let norm2 = embeddings2.broadcast_div(&embeddings2.sqr()?.sum_keepdim(1)?.sqrt()?)?; - - // Compute cosine similarity: dot product of normalized vectors - Ok(norm1.matmul(&norm2.t()?)?) - } - - fn _encode_single_doc(&mut self) -> Result { - // Example document text - let doc = - "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(); - - // Encode the document - let encoding = self.tokenizer.encode(doc, true).map_err(|e| anyhow!(e))?; - - // Create input tensors - let shape = (1, encoding.len()); - let ids = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; - let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; - - // Get embeddings and print intermediate values - let embeddings = self.model.forward_norm(&ids, &mask)?; - - // Print the shape and first few values - println!("Document embedding shape: {:?}", embeddings.shape()); - println!("First few values: {:?}", embeddings.i(0)?.narrow(0, 0, 3)?.to_vec1::()?); - - // Return normalized embeddings - let norm_embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?; - - Ok(norm_embeddings) - } -} - -#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] -enum EmbedDim { - #[value(name = "256")] - Dim256, - #[value(name = "768")] - Dim768, - #[value(name = "1024")] - Dim1024, - #[value(name = "2048")] - Dim2048, - #[value(name = "4096")] - Dim4096, - #[value(name = "6144")] - Dim6144, - #[value(name = "8192")] - Dim8192, -} - -impl EmbedDim { - /// Returns dir path to the embed head weights int he repo - pub fn embed_dim_default_dir(&self) -> &'static str { - match self { - Self::Dim256 => "2_Dense_256", - Self::Dim768 => "2_Dense_768", - Self::Dim1024 => "2_Dense_1024", - Self::Dim2048 => "2_Dense_2048", - Self::Dim4096 => "2_Dense_4096", - Self::Dim6144 => "2_Dense_6144", - Self::Dim8192 => "2_Dense_8192", - } - } - - /// Resolves the `EmbedDim` for given variant - pub fn embed_dim(&self) -> StellaEmbedDim { - match self { - Self::Dim256 => StellaEmbedDim::Dim256, - Self::Dim768 => StellaEmbedDim::Dim768, - Self::Dim1024 => StellaEmbedDim::Dim1024, - Self::Dim2048 => StellaEmbedDim::Dim2048, - Self::Dim4096 => StellaEmbedDim::Dim4096, - Self::Dim6144 => StellaEmbedDim::Dim6144, - Self::Dim8192 => StellaEmbedDim::Dim8192, - } - } -} - -#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] -pub enum EncodeTask { - /// `s2p` is the `retrieval` task - /// Default in this example - #[value(name = "s2p")] - S2P, - /// `s2s` is the semantic similarity task - #[value(name = "s2s")] - S2S, -} - -impl EncodeTask { - /// Preprocess a set of inputs basef on a template suggested by the model authors - /// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction - pub fn query_preproc(&self, txt: &[String]) -> Vec { - let instruct = match self { - Self::S2P => { - "Given a web search query, retrieve relevant passages that answer the query." - } - Self::S2S => "Retrieve semantically similar text.", - }; - - txt.iter() - .map(|s| format!("Instruct: {instruct}\nQuery: {s}")) - .collect::>() - } -} - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - /// Run on CPU rather than on GPU. - #[arg(long)] - cpu: bool, - - /// Enable tracing (generates a trace-timestamp.json file). - #[arg(long)] - tracing: bool, - - #[arg(long)] - use_flash_attn: bool, - - #[arg(long)] - query: Option, - - #[arg(long, default_value = "1024")] - embed_dim: Option, - - #[arg(long)] - tokenizer_file: Option, - - #[arg(long)] - base_weight_files: Option, - - #[arg(long)] - embed_head_weight_files: Option, - - /// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5) - /// `s2s`: Semantic textual similarity - /// `s2p`: Retrieval task - `Default` in this example - #[arg(long, default_value = "s2p")] - task: Option, -} - -// Tokenizer creation is super critical in our case. -// We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result { - let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - - tokenizer.with_padding( - Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - }) - ); - - Ok(tokenizer) -} - -fn main() -> Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - - let args = Args::parse(); - let _guard = if args.tracing { - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; - println!( - "avx: {}, neon: {}, simd128: {}, f16c: {}", - candle::utils::with_avx(), - candle::utils::with_neon(), - candle::utils::with_simd128(), - candle::utils::with_f16c() - ); - - let start = std::time::Instant::now(); - let api = Api::new()?; - let embed_dim = match args.embed_dim { - Some(d) => d, - None => EmbedDim::Dim1024, - }; - let repo = api.repo(Repo::model("dunzhang/stella_en_400M_v5".to_string())); - let tokenizer_filename = match args.tokenizer_file { - Some(file) => std::path::PathBuf::from(file), - None => repo.get("tokenizer.json")?, - }; - - // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights - // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo - let base_weight_files = match args.base_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), - None => { vec![repo.get("model.safetensors")?] } - }; - - let embed_weight_files = match args.embed_head_weight_files { - Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), - None => { - let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); - vec![repo.get(&head_w_path)?] - } - }; - - println!("retrieved the files in {:?}", start.elapsed()); - - // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; - - let start = std::time::Instant::now(); - - let device = candle_examples::device(args.cpu)?; - let dtype = DType::F32; - - let base_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? - }; - // Embedding layer is always built on F32 for accuracy - let embed_vb = unsafe { - VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? - }; - - let model = EmbeddingModel::new(&Config::new_400m(embed_dim.embed_dim()), base_vb, embed_vb)?; - - println!("loaded the model in {:?}", start.elapsed()); - - let mut embedding = Embedding::new(model, tokenizer, &device); - - let task = args.task.map_or(EncodeTask::S2P, |t| t); - - // let _doc_esmbeddings = embedding._encode_single_doc()?; - embedding.encode(task, args.query)?; - Ok(()) -} diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 68ed7e70c6..64239274e8 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -6,17 +6,19 @@ extern crate accelerate_src; use std::path::Path; -use anyhow::{anyhow, Error as E, Result}; +use anyhow::{ anyhow, Error as E, Result }; use clap::Parser; use candle_transformers::models::stella_en_v5::{ - Config, EmbedDim as StellaEmbedDim, EmbeddingModel, + Config, + EmbedDim as StellaEmbedDim, + EmbeddingModel, }; -use candle::{DType, Device, Tensor}; +use candle::{ DType, Device, Tensor }; use candle_nn::VarBuilder; -use hf_hub::{api::sync::Api, Repo}; -use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; +use hf_hub::{ api::sync::Api, Repo }; +use tokenizers::{ PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer }; struct Embedding { model: EmbeddingModel, @@ -59,13 +61,9 @@ impl Embedding { // We only encode the queries and not the data let qry = task.query_preproc(&queries); - let mut qry_encoded = self - .tokenizer - .encode_batch(qry, true) - .map_err(|e| anyhow!(e))?; + let mut qry_encoded = self.tokenizer.encode_batch(qry, true).map_err(|e| anyhow!(e))?; - let mut docs_encoded = self - .tokenizer + let mut docs_encoded = self.tokenizer .encode_batch(docs.to_vec(), true) .map_err(|e| anyhow!(e))?; @@ -76,14 +74,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in qry_encoded.drain(..).enumerate() { - let input_id = - Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = - ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -98,14 +96,14 @@ impl Embedding { let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; for (i, e) in docs_encoded.drain(..).enumerate() { - let input_id = - Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let input_id = Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze( + 0 + )?; let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? .to_dtype(DType::U8)? .unsqueeze(0)?; - ids = - ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + ids = ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; } @@ -268,25 +266,27 @@ fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { pad_id } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); + return Err(anyhow!("Tokenizer doesn't contain expected `<|endoftext|>` token")); }; // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + }) + ); } else { - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Right, - ..Default::default() - })); + tokenizer.with_padding( + Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + }) + ); } Ok(tokenizer) @@ -320,14 +320,8 @@ fn main() -> Result<()> { }; let (repo, cfg) = match args.which { - Which::Large => ( - "dunzhang/stella_en_1.5B_v5", - Config::new_1_5_b_v5(embed_dim.embed_dim()), - ), - Which::Small => ( - "dunzhang/stella_en_400M_v5", - Config::new_400_m_v5(embed_dim.embed_dim()), - ), + Which::Large => ("dunzhang/stella_en_1.5B_v5", Config::new_1_5_b_v5(embed_dim.embed_dim())), + Which::Small => ("dunzhang/stella_en_400M_v5", Config::new_400_m_v5(embed_dim.embed_dim())), }; let repo = api.repo(Repo::model(repo.to_string())); @@ -339,20 +333,12 @@ fn main() -> Result<()> { // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo let base_weight_files = match args.base_weight_files { - Some(files) => files - .split(',') - .map(std::path::PathBuf::from) - .collect::>(), - None => { - vec![repo.get("model.safetensors")?] - } + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), + None => { vec![repo.get("model.safetensors")?] } }; let embed_weight_files = match args.embed_head_weight_files { - Some(files) => files - .split(',') - .map(std::path::PathBuf::from) - .collect::>(), + Some(files) => files.split(',').map(std::path::PathBuf::from).collect::>(), None => { let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); vec![repo.get(&head_w_path)?] @@ -369,11 +355,13 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let dtype = DType::F32; - let base_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? }; + let base_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? + }; // Embedding layer is always built on F32 for accuracy - let embed_vb = - unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; + let embed_vb = unsafe { + VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? + }; let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8c0564db84..23edf349ad 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -85,7 +85,6 @@ pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; pub mod stella_en_v5; -pub mod stella_en_v5_400m; pub mod t5; pub mod trocr; pub mod vgg; diff --git a/candle-transformers/src/models/stella_en_v5_400m.rs b/candle-transformers/src/models/stella_en_v5_400m.rs deleted file mode 100644 index bbe16377d6..0000000000 --- a/candle-transformers/src/models/stella_en_v5_400m.rs +++ /dev/null @@ -1,914 +0,0 @@ -use candle::{ DType, Device, IndexOp, Module, Result, Tensor }; -use candle_nn::{ layer_norm, Activation, LayerNorm, VarBuilder }; -use std::sync::Arc; -use std::time::Instant; - -use super::with_tracing::{ linear, linear_no_bias, Linear }; - -#[derive(Debug, Clone, Copy)] -pub enum EmbedDim { - Dim256, - Dim768, - Dim1024, - Dim2048, - Dim4096, - Dim6144, - Dim8192, -} - -impl Default for EmbedDim { - fn default() -> Self { - Self::Dim1024 - } -} - -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -pub struct EmbedHead { - pub in_features: usize, - pub out_features: usize, -} - -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] -pub struct Config { - pub vocab_size: usize, - pub hidden_size: usize, - pub intermediate_size: usize, - pub num_hidden_layers: usize, - pub num_attention_heads: usize, - pub max_position_embeddings: usize, - pub type_vocab_size: usize, - // pub pad_token_id: usize, - // pub hidden_dropout_prob: f64, - // pub attention_probs_dropout_prob: f64, - pub norm_eps: f64, - // pub initializer_range: f64, - // pub position_embedding_type: String, - pub scaling_factor: f64, - pub rope_theta: f64, - // pub use_memory_efficient_attention: bool, - // pub unpad_inputs: bool, - // pub layer_norm_type: String, - // pub logn_attention_scale: bool, - // pub logn_attention_clip1: bool, - pub activation_fn: Activation, - pub embed_head: EmbedHead, -} - -impl Config { - pub fn new_400m(embed_dim: EmbedDim) -> Self { - let embed_head = EmbedHead { - in_features: 1024, - out_features: match embed_dim { - EmbedDim::Dim256 => 256, - EmbedDim::Dim768 => 768, - EmbedDim::Dim1024 => 1024, - EmbedDim::Dim2048 => 2048, - EmbedDim::Dim4096 => 4096, - EmbedDim::Dim6144 => 6144, - EmbedDim::Dim8192 => 8192, - }, - }; - - Self { - vocab_size: 30528, - hidden_size: 1024, - intermediate_size: 4096, - num_hidden_layers: 24, - num_attention_heads: 16, - max_position_embeddings: 8192, - type_vocab_size: 2, - // pad_token_id: 0, - // hidden_dropout_prob: 0.1, - // attention_probs_dropout_prob: 0.0, - norm_eps: 1e-12, - // initializer_range: 0.02, - // position_embedding_type: "rope".to_string(), - scaling_factor: 2.0, - rope_theta: 160000.0, - // use_memory_efficient_attention: true, - // unpad_inputs: false, - // layer_norm_type: "layer_norm".to_string(), - // logn_attention_scale: false, - // logn_attention_clip1: false, - activation_fn: Activation::Gelu, - embed_head, - } - } -} - -#[derive(Debug, Clone)] -struct RotaryEmbedding { - sin: Tensor, - cos: Tensor, - // _scaling_factor: f64, - // _mixed_b: Option, - // _dim: usize, - // _base: f64, -} - -impl RotaryEmbedding { - fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = cfg.max_position_embeddings; - let scaling_factor = cfg.scaling_factor; // Can be configured in Config if needed - let base = cfg.rope_theta; - - // Calculate scaled position embeddings - let scaled_max_seq_len = ((max_seq_len as f64) * scaling_factor) as usize; - - // Calculate inv_freq with NTK scaling - let inv_freq: Vec<_> = (0..dim / 2) - .map(|i| { - // Apply base scaling - let base = base * scaling_factor; - let freq = 1.0 / base.powf((2.0 * (i as f64)) / (dim as f64)); - - // Apply fixed NTK scaling - let freq = freq / scaling_factor.powf(2.0 / (dim as f64)); - - freq as f32 - }) - .collect(); - - let inv_freq_len = inv_freq.len(); - let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; - - // Calculate position embeddings with scaled sequence length - let t = Tensor::arange(0u32, scaled_max_seq_len as u32, dev)? - .to_dtype(dtype)? - .reshape((scaled_max_seq_len, 1))?; - - let freqs = t.matmul(&inv_freq)?; - let emb = Tensor::cat(&[&freqs, &freqs], 1)?; - - Ok(Self { - sin: emb.sin()?, - cos: emb.cos()?, - // _scaling_factor: scaling_factor, - // _mixed_b: None, - // _dim: dim, - // _base: base, - }) - } -} - -#[derive(Debug, Clone)] -enum NormType { - LayerNorm(candle_nn::LayerNorm), - // RmsNorm(RmsNorm), -} - -impl NormType { - fn forward(&self, x: &Tensor) -> Result { - match self { - Self::LayerNorm(ln) => ln.forward(x), - // Self::RmsNorm(rms) => rms.forward(x), - } - } -} - -#[derive(Debug)] -pub struct Embeddings { - word_embeddings: candle_nn::Embedding, - // position_embeddings: Option, - token_type_embeddings: candle_nn::Embedding, - layer_norm: LayerNorm, - // _padding_idx: usize, - // _position_embedding_type: String, - rotary_emb: Arc, - position_ids: Tensor, -} - -impl Embeddings { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let word_embeddings = candle_nn::embedding( - cfg.vocab_size, - cfg.hidden_size, - vb.pp("word_embeddings") - )?; - - // let position_embeddings = if cfg.position_embedding_type == "absolute" { - // Some( - // candle_nn::embedding( - // cfg.max_position_embeddings, - // cfg.hidden_size, - // vb.pp("position_embeddings") - // )? - // ) - // } else { - // None - // }; - - let token_type_embeddings = candle_nn::embedding( - cfg.type_vocab_size, - cfg.hidden_size, - vb.pp("token_type_embeddings") - )?; - // if cfg.type_vocab_size > 0 { - // Some( - // candle_nn::embedding( - // cfg.type_vocab_size, - // cfg.hidden_size, - // vb.pp("token_type_embeddings") - // )? - // ) - // } else { - // None - // }; - - //if cfg.layer_norm_type == "layer_norm" { - let weight = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "weight", candle_nn::Init::Const(1.0))?; - let bias = vb - .pp("LayerNorm") - .get_with_hints(cfg.hidden_size, "bias", candle_nn::Init::Const(0.0))?; - let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); - // } else { - // NormType::RmsNorm( - // RmsNorm::new(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))? - // ) - // }; - - // let rotary_emb = if cfg.position_embedding_type == "rope" { - // Some(Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?)) - // } else { - // None - // }; - let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); - - // let position_ids = if cfg.position_embedding_type == "absolute" { - // Some(Tensor::arange(0u32, cfg.max_position_embeddings as u32, vb.device())?) - // } else { - // None - // }; - let position_ids = Tensor::arange(0u32, cfg.max_position_embeddings as u32, word_embeddings.embeddings().device())?; - - Ok(Self { - word_embeddings, - // position_embeddings, - token_type_embeddings, - layer_norm, - // _padding_idx: cfg.pad_token_id, - // _position_embedding_type: cfg.position_embedding_type.clone(), - rotary_emb, - position_ids, - }) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - // token_type_ids: Option<&Tensor>, - // position_ids: Option<&Tensor>, - // inputs_embeds: Option<&Tensor>, - // unpad_inputs: bool, - // attention_mask: Option<&Tensor> - ) -> Result<(Tensor, Option<(Tensor, Tensor)>)> { - let (batch_size, seq_length) = input_ids.dims2()?; - - let mut embeddings = self.word_embeddings.forward(input_ids)?; - // match inputs_embeds { - // Some(e) => e.clone(), - // None => self.word_embeddings.forward(input_ids)?, - // }; - - // Get position_ids first - // let position_ids = if let Some(ids) = position_ids { - // ids.clone() - // } else { - // Get device from input_ids which is always available - // let device = input_ids.device(); - - // Initialize position_ids if None - // if self.position_ids.is_none() { - // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - // } - - // // Now check if we need to extend it - // if seq_length > self.position_ids.as_ref().unwrap().dim(0)? { - // self.position_ids = Some(Tensor::arange(0u32, seq_length as u32, device)?); - // } - - // let position_ids = - /*if unpad_inputs { - // For now, just use the same position IDs as padded case since we don't have lengths - self.position_ids - .as_ref() - .unwrap() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))? - } else {*/ - // self.position_ids - // .as_ref() - // .unwrap() - // .narrow(0, 0, seq_length)? - // .expand((batch_size, seq_length))?; - // }; - // }; - - let position_ids = self.position_ids - .as_ref() - .narrow(0, 0, seq_length)? - .expand((batch_size, seq_length))?; - - - let rope_embeds = { - // Get the cos and sin for this sequence length - let cos = self.rotary_emb.cos.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - let sin = self.rotary_emb.sin.narrow(0, 0, seq_length)?; // [seq_len, head_dim] - - // Index using position_ids if needed - let position_ids = position_ids.flatten_all()?; - let cos = cos.index_select(&position_ids, 0)?; // Use index_select instead of i() - let sin = sin.index_select(&position_ids, 0)?; // Use index_select instead of i() - - Some((cos, sin)) - }; - // // Get rotary embeddings if using RoPE - // let rope_embeds = if let Some(rotary) = &self.rotary_emb { - - // } else { - // None - // }; - - // Handle token type embeddings - embeddings = embeddings.add(&self.token_type_embeddings.forward(&position_ids.zeros_like()?)?).unwrap(); - // if let Some(token_emb) = &self.token_type_embeddings { - // let token_type_ids = if let Some(ids) = token_type_ids { - // ids.clone() - // } else { - // position_ids.zeros_like()? // Use mul(0) equivalent - // }; - // if unpad_inputs { - // todo!("Implement unpadded case"); - // } else { - // embeddings = embeddings.add(&token_emb.forward(&position_ids.zeros_like()?)?).unwrap(); - // } - // } - - // Handle absolute position embeddings - // if let Some(pos_emb) = &self.position_embeddings { - // let position_embeddings = pos_emb.forward(&position_ids)?; - // embeddings = embeddings.add(&position_embeddings)?; - // } - - let embeddings = self.layer_norm.forward(&embeddings)?; - - Ok((embeddings, rope_embeds)) - } -} - -#[derive(Debug)] -struct NewAttention { - qkv_proj: Linear, - o_proj: Linear, - num_heads: usize, - head_dim: usize, - hidden_size: usize, - // _use_memory_efficient_attention: bool, -} - -impl NewAttention { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let hidden_sz = cfg.hidden_size; - let num_heads = cfg.num_attention_heads; - let head_dim = hidden_sz / num_heads; - - let qkv_proj = linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?; - let o_proj = linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; - - Ok(Self { - qkv_proj, - o_proj, - num_heads, - head_dim, - hidden_size: hidden_sz, - // _use_memory_efficient_attention: cfg.use_memory_efficient_attention, - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // _attention_scale: Option<&Tensor> - ) -> Result { - let (b_sz, seq_len, _) = hidden_states.dims3()?; - - // QKV projection - let qkv = self.qkv_proj.forward(hidden_states)?; - - // Split into Q, K, V and reshape to match PyTorch shapes - let qkv = qkv.reshape((b_sz, seq_len, 3, self.num_heads, self.head_dim))?; - - // Get Q, K, V with shape [batch, seq_len, num_heads, head_dim] - let query_states = qkv.i((.., .., 0, .., ..))?.contiguous()?; - let key_states = qkv.i((.., .., 1, .., ..))?.contiguous()?; - let value_states = qkv.i((.., .., 2, .., ..))?.contiguous()?; - - // Apply RoPE if provided - let (query_states, key_states) = if let Some((cos, sin)) = rope_embeds { - apply_rotary_pos_emb(&query_states, &key_states, cos, sin)? - } else { - (query_states, key_states) - }; - - // Transpose for attention computation [batch, seq_len, num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim] - let query_states = query_states.transpose(1, 2)?.contiguous()?; - let key_states = key_states.transpose(1, 2)?.contiguous()?; - let value_states = value_states.transpose(1, 2)?.contiguous()?; - - // For key, we want to transpose the last two dimensions for the matmul - // Is this equivalent to PyTorch's transpose(-1, -2)? - let key_states_t = key_states.transpose(2, 3)?.contiguous()?; - - // Prepare tensors for batched matmul using matmul - // Reshape tensors to merge batch and head dimensions - let bsz = b_sz; - let nh = self.num_heads; - let s_len = seq_len; - let h_dim = self.head_dim; - - // Reshape tensors to [batch_size * num_heads, seq_len, head_dim] - let query_states_reshaped = query_states.reshape((bsz * nh, s_len, h_dim))?; - let key_states_t_reshaped = key_states_t.reshape((bsz * nh, h_dim, s_len))?; - - // Perform batched matmul using matmul - // The matmul should handle batch dimensions if tensors are 3D - let attn_weights = query_states_reshaped.matmul(&key_states_t_reshaped)?; - - // Reshape attn_weights back to [batch_size, num_heads, seq_len, seq_len] - let attn_weights = attn_weights.reshape((bsz, nh, s_len, s_len))?; - - // Scale attention scores - let scale = 1f32 / (self.head_dim as f32).sqrt(); - - let scale_tensor = Tensor::new(scale, attn_weights.device())? - .to_dtype(attn_weights.dtype())? - .broadcast_as(attn_weights.shape())?; - - // Multiply the attention weights by the scalar tensor - let attn_weights = attn_weights.mul(&scale_tensor)?; - - // Apply attention mask - let mut attn_weights = if let Some(bias) = attention_bias { - // let attn_weights = attn_weights.broadcast_add(bias)?; - attn_weights.broadcast_add(bias)? - } else { - attn_weights - }; - - // Normalize attention scores - attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; - - // Reshape value_states for batched matmul - let value_states_reshaped = value_states.reshape((bsz * nh, s_len, h_dim))?; - - // Reshape attn_weights to [batch_size * num_heads, seq_len, seq_len] - let attn_weights_reshaped = attn_weights.reshape((bsz * nh, s_len, s_len))?; - - // Compute attention output - let attn_output = attn_weights_reshaped.matmul(&value_states_reshaped)?; - - // Reshape attn_output back to [batch_size, num_heads, seq_len, head_dim] - let attn_output = attn_output.reshape((bsz, nh, s_len, h_dim))?; - - // Transpose back to [batch_size, seq_len, num_heads, head_dim] - let attn_output = attn_output.transpose(1, 2)?; - - // Project to final dimension - let attn_output = attn_output.reshape((b_sz, seq_len, self.hidden_size))?; - self.o_proj.forward(&attn_output) - } -} - -#[derive(Debug)] -struct NewGatedMLP { - up_gate_proj: Linear, - down_proj: Linear, - act_fn: Activation, -} - -impl NewGatedMLP { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let hidden_sz = cfg.hidden_size; - let intermediate_size = cfg.intermediate_size; - let act_fn = cfg.activation_fn; - let up_gate_proj = linear_no_bias(hidden_sz, intermediate_size * 2, vb.pp("up_gate_proj"))?; - let down_proj = linear(intermediate_size, hidden_sz, vb.pp("down_proj"))?; - - Ok(Self { - up_gate_proj, - down_proj, - act_fn, - }) - } -} - -impl Module for NewGatedMLP { - fn forward(&self, xs: &Tensor) -> Result { - let up_gate = self.up_gate_proj.forward(xs)?; - - // Get the dimensions - let (_batch_size, _seq_len, hidden_dim) = up_gate.dims3()?; - let split_size = hidden_dim / 2; - - // Split along the last dimension (hidden_dim) - let up_states = up_gate.narrow(2, 0, split_size)?; - let gate = up_gate.narrow(2, split_size, split_size)?; - - // Apply activation to gate and multiply - let gate = gate.apply(&self.act_fn)?; - - let gated_states = up_states.mul(&gate)?; - - // Project back to hidden dimension - let output = self.down_proj.forward(&gated_states)?; - - Ok(output) - } -} - -#[derive(Debug)] -struct NewLayer { - attention: NewAttention, - mlp: NewGatedMLP, - attn_ln: NormType, - mlp_ln: NormType, -} - -impl NewLayer { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let attention = NewAttention::new(cfg, vb.pp("attention"))?; - let mlp = NewGatedMLP::new(cfg, vb.pp("mlp"))?; - - // let ln_eps = cfg.layer_norm_eps; - - // Use LayerNorm or RmsNorm based on config - let (attn_ln, mlp_ln) = { - let attn_ln = layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("attn_ln") - )?; - let mlp_ln = layer_norm( - cfg.hidden_size, - candle_nn::LayerNormConfig { eps: cfg.norm_eps, ..Default::default() }, - vb.pp("mlp_ln") - )?; - (NormType::LayerNorm(attn_ln), NormType::LayerNorm(mlp_ln)) - }; - // else - // { - // let attn_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("attn_ln"))?; - // let mlp_ln = RmsNorm::new(cfg.hidden_size, ln_eps, vb.pp("mlp_ln"))?; - // (NormType::RmsNorm(attn_ln), NormType::RmsNorm(mlp_ln)) - // }; - - Ok(Self { - attention, - mlp, - attn_ln, - mlp_ln, - }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // attention_scale: Option<&Tensor> - ) -> Result { - // Store original input - let original = hidden_states; - - // Use normalized states for attention - let hidden_states = self.attention.forward( - original, - attention_bias, - rope_embeds, - // attention_scale - )?; - - let hidden_states = original.add(&hidden_states)?; - - // Apply layer norm - let hidden_states = self.attn_ln.forward(&hidden_states)?; - - // Store residual - let residual = &hidden_states; - - // Pass through MLP - let hidden_states = self.mlp.forward(&hidden_states)?; - - // Add residual connection - let hidden_states = residual.add(&hidden_states)?; - - // Final layer norm - self.mlp_ln.forward(&hidden_states) - } -} - -#[derive(Debug)] -struct NewEncoder { - layers: Vec, -} - -impl NewEncoder { - fn new(cfg: &Config, vb: VarBuilder) -> Result { - let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb.pp("layer"); - for layer_idx in 0..cfg.num_hidden_layers { - layers.push(NewLayer::new(cfg, vb_l.pp(layer_idx))?); - } - Ok(Self { layers }) - } - - fn forward( - &self, - hidden_states: &Tensor, - attention_bias: Option<&Tensor>, - rope_embeds: Option<&(Tensor, Tensor)>, - // attention_scale: Option<&Tensor> - ) -> Result { - let mut hidden_states = hidden_states.clone(); - - for layer in self.layers.iter() { - hidden_states = layer.forward( - &hidden_states, - attention_bias, - rope_embeds, - // attention_scale - )?; - } - - Ok(hidden_states) - } -} - -#[derive(Debug)] -pub struct NewModel { - embeddings: Embeddings, - encoder: NewEncoder, - device: Device, - dtype: DType, - // config: Config, -} - -impl NewModel { - pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vb_m = vb.pp("new"); - let embeddings = Embeddings::new(cfg, vb_m.pp("embeddings"))?; - let encoder = NewEncoder::new(cfg, vb_m.pp("encoder"))?; - Ok(Self { - embeddings, - encoder, - device: vb.device().clone(), - dtype: vb.dtype(), - // config: cfg.clone(), - }) - } - - fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result { - let (b_sz, seq_len) = attn_mask.dims2()?; - let mask = attn_mask - .unsqueeze(1) - ? // [b_sz, 1, seq_len] - .unsqueeze(2) - ? // [b_sz, 1, 1, seq_len] - .broadcast_as((b_sz, 1, 1, seq_len))?; // [b_sz, 1, 1, seq_len] - - // Use a large negative value for mask instead of -0.0 - let on_true = mask.zeros_like()?.to_dtype(self.dtype)?; - let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)? - .broadcast_as(mask.shape())? - .to_dtype(self.dtype)?; - - mask.where_cond(&on_true, &on_false) - } - - pub fn forward( - &mut self, - input_ids: &Tensor, - attention_mask: &Tensor, - // token_type_ids: Option<&Tensor>, - // position_ids: Option<&Tensor> - ) -> Result { - let (_, seq_length) = input_ids.dims2()?; - - // Get attention mask if not provided - // let attention_mask = mask; - - // Prepare attention bias - let attention_bias = if seq_length <= 1 { - None - } else { - Some(self.prepare_attention_mask(attention_mask)?) - }; - - // Get embeddings and rotary embeddings - let (hidden_states, rope_embeds) = self.embeddings.forward( - input_ids, - // token_type_ids, - // position_ids, - // None, - // self.config.unpad_inputs, - // Some(&attention_mask) - )?; - - // Compute attention scale if needed - // let attention_scale = if self.config.logn_attention_scale { - // let scale = - // attention_mask.sum_keepdim(1)?.log()? / - // (self.config.max_position_embeddings as f64).ln(); - // if self.config.logn_attention_clip1 { - // let scale = scale?; - // Some(scale.maximum(&Tensor::new(1f64, &self.device)?)?) - // } else { - // Some(scale?) - // } - // } else { - // None - // }; - - // Forward through encoder - let hidden_states = self.encoder.forward( - &hidden_states, - attention_bias.as_ref(), - rope_embeds.as_ref(), - // attention_scale.as_ref() - )?; - - Ok(hidden_states) - } -} - -// Optional pooler implementation -// #[derive(Debug)] -// pub struct NewPooler { -// dense: Linear, -// } - -// impl NewPooler { -// pub fn new(cfg: &Config, vb: VarBuilder) -> Result { -// let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; -// Ok(Self { dense }) -// } - -// pub fn forward(&self, hidden_states: &Tensor) -> Result { -// let first_token = hidden_states.i((.., 0, ..))?; -// let pooled = self.dense.forward(&first_token)?; -// pooled.tanh() -// } -// } - -// // Complete model with pooler -// #[derive(Debug)] -// pub struct NewModelWithPooler { -// model: NewModel, -// pooler: Option, -// } - -// impl NewModelWithPooler { -// pub fn new(cfg: &Config, vb: VarBuilder, add_pooling_layer: bool) -> Result { -// let vb_m = vb.pp("new"); -// let model = NewModel::new(cfg, vb_m.pp("model"))?; -// let pooler = if add_pooling_layer { -// Some(NewPooler::new(cfg, vb.pp("new").pp("pooler"))?) -// } else { -// None -// }; -// Ok(Self { model, pooler }) -// } - -// pub fn forward( -// &mut self, -// input_ids: &Tensor, -// attention_mask: Option<&Tensor>, -// token_type_ids: Option<&Tensor>, -// position_ids: Option<&Tensor> -// ) -> Result<(Tensor, Option)> { -// let hidden_states = self.model.forward( -// input_ids, -// attention_mask, -// token_type_ids, -// position_ids -// )?; - -// let pooled_output = match &self.pooler { -// Some(pooler) => Some(pooler.forward(&hidden_states)?), -// None => None, -// }; - -// Ok((hidden_states, pooled_output)) -// } -// } - -#[derive(Debug)] -pub struct EmbeddingModel { - base_model: NewModel, - lm_head: Linear, -} - -impl EmbeddingModel { - pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result { - let base_model = NewModel::new(cfg, base_vb)?; - let lm_head = linear( - cfg.embed_head.in_features, - cfg.embed_head.out_features, - embed_vb.pp("linear") - )?; - - Ok(Self { - base_model, - lm_head, - }) - } - - pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.base_model.forward(input_ids, mask)?;//, None, None)?; - let x = self.pool(&x, mask)?; - self.lm_head.forward(&x.to_dtype(DType::F32)?) - } - - pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { - let x = self.forward(input_ids, mask)?; - x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?) - } - fn pool(&self, x: &Tensor, mask: &Tensor) -> Result { - let mask = mask.to_dtype(x.dtype())?; - let (batch_size, seq_len, hidden_dim) = x.dims3()?; - let mask_expanded = mask.unsqueeze(2)?.broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim] - let x = x.mul(&mask_expanded)?; - let sum_mask = mask.sum(1)?.unsqueeze(1)?.expand((batch_size, hidden_dim))?; - x.sum(1)? / sum_mask - } -} - -pub fn time_run(f: F) -> (T, std::time::Duration) where F: FnOnce() -> T { - let start = Instant::now(); - let result = f(); - let duration = start.elapsed(); - (result, duration) -} - -fn apply_rotary_pos_emb( - q: &Tensor, - k: &Tensor, - cos: &Tensor, - sin: &Tensor -) -> Result<(Tensor, Tensor)> { - let cos = cos.to_dtype(q.dtype())?; - let sin = sin.to_dtype(q.dtype())?; - - let (batch_size, seq_len, num_heads, head_dim) = q.dims4()?; - let half_dim = head_dim / 2; - - // Reshape q and k to split the head dim for rotation - let q_split = q.chunk(2, 3)?; // Split along head_dim - let k_split = k.chunk(2, 3)?; - - let q1 = &q_split[0]; - let q2 = &q_split[1]; - let k1 = &k_split[0]; - let k2 = &k_split[1]; - - // Handle cos/sin for the sequence length we have - let cos = cos.narrow(0, 0, seq_len)?; - let sin = sin.narrow(0, 0, seq_len)?; - - // Reshape cos/sin to match the dimensions we need - let cos = cos - .reshape((seq_len, head_dim))? - .chunk(2, 1)? - [0].reshape((seq_len, 1, half_dim))? - .broadcast_as((seq_len, num_heads, half_dim))? - .unsqueeze(0)? - .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; - - let sin = sin - .reshape((seq_len, head_dim))? - .chunk(2, 1)? - [0].reshape((seq_len, 1, half_dim))? - .broadcast_as((seq_len, num_heads, half_dim))? - .unsqueeze(0)? - .broadcast_as((batch_size, seq_len, num_heads, half_dim))?; - - // Apply rotation using the formulas: - // q = q * cos + rotate_half(q) * sin - // k = k * cos + rotate_half(k) * sin - let q_out = Tensor::cat( - &[&q1.mul(&cos)?.sub(&q2.mul(&sin)?)?, &q2.mul(&cos)?.add(&q1.mul(&sin)?)?], - 3 - )?; - - let k_out = Tensor::cat( - &[&k1.mul(&cos)?.sub(&k2.mul(&sin)?)?, &k2.mul(&cos)?.add(&k1.mul(&sin)?)?], - 3 - )?; - - Ok((q_out, k_out)) -}