From cae59157a1752f5b63094357b13c08e5250fd0c4 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Mon, 7 Oct 2024 17:54:01 +0530 Subject: [PATCH 1/6] Stella_en_1.5B_v5 --- .../examples/stella-en-v5/README.md | 33 ++ candle-examples/examples/stella-en-v5/main.rs | 338 +++++++++++++++ candle-transformers/src/models/mod.rs | 1 + .../src/models/stella_en_v5.rs | 398 ++++++++++++++++++ 4 files changed, 770 insertions(+) create mode 100644 candle-examples/examples/stella-en-v5/README.md create mode 100644 candle-examples/examples/stella-en-v5/main.rs create mode 100644 candle-transformers/src/models/stella_en_v5.rs diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md new file mode 100644 index 0000000000..90f55ba894 --- /dev/null +++ b/candle-examples/examples/stella-en-v5/README.md @@ -0,0 +1,33 @@ +# candle-qwen: large language model series from Alibaba Cloud + +Qwen 1.5 is a series of large language models that provide strong performances +on English and Chinese. + +- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5. +- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub. +- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the + mixture-of-experts (MoE) variant. + +## Running the example + +Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" + +> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]] +> Tensor[[1, 1024], f32] +``` + +Various model sizes are available via the `--model` argument, including the MoE +variant. + +```bash +$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): ' +def print_prime(n: int): # n is the number of primes to be printed + for i in range(2, n + 1): + if all(i % j != 0 for j in range(2, i)): + print(i) +``` + diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs new file mode 100644 index 0000000000..bfa750e0ab --- /dev/null +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -0,0 +1,338 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{anyhow, Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::stella_en_v5::{ + Config, EmbedDim as StellaEmbedDim, EmbeddingModel, +}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use hf_hub::{api::sync::Api, Repo}; +use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; + +struct Embedding { + model: EmbeddingModel, + device: Device, + tokenizer: Tokenizer, +} + +impl Embedding { + fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self { + Self { + model, + tokenizer, + device: device.clone(), + } + } + + fn encode(&mut self, task: EncodeTask, text: Option) -> Result<()> { + // Just shocasing embeddings, this has no real value + if let Some(text) = text { + let qry = task.query_preproc(&[text]); + let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?; + + let shape = (1, encoding.len()); + let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; + let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; + + let result = self.model.forward(&input, &mask)?; + println!("embeddings: {result}"); + } else { + // Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers) + let queries = [ + "What are some ways to reduce stress?".to_string(), + "What are the benefits of drinking green tea?".to_string(), + ]; + + let docs = [ + "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(), + "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(), + ]; + + // We only encode the queries and not the data + let qry = task.query_preproc(&queries); + let mut qry_encoded = self + .tokenizer + .encode_batch(qry, true) + .map_err(|e| anyhow!(e))?; + + let mut docs_encoded = self + .tokenizer + .encode_batch(docs.to_vec(), true) + .map_err(|e| anyhow!(e))?; + + let qry_embed = { + // Now, we generate the tensors for the `input` and `mask` + let shape = (qry_encoded.len(), qry_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in qry_encoded.drain(..).enumerate() { + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + + let doc_embed = { + let shape = (docs_encoded.len(), docs_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in docs_encoded.drain(..).enumerate() { + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + + println!( + "Embed shapes:\nQuery: {:?}\nDocs: {:?}", + qry_embed.shape(), + doc_embed.shape() + ); // [2, 1024] for head dim `1024` + // a matmul to generate the `similarity` score + let res = qry_embed.matmul(&doc_embed.t()?)?; + println!("Similarity: {res}"); + } + + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum EmbedDim { + #[value(name = "256")] + Dim256, + #[value(name = "768")] + Dim768, + #[value(name = "1024")] + Dim1024, + #[value(name = "2048")] + Dim2048, + #[value(name = "4096")] + Dim4096, + #[value(name = "6144")] + Dim6144, + #[value(name = "8192")] + Dim8192, +} + +impl EmbedDim { + /// Returns dir path to the embed head weights int he repo + pub fn embed_dim_default_dir(&self) -> &'static str { + match self { + Self::Dim256 => "2_Dense_256", + Self::Dim768 => "2_Dense_768", + Self::Dim1024 => "2_Dense_1024", + Self::Dim2048 => "2_Dense_2048", + Self::Dim4096 => "2_Dense_4096", + Self::Dim6144 => "2_Dense_6144", + Self::Dim8192 => "2_Dense_8192", + } + } + + /// Resolves the `EmbedDim` for given variant + pub fn embed_dim(&self) -> StellaEmbedDim { + match self { + Self::Dim256 => StellaEmbedDim::Dim256, + Self::Dim768 => StellaEmbedDim::Dim768, + Self::Dim1024 => StellaEmbedDim::Dim1024, + Self::Dim2048 => StellaEmbedDim::Dim2048, + Self::Dim4096 => StellaEmbedDim::Dim4096, + Self::Dim6144 => StellaEmbedDim::Dim6144, + Self::Dim8192 => StellaEmbedDim::Dim8192, + } + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +pub enum EncodeTask { + /// `s2p` is the `retrieval` task + /// Default in this example + #[value(name = "s2p")] + S2P, + /// `s2s` is the semantic similarity task + #[value(name = "s2s")] + S2S, +} + +impl EncodeTask { + /// Preprocess a set of inputs basef on a template suggested by the model authors + /// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction + pub fn query_preproc(&self, txt: &[String]) -> Vec { + let instruct = match self { + Self::S2P => { + "Given a web search query, retrieve relevant passages that answer the query." + } + Self::S2S => "Retrieve semantically similar text.", + }; + + txt.iter() + .map(|s| format!("Instruct: {instruct}\nQuery: {s}")) + .collect::>() + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + query: Option, + + #[arg(long, default_value = "1024")] + embed_dim: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + base_weight_files: Option, + + #[arg(long)] + embed_head_weight_files: Option, + + /// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5) + /// `s2s`: Semantic textual similarity + /// `s2p`: Retrieval task - `Default` in this example + #[arg(long, default_value = "s2p")] + task: Option, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let embed_dim = match args.embed_dim { + Some(d) => d, + None => EmbedDim::Dim1024, + }; + let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights + // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo + let base_weight_files = match args.base_weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + + let embed_weight_files = match args.embed_head_weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); + vec![repo.get(&head_w_path)?] + } + }; + + println!("retrieved the files in {:?}", start.elapsed()); + + // Initializing the tokenizer which would require us to add padding to the right for batch encoding + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + + let start = std::time::Instant::now(); + + let device = candle_examples::device(args.cpu)?; + let dtype = DType::F32; + + let base_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? }; + // Embedding layer is always built on F32 for accuracy + let embed_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; + + let model = EmbeddingModel::new( + &Config::new_1_5_b_v5(embed_dim.embed_dim()), + base_vb, + embed_vb, + )?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut embedding = Embedding::new(model, tokenizer, &device); + + let task = args.task.map_or(EncodeTask::S2P, |t| t); + + embedding.encode(task, args.query) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 80cd4f810c..bd99eadff1 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -83,6 +83,7 @@ pub mod siglip; pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; +pub mod stella_en_v5; pub mod t5; pub mod trocr; pub mod vgg; diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs new file mode 100644 index 0000000000..1a0c53e3ad --- /dev/null +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -0,0 +1,398 @@ +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +// Same as `qwen2` family of models with the exception being the `embed_head` +// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head` +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub hidden_act: Activation, + pub embed_head: EmbedHead, +} + +// Excerpt from `stella` model card: +// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions +// Embed head represents the config for various embedding dims supported +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct EmbedHead { + pub in_features: usize, + pub out_features: usize, +} + +/// An enum variant representing the Embedding head dimensions `stella` is trained on +/// As the [model-card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction) suggests, D1024 is good enough for most cases +pub enum EmbedDim { + Dim256, + Dim768, + Dim1024, + Dim2048, + Dim4096, + Dim6144, + Dim8192, +} + +impl Default for EmbedDim { + fn default() -> Self { + Self::Dim1024 + } +} + +impl EmbedDim { + pub fn config(&self) -> EmbedHead { + EmbedHead { + in_features: 1536, + out_features: match &self { + Self::Dim256 => 256, + Self::Dim768 => 768, + Self::Dim1024 => 1024, + Self::Dim2048 => 2048, + Self::Dim4096 => 4096, + Self::Dim6144 => 6144, + Self::Dim8192 => 8192, + }, + } + } +} + +// Initialize a new `stella_en` model - with 400M variant or 1.5B variant +impl Config { + /// Initialize a new `stella_en_1.5B_v5`` model with given embedding dim + pub fn new_1_5_b_v5(embed_dim: EmbedDim) -> Self { + // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json + // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here + Self { + hidden_act: candle_nn::Activation::Silu, + vocab_size: 151646, + hidden_size: 1536, + intermediate_size: 8960, + num_hidden_layers: 28, + num_attention_heads: 12, + num_key_value_heads: 2, + max_position_embeddings: 131072, + max_window_layers: 21, + tie_word_embeddings: false, + rope_theta: 1000000., + rms_norm_eps: 1e-06, + embed_head: embed_dim.config(), + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, 0, seq_len)?; + let sin = self.sin.narrow(0, 0, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + }) + } + + fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = self + .rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states)?; + + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + Ok(Self { + embed_tokens, + layers, + norm, + // sliding_window: 0, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result { + let (b_sz, sql_len) = attn_mask.dims2()?; + let mut mask: Vec = vec![]; + for b in 0..b_sz { + mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?); + } + let mask = Tensor::cat(&mask, 0)?; + let on_true = mask.zeros_like()?.to_dtype(self.dtype)?; + let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)? + .broadcast_as(mask.shape())? + .to_dtype(self.dtype)?; + mask.where_cond(&on_true, &on_false) + } + + pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let (_, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + // This is not a `causal language modelling` task, we'll need to prepare a `non-causal` attention + Some(self.prepare_attention_mask(mask)?) + }; + + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref())? + } + xs.apply(&self.norm) + } +} + +#[derive(Debug, Clone)] +pub struct EmbeddingModel { + base_model: Model, + lm_head: Linear, +} + +impl EmbeddingModel { + pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result { + let base_model = Model::new(cfg, base_vb.clone())?; + let lm_head = linear( + cfg.embed_head.in_features, + cfg.embed_head.out_features, + embed_vb.pp("linear"), + )?; + + Ok(Self { + base_model, + lm_head, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.base_model.forward(input_ids, mask)?; + let x = self.pool(&x, mask)?; + + // No matter what keeping the final activations as F32 helps with the accuracy + self.lm_head.forward(&x.to_dtype(DType::F32)?) // [B_sz, dim_size] + } + + /// Same as forward pass but normalizes the output + pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.forward(input_ids, mask)?; + // Normalize + x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?) + } + + fn pool(&self, x: &Tensor, mask: &Tensor) -> Result { + let mask = mask.to_dtype(x.dtype())?; // [B_Sz, Seq_len] + let (batch_size, seq_len, hidden_dim) = x.dims3()?; + // expanding the shape of the mask from [B_Sz, Seq_len] -> [B_Sz, Seq_len, Hidden_size] + let mask_expanded = mask + .unsqueeze(2)? + .broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim] + + let x = (x * &mask_expanded)?; + + // Sum + let sum_mask = mask + .sum(1)? + .unsqueeze(1)? + .expand((batch_size, hidden_dim))?; + x.sum(1)? / sum_mask + } +} From 18edaeb685cadd468a17ce89198d53033cef4aaf Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Mon, 7 Oct 2024 18:05:00 +0530 Subject: [PATCH 2/6] Separated creation. This is a critical step for numerical accuracy and would be documented in the readme --- candle-examples/examples/stella-en-v5/main.rs | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index bfa750e0ab..dd41dae7ab 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -4,6 +4,8 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; +use std::path::Path; + use anyhow::{anyhow, Error as E, Result}; use clap::Parser; @@ -236,6 +238,30 @@ struct Args { task: Option, } +// Tokenizer creation is super critical in our case. +// We are going to be `padding: Left` for each batch +fn create_tokenizer(tokenizer_file: &Path) -> Result { + let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + + Ok(tokenizer) +} + fn main() -> Result<()> { use tracing_chrome::ChromeLayerBuilder; use tracing_subscriber::prelude::*; @@ -293,23 +319,8 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); - // Initializing the tokenizer which would require us to add padding to the right for batch encoding - let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { - pad_id - } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); - }; - - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding + let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; let start = std::time::Instant::now(); From 4155377135cbef8444fb28e15ac8e4087fcc19da Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Mon, 7 Oct 2024 19:16:10 +0530 Subject: [PATCH 3/6] EmbedDim would require clone and copy --- candle-transformers/src/models/stella_en_v5.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 1a0c53e3ad..9d933fade5 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -33,6 +33,7 @@ pub struct EmbedHead { /// An enum variant representing the Embedding head dimensions `stella` is trained on /// As the [model-card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction) suggests, D1024 is good enough for most cases +#[derive(Debug, Clone, Copy)] pub enum EmbedDim { Dim256, Dim768, From 4a33e8a5435d411c89e15fc689c18c2c765da032 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 8 Oct 2024 00:27:07 +0530 Subject: [PATCH 4/6] WIP: example --- .../examples/stella-en-v5/README.md | 24 +++++++------------ candle-examples/examples/stella-en-v5/main.rs | 7 +++++- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index 90f55ba894..d5ff9dcefb 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -1,12 +1,8 @@ -# candle-qwen: large language model series from Alibaba Cloud +# candle-stella-en-v5: Implementation of [stella_en_1.5B_v5](https://huggingface.co/dunzhang/stella_en_1.5B_v5) embedding model -Qwen 1.5 is a series of large language models that provide strong performances -on English and Chinese. +As of 7th Oct 2024, *Stella_en_1.5B_v5* is one of the top ranking model on `retrieval` and `reranking` tasks in [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard. -- [Blog post](https://qwenlm.github.io/blog/qwen1.5/) introducing Qwen1.5. -- [Model card](https://huggingface.co/Qwen/Qwen1.5-0.5B) on the HuggingFace Hub. -- [Blog post](https://qwenlm.github.io/blog/qwen-moe/) for the - mixture-of-experts (MoE) variant. +[Model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) on the HuggingFace Hub. ## Running the example @@ -20,14 +16,10 @@ $ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" > Tensor[[1, 1024], f32] ``` -Various model sizes are available via the `--model` argument, including the MoE -variant. +Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions. -```bash -$ cargo run --example qwen --release -- --model moe-a2.7b --prompt 'def print_prime(n: int): ' -def print_prime(n: int): # n is the number of primes to be printed - for i in range(2, n + 1): - if all(i % j != 0 for j in range(2, i)): - print(i) -``` +The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). +```bash +$ cargo run --example stella-en-v5 --release --features +``` \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index dd41dae7ab..d2dc857b30 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -121,7 +121,12 @@ impl Embedding { ); // [2, 1024] for head dim `1024` // a matmul to generate the `similarity` score let res = qry_embed.matmul(&doc_embed.t()?)?; - println!("Similarity: {res}"); + for (k, v) in queries.iter().enumerate() { + let tnsr = res.get(k)?; + let max = tnsr.argmax(0)?; + println!("Score: {}\tQuery: {}\tAnswer: {}", ) + } + } Ok(()) From f5b5d559cc56aedaf1673a3c7c0bca32e7d2f7d9 Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 8 Oct 2024 00:43:46 +0530 Subject: [PATCH 5/6] Examples added --- candle-examples/examples/stella-en-v5/README.md | 17 ++++++++++++++++- candle-examples/examples/stella-en-v5/main.rs | 13 +++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index d5ff9dcefb..ff4314fd54 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -18,8 +18,23 @@ $ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions. -The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). +The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. ```bash $ cargo run --example stella-en-v5 --release --features + +> +> Score: 0.8178786 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> Score: 0.7853528 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> ``` \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index d2dc857b30..2408262b1a 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -119,14 +119,19 @@ impl Embedding { qry_embed.shape(), doc_embed.shape() ); // [2, 1024] for head dim `1024` - // a matmul to generate the `similarity` score + + // a matmul to generate the `similarity` score let res = qry_embed.matmul(&doc_embed.t()?)?; for (k, v) in queries.iter().enumerate() { let tnsr = res.get(k)?; - let max = tnsr.argmax(0)?; - println!("Score: {}\tQuery: {}\tAnswer: {}", ) + let max = tnsr.argmax(0)?.to_scalar::()?; + println!( + "\nScore: {}\nQuery: {}\nAnswer: {}\n\n", + tnsr.get(max as usize)?.to_scalar::()?, + v, + docs[k] + ); } - } Ok(()) From d74695a47cdee8a8407ac5b47185ea94d9eb334f Mon Sep 17 00:00:00 2001 From: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:31:55 +0530 Subject: [PATCH 6/6] a litte more in README --- candle-examples/examples/stella-en-v5/README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index ff4314fd54..5fcc67c351 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -37,4 +37,9 @@ $ cargo run --example stella-en-v5 --release --features > caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > > of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. > -``` \ No newline at end of file +``` + +## Supported options: +- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. + +- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file