diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md index 5fcc67c351..3a87b2956a 100644 --- a/candle-examples/examples/stella-en-v5/README.md +++ b/candle-examples/examples/stella-en-v5/README.md @@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. ```bash -$ cargo run --example stella-en-v5 --release --features +$ cargo run --example stella-en-v5 --release --features -- --which 1.5b > > Score: 0.8178786 @@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features > caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > > of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. > + +$ cargo run --example stella-en-v5 --release --features -- --which 400m + +> +> Score: 0.8397539 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> +> Score: 0.809545 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> ``` ## Supported options: -- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. +- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`. + +- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. - As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs index 2408262b1a..68ed7e70c6 100644 --- a/candle-examples/examples/stella-en-v5/main.rs +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -212,6 +212,14 @@ impl EncodeTask { } } +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "1.5b")] + Large, + #[value(name = "400m")] + Small, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -219,6 +227,9 @@ struct Args { #[arg(long)] cpu: bool, + #[arg(long)] + which: Which, + /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -250,24 +261,33 @@ struct Args { // Tokenizer creation is super critical in our case. // We are going to be `padding: Left` for each batch -fn create_tokenizer(tokenizer_file: &Path) -> Result { +fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result { let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; - let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { - pad_id - } else { - return Err(anyhow!( - "Tokenizer doesn't contain expected `<|endoftext|>` token" - )); - }; - // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding - tokenizer.with_padding(Some(PaddingParams { - strategy: PaddingStrategy::BatchLongest, - direction: PaddingDirection::Left, - pad_id, - pad_token: "<|endoftext|>".to_string(), - ..Default::default() - })); + if which == Which::Large { + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + } else { + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Right, + ..Default::default() + })); + } Ok(tokenizer) } @@ -298,7 +318,19 @@ fn main() -> Result<()> { Some(d) => d, None => EmbedDim::Dim1024, }; - let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + + let (repo, cfg) = match args.which { + Which::Large => ( + "dunzhang/stella_en_1.5B_v5", + Config::new_1_5_b_v5(embed_dim.embed_dim()), + ), + Which::Small => ( + "dunzhang/stella_en_400M_v5", + Config::new_400_m_v5(embed_dim.embed_dim()), + ), + }; + + let repo = api.repo(Repo::model(repo.to_string())); let tokenizer_filename = match args.tokenizer_file { Some(file) => std::path::PathBuf::from(file), None => repo.get("tokenizer.json")?, @@ -330,7 +362,7 @@ fn main() -> Result<()> { println!("retrieved the files in {:?}", start.elapsed()); // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding - let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?; let start = std::time::Instant::now(); @@ -343,11 +375,7 @@ fn main() -> Result<()> { let embed_vb = unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; - let model = EmbeddingModel::new( - &Config::new_1_5_b_v5(embed_dim.embed_dim()), - base_vb, - embed_vb, - )?; + let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs index 7c1d2b5ae9..761e44a918 100644 --- a/candle-transformers/src/models/stella_en_v5.rs +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -16,33 +16,49 @@ //! use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; -use candle::{DType, Device, IndexOp, Module, Result, Tensor}; -use candle_nn::{Activation, VarBuilder}; +use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; use std::sync::Arc; +// internal representation for identifying which model is being used +#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)] +pub enum ModelVariant { + Large, // 1.5B + Small, // 400M +} + +impl Default for ModelVariant { + fn default() -> Self { + Self::Large + } +} + // Same as `qwen2` family of models with the exception being the `embed_head` // The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head` -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct Config { + pub variant: ModelVariant, pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, - pub num_key_value_heads: usize, pub max_position_embeddings: usize, - pub max_window_layers: usize, - pub tie_word_embeddings: bool, pub rope_theta: f64, - pub rms_norm_eps: f64, - pub hidden_act: Activation, pub embed_head: EmbedHead, + pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M + pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M + // Unique to 1.5B + pub num_key_value_heads: usize, + // Unique to 400M + pub type_vocab_size: usize, + pub scaling_factor: f64, } // Excerpt from `stella` model card: // `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions // Embed head represents the config for various embedding dims supported -#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)] pub struct EmbedHead { pub in_features: usize, pub out_features: usize, @@ -68,9 +84,9 @@ impl Default for EmbedDim { } impl EmbedDim { - pub fn config(&self) -> EmbedHead { + pub fn config(&self, in_features: usize) -> EmbedHead { EmbedHead { - in_features: 1536, + in_features, out_features: match &self { Self::Dim256 => 256, Self::Dim768 => 768, @@ -91,7 +107,8 @@ impl Config { // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here Self { - hidden_act: candle_nn::Activation::Silu, + variant: ModelVariant::Large, + activation_fn: candle_nn::Activation::Silu, vocab_size: 151646, hidden_size: 1536, intermediate_size: 8960, @@ -99,11 +116,30 @@ impl Config { num_attention_heads: 12, num_key_value_heads: 2, max_position_embeddings: 131072, - max_window_layers: 21, - tie_word_embeddings: false, rope_theta: 1000000., - rms_norm_eps: 1e-06, - embed_head: embed_dim.config(), + norm_eps: 1e-06, + embed_head: embed_dim.config(1536), + ..Default::default() + } + } + + /// Initialize new `stella_en_400M_v5` + pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self { + Self { + variant: ModelVariant::Small, + vocab_size: 30528, + hidden_size: 1024, + intermediate_size: 4096, + num_hidden_layers: 24, + num_attention_heads: 16, + max_position_embeddings: 8192, + type_vocab_size: 2, + norm_eps: 1e-12, + scaling_factor: 2.0, + rope_theta: 160000.0, + activation_fn: Activation::Gelu, + embed_head: embed_dim.config(1024), + ..Default::default() } } } @@ -117,27 +153,57 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { let dim = cfg.hidden_size / cfg.num_attention_heads; - let max_seq_len = cfg.max_position_embeddings; + // Factoring in `scaling factor` for `400M` variant + let max_seq_len = if cfg.scaling_factor == 0. { + cfg.max_position_embeddings + } else { + ((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize + }; + + // let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim }; let inv_freq: Vec<_> = (0..dim) .step_by(2) - .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .map(|i| { + // Scaled rope_theta for 400M variant + let rope_theta = if cfg.scaling_factor == 0. { + cfg.rope_theta + } else { + cfg.rope_theta * cfg.scaling_factor + }; + let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64); + + if cfg.scaling_factor != 0. { + freq /= cfg.scaling_factor.powf(2.0 / (dim as f64)) + } + + freq as f32 + }) .collect(); + let inv_freq_len = inv_freq.len(); let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + + // Calculate position embeddings with scaled sequence length let t = Tensor::arange(0u32, max_seq_len as u32, dev)? .to_dtype(dtype)? .reshape((max_seq_len, 1))?; let freqs = t.matmul(&inv_freq)?; + // if cfg.variant == ModelVariant::Small { + // freqs = Tensor::cat(&[&freqs, &freqs], 1)? + // } + Ok(Self { sin: freqs.sin()?, cos: freqs.cos()?, }) } + // TODO: re-visit this fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let cos = self.cos.narrow(0, 0, seq_len)?; let sin = self.sin.narrow(0, 0, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; Ok((q_embed, k_embed)) @@ -147,8 +213,9 @@ impl RotaryEmbedding { #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] struct MLP { + variant: ModelVariant, gate_proj: Linear, - up_proj: Linear, + up_proj: Option, // `up_proj` only for 1.5B variant down_proj: Linear, act_fn: Activation, } @@ -157,31 +224,65 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + + let (gate_proj, up_proj, down_proj) = match cfg.variant { + ModelVariant::Large => ( + linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?, + Some(linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + )?), + linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), + ModelVariant::Small => ( + linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?, + None, + linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?, + ), + }; + Ok(Self { + variant: cfg.variant, gate_proj, up_proj, down_proj, - act_fn: cfg.hidden_act, + act_fn: cfg.activation_fn, }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { - let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; - let rhs = xs.apply(&self.up_proj)?; + let up = self.gate_proj.forward(xs)?; + + let (lhs, rhs) = match self.variant { + ModelVariant::Large => { + let lhs = up.apply(&self.act_fn)?; + let rhs = xs.apply(self.up_proj.as_ref().unwrap())?; + + (lhs, rhs) + } + ModelVariant::Small => { + // Get the dimensions + let (_batch_size, _seq_len, hidden_dim) = up.dims3()?; + let split_size = hidden_dim / 2; + + // Split along the last dimension (hidden_dim) + let up_states = up.narrow(2, 0, split_size)?; + let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?; + + (up_states, gate) + } + }; + (lhs * rhs)?.apply(&self.down_proj) } } #[derive(Debug, Clone)] struct Attention { - q_proj: Linear, - k_proj: Linear, - v_proj: Linear, + qkv_proj: Linear, o_proj: Linear, num_heads: usize, num_kv_heads: usize, @@ -189,6 +290,7 @@ struct Attention { head_dim: usize, hidden_size: usize, rotary_emb: Arc, + variant: ModelVariant, } impl Attention { @@ -196,16 +298,47 @@ impl Attention { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let num_kv_groups = num_heads / num_kv_heads; + let num_kv_groups = if num_kv_heads > 0 { + num_heads / num_kv_heads + } else { + 0 + }; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + + let (qkv_proj, o_proj) = match cfg.variant { + ModelVariant::Large => { + // The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize + // Weights + let q_w = vb + .pp("q_proj") + .get((num_heads * head_dim, hidden_sz), "weight")?; + let k_w = vb + .pp("k_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; + let v_w = vb + .pp("v_proj") + .get((num_kv_heads * head_dim, hidden_sz), "weight")?; + // Biases + let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?; + let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?; + let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?; + + let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?; + let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?; + + ( + Linear::from_weights(qkv_w, Some(qkv_b)), + linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, + ) + } + ModelVariant::Small => ( + linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?, + linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?, + ), + }; + Ok(Self { - q_proj, - k_proj, - v_proj, + qkv_proj, o_proj, num_heads, num_kv_heads, @@ -213,45 +346,90 @@ impl Attention { head_dim, hidden_size: hidden_sz, rotary_emb, + variant: cfg.variant, }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { let (b_sz, q_len, _) = xs.dims3()?; - let query_states = self.q_proj.forward(xs)?; - let key_states = self.k_proj.forward(xs)?; - let value_states = self.v_proj.forward(xs)?; + let qkv = self.qkv_proj.forward(xs)?; - let query_states = query_states - .reshape((b_sz, q_len, self.num_heads, self.head_dim))? - .transpose(1, 2)?; - let key_states = key_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; - let value_states = value_states - .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? - .transpose(1, 2)?; + let n_kv_heads = match self.variant { + ModelVariant::Large => self.num_kv_heads, + ModelVariant::Small => self.num_heads, + }; + + let (query_states, key_states, value_states) = match self.variant { + ModelVariant::Large => { + let q_sz = self.num_heads * self.head_dim; + let kv_sz = n_kv_heads * self.head_dim; + + let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape(( + b_sz, + q_len, + self.num_heads, + self.head_dim, + ))?; + let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; + let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape(( + b_sz, + q_len, + n_kv_heads, + self.head_dim, + ))?; + + (q, k, v) + } + ModelVariant::Small => { + // Split into Q, K, V and reshape to match PyTorch shapes + let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?; + + ( + qkv.i((.., .., 0, .., ..))?, + qkv.i((.., .., 1, .., ..))?, + qkv.i((.., .., 2, .., ..))?, + ) + } + }; + + let query_states = query_states.transpose(1, 2)?.contiguous()?; + let key_states = key_states.transpose(1, 2)?.contiguous()?; + let value_states = value_states.transpose(1, 2)?.contiguous()?; let (query_states, key_states) = self .rotary_emb .apply_rotary_emb_qkv(&query_states, &key_states)?; - let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; - let value_states = - crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + // The 1.5B is expected to have grouped query attention + let (key_states, value_states) = if self.variant == ModelVariant::Large { + ( + crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?, + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?, + ) + } else { + (key_states, value_states) + }; let attn_output = { let scale = 1f64 / f64::sqrt(self.head_dim as f64); - let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?; + let attn_weights = (attn_weights * scale)?; let attn_weights = match attention_mask { None => attn_weights, Some(mask) => attn_weights.broadcast_add(mask)?, }; let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? }; + attn_output .transpose(1, 2)? .reshape((b_sz, q_len, self.hidden_size))? @@ -260,70 +438,282 @@ impl Attention { } #[derive(Debug, Clone)] -struct DecoderLayer { - self_attn: Attention, +enum NormType { + Layer(LayerNorm), + Rms(RmsNorm), +} + +#[derive(Debug, Clone)] +struct Layer { + variant: ModelVariant, + attention: Attention, mlp: MLP, - input_layernorm: RmsNorm, - post_attention_layernorm: RmsNorm, + // For 1.5B: this is `input_layernorm` + // For 400M: this is `output_layernorm` + layernorm: NormType, + post_attention_layernorm: NormType, } -impl DecoderLayer { +impl Layer { fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { - let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; - let mlp = MLP::new(cfg, vb.pp("mlp"))?; - let input_layernorm = - RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; - let post_attention_layernorm = RmsNorm::new( - cfg.hidden_size, - cfg.rms_norm_eps, - vb.pp("post_attention_layernorm"), + let attention = Attention::new( + rotary_emb, + cfg, + vb.pp(if cfg.variant == ModelVariant::Large { + "self_attn" + } else { + "attention" + }), )?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let (layernorm, post_attention_layernorm) = match cfg.variant { + ModelVariant::Large => ( + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("input_layernorm"), + )?), + NormType::Rms(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb.pp("post_attention_layernorm"), + )?), + ), + ModelVariant::Small => ( + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("mlp_ln"), + )?), + NormType::Layer(layer_norm( + cfg.hidden_size, + candle_nn::LayerNormConfig { + eps: cfg.norm_eps, + ..Default::default() + }, + vb.pp("attn_ln"), + )?), + ), + }; + Ok(Self { - self_attn, + variant: cfg.variant, + attention, mlp, - input_layernorm, + layernorm, post_attention_layernorm, }) } fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + // Here, the application of normalizations and activation calculations differ + // For Large [1.5B]: + // residual = x + // state = other_layernorm(xs) + // state = attention(state) + // state += residual + // residual = state + // state = mlp(attention_layernorm(state)) + // -> residual + state + // For Small [400M]: + // residual = x; + // state = attention(x) + // state += residual + // state = attention_layernorm(state) + // residual = state + // state = mlp(state) + // state += residual + // -> other_layernorm(state) let residual = xs; - let xs = self.input_layernorm.forward(xs)?; - let xs = self.self_attn.forward(&xs, attention_mask)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - residual + xs + + match self.variant { + ModelVariant::Large => { + let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg( + "Stella 1.5B expects RMSNorm".to_string(), + )); + }; + + let xs = input_ln.forward(xs)?; + let xs = (self.attention.forward(&xs, attention_mask)? + residual)?; + + let residual = &xs; + let xs = xs.apply(attn_ln)?.apply(&self.mlp)?; + + residual + xs + } + ModelVariant::Small => { + let (attn_ln, output_ln) = + if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) = + (&self.post_attention_layernorm, &self.layernorm) + { + (attn_ln, input_ln) + } else { + return Err(candle::error::Error::Msg( + "Stella 400M expects RMSNorm".to_string(), + )); + }; + + let xs = (self.attention.forward(xs, attention_mask)? + residual)?; + let xs = attn_ln.forward(&xs)?; + + let residual = &xs; + let xs = (self.mlp.forward(&xs)? + residual)?; + + output_ln.forward(&xs) + } + } + } +} + +#[derive(Debug, Clone)] +pub struct Embeddings { + variant: ModelVariant, + // For 1.5B: this is the `embed_tokens` + // For 400M: this is the `word_embeddings` + embeddings: candle_nn::Embedding, + // folloing are specifically for 400M + token_type_embeddings: Option, + layer_norm: Option, + position_ids: Option, +} + +impl Embeddings { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant { + ModelVariant::Large => ( + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?, + None, + None, + None, + ), + ModelVariant::Small => { + let vb = vb.pp("embeddings"); + let weight = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "weight", + candle_nn::Init::Const(1.0), + )?; + let bias = vb.pp("LayerNorm").get_with_hints( + cfg.hidden_size, + "bias", + candle_nn::Init::Const(0.0), + )?; + let dev = bias.device().clone(); + + let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps); + + ( + candle_nn::embedding( + cfg.vocab_size, + cfg.hidden_size, + vb.pp("word_embeddings"), + )?, + Some(candle_nn::embedding( + cfg.type_vocab_size, + cfg.hidden_size, + vb.pp("token_type_embeddings"), + )?), + Some(layer_norm), + Some(Tensor::arange( + 0u32, + cfg.max_position_embeddings as u32, + &dev, + )?), + ) + } + }; + + Ok(Self { + variant: cfg.variant, + embeddings, + token_type_embeddings, + layer_norm, + position_ids, + }) + } +} + +impl Module for Embeddings { + fn forward(&self, xs: &Tensor) -> Result { + let embd = self.embeddings.forward(xs)?; + // For 1.5B just forward the embeddings + if self.variant == ModelVariant::Large { + return Ok(embd); + } + + let (token_type_embed, layer_norm, pos_ids) = + if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = ( + &self.token_type_embeddings, + &self.layer_norm, + &self.position_ids, + ) { + (token_type_embd, layer_norm, position_ids) + } else { + return Err(Error::Msg( + "Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`" + .to_string(), + )); + }; + + let (batch_size, seq_length) = xs.dims2()?; + + let pos_ids = pos_ids + .as_ref() + .narrow(0, 0, seq_length)? + .expand((batch_size, seq_length))?; + + layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?) } } #[derive(Debug, Clone)] pub struct Model { - embed_tokens: candle_nn::Embedding, - layers: Vec, - norm: RmsNorm, + embeddings: Embeddings, + layers: Vec, + norm: Option, device: Device, dtype: DType, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let vb_m = vb.pp("model"); - let embed_tokens = - candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let vb_m = match cfg.variant { + ModelVariant::Large => vb.pp("model"), + ModelVariant::Small => vb.pp("new"), + }; + // let embed_tokens = + // candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let embeddings = Embeddings::new(cfg, vb_m.clone())?; let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); let mut layers = Vec::with_capacity(cfg.num_hidden_layers); - let vb_l = vb_m.pp("layers"); + let vb_l = match cfg.variant { + ModelVariant::Large => vb_m.pp("layers"), + ModelVariant::Small => vb_m.pp("encoder").pp("layer"), + }; for layer_idx in 0..cfg.num_hidden_layers { - let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; layers.push(layer) } - let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let norm = match cfg.variant { + ModelVariant::Large => Some(RmsNorm::new( + cfg.hidden_size, + cfg.norm_eps, + vb_m.pp("norm"), + )?), + ModelVariant::Small => None, + }; Ok(Self { - embed_tokens, + embeddings, layers, norm, - // sliding_window: 0, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -352,15 +742,20 @@ impl Model { Some(self.prepare_attention_mask(mask)?) }; - let mut xs = self.embed_tokens.forward(input_ids)?; + let mut xs = self.embeddings.forward(input_ids)?; for layer in self.layers.iter_mut() { xs = layer.forward(&xs, attention_mask.as_ref())? } - xs.apply(&self.norm) + + if let Some(n) = &self.norm { + xs.apply(n) + } else { + Ok(xs) + } } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct EmbeddingModel { base_model: Model, lm_head: Linear,