From 1cd03b2c35b95fac6119a65b9804af552fb49210 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:26:08 -0400 Subject: [PATCH 01/11] Update mod.rs --- candle-transformers/src/models/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 571a88614d..be1f15c413 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -62,6 +62,7 @@ pub mod mobilenetv4; pub mod mobileone; pub mod moondream; pub mod mpt; +pub mod nvembed_v2; pub mod olmo; pub mod openclip; pub mod paligemma; From cfc49e929f6ddf3aaa54485ecf35c47bb3a71257 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:29:56 -0400 Subject: [PATCH 02/11] Create mod.rs --- .../src/models/nvembed_v2/mod.rs | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 candle-transformers/src/models/nvembed_v2/mod.rs diff --git a/candle-transformers/src/models/nvembed_v2/mod.rs b/candle-transformers/src/models/nvembed_v2/mod.rs new file mode 100644 index 0000000000..610a44d94c --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/mod.rs @@ -0,0 +1,22 @@ +//! NV-Embed-v2 +//! +//! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings. +//! +//! - [HuggingFace Model Card](https://huggingface.co/nvidia/NV-Embed-v2) +//! +//! # Query-Passage Retrieval Example +//! ```bash +//! cargo run --example nvembed_v2 --release +//! ``` +//! +//! # Sentence Embedding Example +//! ```bash +//! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence" +//! ``` + +// Copyright (c) NVIDIA CORPORATION, all rights reserved. +// This source code is licensed under the CC-BY-NC-4.0 license. +// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. + +pub mod decoder; +pub mod model; From 6123269b0c564984c0231a03462ee9263449b263 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:33:43 -0400 Subject: [PATCH 03/11] Create decoder.rs --- .../src/models/nvembed_v2/decoder.rs | 298 ++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 candle-transformers/src/models/nvembed_v2/decoder.rs diff --git a/candle-transformers/src/models/nvembed_v2/decoder.rs b/candle-transformers/src/models/nvembed_v2/decoder.rs new file mode 100644 index 0000000000..303d74bf36 --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/decoder.rs @@ -0,0 +1,298 @@ +// Copyright (c) NVIDIA CORPORATION, all rights reserved. +// This source code is licensed under the CC-BY-NC-4.0 license. +// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. + +/// Mistral LLM, https://github.com/mistralai/mistral-src +use crate::models::{ + mistral::Config, + with_tracing::{linear_no_bias, Linear, RmsNorm}, +}; +use crate::utils::repeat_kv; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let rope_theta = cfg.rope_theta as f32; + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offset: usize, + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, seqlen_offset, seq_len)?; + let sin = self.sin.narrow(0, seqlen_offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(q, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(k, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous()?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = + self.rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; + + let key_states = repeat_kv(key_states, self.num_kv_groups)?; + let value_states = repeat_kv(value_states, self.num_kv_groups)?; + + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_weights.matmul(&value_states)?; + + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + + let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?; + + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + pub cfg: Config, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("norm"))?; + Ok(Self { + embed_tokens, + layers, + norm, + cfg: cfg.clone(), + }) + } + + // Attn mask used to mask out padding tokens + pub fn forward( + &mut self, + attn_mask: &Tensor, + input_ids: &Tensor, + dtype: DType, + ) -> Result { + let mut xs = self.embed_tokens.forward(input_ids)?; + + // Expand to 4d mask for sdpa + let attn_mask = prepare_4d_attention_mask(attn_mask, dtype, None)?; + + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, Some(&attn_mask), 0)?; + } + + // Return hiddens instead of logits + xs.apply(&self.norm) + } +} + +fn prepare_4d_attention_mask( + mask: &Tensor, + dtype: DType, + tgt_len: Option, +) -> Result { + let bsz = mask.dims()[0]; + let src_len = mask.dims()[1]; + let tgt_len = tgt_len.unwrap_or(src_len); + + let expanded_mask = mask + .unsqueeze(1)? + .unsqueeze(2)? + .expand((bsz, 1, tgt_len, src_len))? + .to_dtype(dtype)?; + + let inverted_mask = (1.0 - expanded_mask)?; + + (inverted_mask * get_dtype_min_val(dtype))?.to_dtype(dtype) +} + +fn get_dtype_min_val(dtype: DType) -> f64 { + match dtype { + DType::F32 => f32::MIN as f64, + DType::F64 => f64::MIN, + _ => panic!("Unsupported data type"), + } +} From 8912d2d56499af82c9892ee5c503298f044e14c9 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:33:59 -0400 Subject: [PATCH 04/11] Create model.rs --- .../src/models/nvembed_v2/model.rs | 356 ++++++++++++++++++ 1 file changed, 356 insertions(+) create mode 100644 candle-transformers/src/models/nvembed_v2/model.rs diff --git a/candle-transformers/src/models/nvembed_v2/model.rs b/candle-transformers/src/models/nvembed_v2/model.rs new file mode 100644 index 0000000000..39c7f88bd3 --- /dev/null +++ b/candle-transformers/src/models/nvembed_v2/model.rs @@ -0,0 +1,356 @@ +// Copyright (c) NVIDIA CORPORATION, all rights reserved. +// This source code is licensed under the CC-BY-NC-4.0 license. +// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. + +use super::decoder::Model as MistralModel; +use crate::models::{ + mistral::Config, + with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear}, +}; +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{ops::softmax_last_dim, Module, VarBuilder}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +struct LatentAttentionConfig { + num_latents_value: usize, + num_cross_heads: usize, + output_normalize: bool, + hidden_dim: usize, + latent_dim: usize, + cross_dim_head: usize, + hidden_size: usize, +} + +impl LatentAttentionConfig { + fn new(hidden_size: usize, output_normalize: bool) -> Self { + Self { + num_latents_value: 512, + num_cross_heads: 8, + output_normalize, + hidden_dim: 4096, + latent_dim: 4096, + cross_dim_head: 4096, + hidden_size, + } + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct GEGLU {} + +impl GEGLU { + fn new() -> Self { + Self {} + } +} +impl Module for GEGLU { + fn forward(&self, x: &Tensor) -> Result { + let last_dim = x.dims().len() - 1; + let chunks = x.chunk(2, last_dim)?; + let (x, gates) = (chunks[0].clone(), chunks[1].clone()); + + let gates = gates.gelu()?; + + x * gates + } +} + +#[derive(Debug, Clone)] +struct FeedForward { + linear1: Linear, + gelu: GEGLU, + linear2: Linear, +} + +impl FeedForward { + fn new(dim: usize, vb1: VarBuilder, vb2: VarBuilder) -> Result { + let linear1 = linear(dim, dim * 4 * 2, vb1)?; + let gelu = GEGLU::new(); + let linear2 = linear(dim * 4, dim, vb2)?; + + Ok(Self { + linear1, + gelu, + linear2, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.linear1.forward(xs)?; + let xs = self.gelu.forward(&xs)?; + let xs = self.linear2.forward(&xs)?; + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Attention { + heads: usize, + to_q: Linear, + to_kv: Linear, + to_out: Linear, + dim_head: usize, +} + +#[allow(clippy::too_many_arguments)] +impl Attention { + fn new( + query_dim: usize, + context_dim: Option, + heads: Option, + dim_head: Option, + vb_to_q: VarBuilder, + vb_to_kv: VarBuilder, + vb_to_out: VarBuilder, + ) -> Result { + let heads = heads.unwrap_or(8); + let dim_head = dim_head.unwrap_or(64); + let inner_dim = dim_head * heads; + let context_dim = context_dim.unwrap_or(query_dim); + + let to_q = linear_no_bias(query_dim, inner_dim, vb_to_q)?; + let to_kv = linear_no_bias(context_dim, inner_dim * 2, vb_to_kv)?; + let to_out = linear_no_bias(inner_dim, query_dim, vb_to_out)?; + Ok(Self { + heads, + to_q, + to_kv, + to_out, + dim_head, + }) + } + + // Cross attn takes queries from the mistral decoder and kv from latent attention model + fn forward(&self, x: &Tensor, context: &Tensor) -> Result { + let h = self.heads; + let q = self.to_q.forward(x)?; + let kv_chunks = self + .to_kv + .forward(context)? + .chunk(2, context.shape().dims().len() - 1)?; + let (k, v) = (kv_chunks[0].clone(), kv_chunks[1].clone()); + + let (b_sz, q_len, _) = q.dims3()?; + let q = q + .reshape((b_sz, q_len, h, self.dim_head))? + .transpose(1, 2)? + .contiguous()?; + + let (_, q_len, _) = k.dims3()?; + let k = k + .reshape((b_sz, q_len, h, self.dim_head))? + .transpose(1, 2)? + .contiguous()?; + + let (_, q_len, _) = v.dims3()?; + let v = v + .reshape((b_sz, q_len, h, self.dim_head))? + .transpose(1, 2)? + .contiguous()?; + + let scale = 1f64 / f64::sqrt(self.dim_head as f64); + + let attn_weight = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + let attn_weight = softmax_last_dim(&attn_weight)?; + + let out = attn_weight.matmul(&v)?; + + let (_, _, q_len, _) = out.dims4()?; + let out = out + .transpose(1, 2)? + .reshape((b_sz, q_len, self.dim_head * h))?; + + self.to_out.forward(&out) + } +} + +#[derive(Debug, Clone)] +enum PreNormInnerLayer { + Attention(Attention), + FeedForward(FeedForward), +} + +#[derive(Debug, Clone)] +struct PreNorm { + norm: LayerNorm, + norm_context: Option, + inner_layer: PreNormInnerLayer, +} + +impl PreNorm { + fn new( + dim: usize, + context_dim: Option, + inner_layer: PreNormInnerLayer, + norm_vb: VarBuilder, + norm_context_vb: Option, + ) -> Result { + let norm = layer_norm(dim, candle_nn::LayerNormConfig::default(), norm_vb)?; + + let norm_context = match context_dim { + Some(context_dim) => { + let norm_context_vb = norm_context_vb + .expect("norm_context_vb must be passed if context_dim is passed"); + match layer_norm( + context_dim, + candle_nn::LayerNormConfig::default(), + norm_context_vb, + ) { + Ok(norm_context) => Some(norm_context), + Err(e) => return Err(e), + } + } + None => None, + }; + Ok(Self { + norm, + norm_context, + inner_layer, + }) + } + + // Applies a layernorm to the input before passing to cross attn or feed forward + fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { + let xs = self.norm.forward(xs)?; + + let mut normed_context = None; + if let Some(norm_context) = &self.norm_context { + if let Some(context) = context { + normed_context = Some(norm_context.forward(context)?); + } + } + + match &self.inner_layer { + PreNormInnerLayer::Attention(attn) => attn.forward(&xs, &normed_context.unwrap()), + PreNormInnerLayer::FeedForward(ff) => ff.forward(&xs), + } + } +} + +#[derive(Debug, Clone)] +struct LatentAttentionModel { + cross_attn: PreNorm, + ff: PreNorm, + output_normalize: bool, + latents: Tensor, +} + +impl LatentAttentionModel { + fn new(vb: VarBuilder, config: LatentAttentionConfig) -> Result { + let vb_cross = vb.pp("cross_attend_blocks"); + + let num_latents = config.num_latents_value; + let latent_dim = config.latent_dim; + let cross_heads = config.num_cross_heads; + let cross_dim_head = config.cross_dim_head; + let dim = config.hidden_dim; + let hidden_size = config.hidden_size; + + let cross_attn = PreNorm::new( + latent_dim, + Some(hidden_size), + PreNormInnerLayer::Attention(Attention::new( + latent_dim, + Some(dim), + Some(cross_heads), + Some(cross_dim_head), + vb_cross.pp("0.fn.to_q"), + vb_cross.pp("0.fn.to_kv"), + vb_cross.pp("0.fn.to_out"), + )?), + vb_cross.pp("0.norm"), + Some(vb_cross.pp("0.norm_context")), + )?; + + let ff = PreNorm::new( + latent_dim, + None, + PreNormInnerLayer::FeedForward(FeedForward::new( + latent_dim, + vb_cross.pp("1.fn.net.0"), + vb_cross.pp("1.fn.net.2"), + )?), + vb_cross.pp("1.norm"), + None, + )?; + + let output_normalize = config.output_normalize; + let latents = vb.get((num_latents, latent_dim), "latents")?; + + Ok(Self { + cross_attn, + ff, + output_normalize, + latents, + }) + } + + fn forward(&self, hiddens: &Tensor, attention_mask: &Tensor) -> Result { + let b = hiddens.dims()[0]; + let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?; + + let hiddens = (self.cross_attn.forward(hiddens, Some(&x))? + hiddens)?; + let hiddens = (self.ff.forward(&hiddens, None)? + hiddens)?; + + // Mean pooling + let hiddens_masked = hiddens.broadcast_mul(&attention_mask.unsqueeze(D::Minus1)?)?; + let s = hiddens_masked.sum(1)?; + let d = attention_mask.sum_keepdim(1)?; + let hiddens = s.broadcast_div(&d)?; + + if self.output_normalize { + let hiddens = div_l2_norm(&hiddens)?; + + Ok(hiddens) + } else { + Ok(hiddens) + } + } +} + +#[derive(Debug, Clone)] +pub struct NVEmbedModel { + latent_attention_model: LatentAttentionModel, + embedding_model: MistralModel, + pub device: Device, + pub dtype: DType, +} + +impl NVEmbedModel { + pub fn new(vb: VarBuilder, output_normalize: bool) -> Result { + let cfg = Config::config_7b_v0_1(false); + let embedding_model = MistralModel::new(&cfg, vb.pp("embedding_model"))?; + let hidden_size = embedding_model.cfg.hidden_size; + let latent_attention_model = LatentAttentionModel::new( + vb.pp("latent_attention_model"), + LatentAttentionConfig::new(hidden_size, output_normalize), + )?; + + Ok(Self { + latent_attention_model, + embedding_model, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + attn_mask: &Tensor, + pool_mask: &Tensor, + ) -> Result { + let outputs = self + .embedding_model + .forward(attn_mask, input_ids, self.dtype)?; + + self.latent_attention_model.forward(&outputs, pool_mask) + } +} + +fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} From 666526fc5abd7237512b543ced992b382bf16872 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:35:58 -0400 Subject: [PATCH 05/11] Create main.rs --- candle-examples/examples/nvembed_v2/main.rs | 213 ++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 candle-examples/examples/nvembed_v2/main.rs diff --git a/candle-examples/examples/nvembed_v2/main.rs b/candle-examples/examples/nvembed_v2/main.rs new file mode 100644 index 0000000000..4a3c951683 --- /dev/null +++ b/candle-examples/examples/nvembed_v2/main.rs @@ -0,0 +1,213 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use candle::{DType, IndexOp, Shape, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::nvembed_v2::model; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingDirection, PaddingParams, Tokenizer, TruncationParams}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, + + /// L2 normalization for embeddings. + #[arg(long, default_value = "true")] + normalize_embeddings: bool, + + #[arg(long)] + tokenizer: Option, + + #[arg(long)] + model: Option, + + /// Comma-separated list of model files (e.g., '/path/file1.safetensors,/path/file2.safetensors,/path/file3.safetensors') + #[arg(long)] + model_files: Option, +} + +impl Args { + fn build_model_and_tokenizer( + &self, + ) -> anyhow::Result<(model::NVEmbedModel, tokenizers::Tokenizer)> { + let model_name = match self.model.as_ref() { + Some(model) => model.to_string(), + None => "nvidia/NV-Embed-v2".to_string(), + }; + + let api = Api::new()?; + let repo = api.repo(Repo::new(model_name.to_string(), RepoType::Model)); + + let model_files = match &self.model_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + + let tokenizer_file = match &self.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let device = candle_examples::device(self.cpu)?; + + let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + + let _ = tokenizer + .with_padding(Some(PaddingParams { + direction: PaddingDirection::Right, + pad_id: 2, + pad_token: "".to_string(), + ..Default::default() + })) + .with_truncation(Some(TruncationParams { + max_length: 32768, + ..Default::default() + })); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device) }?; + + let nvembed_model = model::NVEmbedModel::new(vb, self.normalize_embeddings); + Ok((nvembed_model?, tokenizer)) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (mut model, tokenizer) = args.build_model_and_tokenizer()?; + + if let Some(prompt) = args.prompt { + let emb = encode(&mut model, &tokenizer, vec![prompt], "")?; + println!("Embedding: {emb}"); + } else { + let queries = [ + "are judo throws allowed in wrestling?", + "how to become a radiology technician in michigan?", + ]; + + let passages = [ + "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", + "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." + ]; + let passage_instruction = "".to_string(); + let query_instruction = + "Instruct: Given a question, retrieve passages that answer the question\nQuery: " + .to_string(); + + let passages: Vec = passages.iter().map(|s| s.to_string()).collect(); + let queries: Vec = queries.iter().map(|s| s.to_string()).collect(); + + let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?; + let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?; + + let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?; + + println!("scores: {scores}"); + } + Ok(()) +} + +fn encode( + model: &mut model::NVEmbedModel, + tokenizer: &Tokenizer, + examples: Vec, + instruction: &str, +) -> Result { + let device = &model.device; + let dtype = model.dtype; + + // Format input text + let eos_token = if let Some(padding) = tokenizer.get_padding() { + padding.pad_token.clone() + } else { + "".to_string() + }; + let bos = "".to_string(); + let input_texts = examples + .iter() + .map(|input_example| format!("{bos}{instruction}{input_example}{eos_token}")) + .collect::>(); + let encodings = tokenizer.encode_batch(input_texts, false).map_err(E::msg)?; + + let input_ids_list = encodings + .iter() + .map(|encoding| { + Tensor::from_slice( + encoding.get_ids(), + Shape::from(encoding.get_ids().len()), + device, + ) + }) + .collect::, _>>()?; + + // Mask out padding tokens for both mistral and latent attention model + let attention_masks: Vec = encodings + .iter() + .map(|encoding| { + Tensor::from_slice( + encoding.get_attention_mask(), + Shape::from(encoding.get_attention_mask().len()), + device, + )? + .to_dtype(dtype) + }) + .collect::, _>>()?; + + let input_ids = Tensor::stack(&input_ids_list, 0)?; + let attention_mask = Tensor::stack(&attention_masks, 0)?; + + let instruction_lens = if !instruction.is_empty() { + let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?; + encoded_instruction.get_tokens().len() + } else { + 0 + }; + + // Mask out instruction tokens for latent attention model + let pool_mask = if instruction_lens > 0 { + let zeros = Tensor::zeros( + attention_mask.i((.., ..instruction_lens))?.shape(), + dtype, + device, + )?; + + let batch_size = attention_mask.dims()[0]; + attention_mask.slice_assign(&[..batch_size, ..instruction_lens], &zeros)? + } else { + attention_mask.clone() + }; + + Ok(model + .forward(&input_ids, &attention_mask, &pool_mask)? + .squeeze(1)?) +} From 1557e16673fbad0fb1503d1c67b4b52d050bc0cb Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:37:22 -0400 Subject: [PATCH 06/11] Create README.md --- candle-examples/examples/nvembed_v2/README.md | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 candle-examples/examples/nvembed_v2/README.md diff --git a/candle-examples/examples/nvembed_v2/README.md b/candle-examples/examples/nvembed_v2/README.md new file mode 100644 index 0000000000..f6543fcf47 --- /dev/null +++ b/candle-examples/examples/nvembed_v2/README.md @@ -0,0 +1,43 @@ +# NV-Embed-v2 + +Candle implementation (inference only) of [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2), a text embedding model that ranks No. 1 (as of Nov 25 2024) on the [MTEB](https://huggingface.co/spaces/mteb/leaderboard) benchmark with a score of 72.31 across 56 text embedding tasks. + +## Running an example: Retrieval +```bash +cargo run --example nvembed_v2 --release +> scores: [[87.4269, 0.4629], +> [ 0.9653, 86.0372]] +> Tensor[[2, 2], f32] +``` +In this example, we have two queries and two passages (the corresponding answers). The output tensor represents the similarity scores between each query-passage pair. The scores are computed by taking the dot product of the query and passage embeddings and scaling the result by 100. +```rust +let queries = [ + "are judo throws allowed in wrestling?", + "how to become a radiology technician in michigan?", +]; +let query_instruction = + "Instruct: Given a question, retrieve passages that answer the question\nQuery: " + .to_string(); + +let passages = [ + "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", + "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." +]; +let passage_instruction = "".to_string(); +``` + +If you already have the model and tokenizer files, you can use the `--tokenizer` and `--model-files` options to specify their full paths, instead of downloading them from the hub. + +## Running an example: Sentence embedding +```bash +cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence" +> Embedding: [[ 0.0066, -0.0048, 0.0066, ..., -0.0096, 0.0119, -0.0052]] +> Tensor[[1, 4096], f32] +``` +In this example, we pass a prompt to the model and it outputs the vector encoding of the prompt. + +## Hardware Requirements +29.25GB at fp32 + +## License +This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms. From 3b7ef5e471fe997d6b317e47dc86e8c91c2657d4 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Fri, 29 Nov 2024 02:52:38 -0400 Subject: [PATCH 07/11] Update README.md --- candle-examples/examples/nvembed_v2/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-examples/examples/nvembed_v2/README.md b/candle-examples/examples/nvembed_v2/README.md index f6543fcf47..66b10fab04 100644 --- a/candle-examples/examples/nvembed_v2/README.md +++ b/candle-examples/examples/nvembed_v2/README.md @@ -40,4 +40,4 @@ In this example, we pass a prompt to the model and it outputs the vector encodin 29.25GB at fp32 ## License -This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms. +CC-BY-NC-4.0. This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms. From d6c9925ffce230fea2c8e13aac9a7558de7b5ba6 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:07:11 -0400 Subject: [PATCH 08/11] Update main.rs --- candle-examples/examples/nvembed_v2/main.rs | 141 ++++++++++---------- 1 file changed, 71 insertions(+), 70 deletions(-) diff --git a/candle-examples/examples/nvembed_v2/main.rs b/candle-examples/examples/nvembed_v2/main.rs index 4a3c951683..8db9a100fe 100644 --- a/candle-examples/examples/nvembed_v2/main.rs +++ b/candle-examples/examples/nvembed_v2/main.rs @@ -5,9 +5,9 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::{Error as E, Result}; -use candle::{DType, IndexOp, Shape, Tensor}; +use candle::{DType, IndexOp, Shape, Tensor, D}; use candle_nn::VarBuilder; -use candle_transformers::models::nvembed_v2::model; +use candle_transformers::models::nvembed_v2::model::Model; use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{PaddingDirection, PaddingParams, Tokenizer, TruncationParams}; @@ -43,9 +43,7 @@ struct Args { } impl Args { - fn build_model_and_tokenizer( - &self, - ) -> anyhow::Result<(model::NVEmbedModel, tokenizers::Tokenizer)> { + fn build_model_and_tokenizer(&self) -> anyhow::Result<(Model, tokenizers::Tokenizer)> { let model_name = match self.model.as_ref() { Some(model) => model.to_string(), None => "nvidia/NV-Embed-v2".to_string(), @@ -85,60 +83,13 @@ impl Args { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device) }?; - let nvembed_model = model::NVEmbedModel::new(vb, self.normalize_embeddings); + let nvembed_model = Model::new(vb); Ok((nvembed_model?, tokenizer)) } } -fn main() -> anyhow::Result<()> { - use tracing_chrome::ChromeLayerBuilder; - use tracing_subscriber::prelude::*; - - let args = Args::parse(); - let _guard = if args.tracing { - println!("tracing..."); - let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); - tracing_subscriber::registry().with(chrome_layer).init(); - Some(guard) - } else { - None - }; - - let (mut model, tokenizer) = args.build_model_and_tokenizer()?; - - if let Some(prompt) = args.prompt { - let emb = encode(&mut model, &tokenizer, vec![prompt], "")?; - println!("Embedding: {emb}"); - } else { - let queries = [ - "are judo throws allowed in wrestling?", - "how to become a radiology technician in michigan?", - ]; - - let passages = [ - "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", - "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." - ]; - let passage_instruction = "".to_string(); - let query_instruction = - "Instruct: Given a question, retrieve passages that answer the question\nQuery: " - .to_string(); - - let passages: Vec = passages.iter().map(|s| s.to_string()).collect(); - let queries: Vec = queries.iter().map(|s| s.to_string()).collect(); - - let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?; - let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?; - - let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?; - - println!("scores: {scores}"); - } - Ok(()) -} - fn encode( - model: &mut model::NVEmbedModel, + model: &mut Model, tokenizer: &Tokenizer, examples: Vec, instruction: &str, @@ -157,6 +108,8 @@ fn encode( .iter() .map(|input_example| format!("{bos}{instruction}{input_example}{eos_token}")) .collect::>(); + + // Tokenize let encodings = tokenizer.encode_batch(input_texts, false).map_err(E::msg)?; let input_ids_list = encodings @@ -169,8 +122,9 @@ fn encode( ) }) .collect::, _>>()?; + let input_ids = Tensor::stack(&input_ids_list, 0)?; - // Mask out padding tokens for both mistral and latent attention model + // Mask out padding tokens for both embedding model and latent attention model let attention_masks: Vec = encodings .iter() .map(|encoding| { @@ -182,32 +136,79 @@ fn encode( .to_dtype(dtype) }) .collect::, _>>()?; - - let input_ids = Tensor::stack(&input_ids_list, 0)?; let attention_mask = Tensor::stack(&attention_masks, 0)?; - let instruction_lens = if !instruction.is_empty() { - let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?; - encoded_instruction.get_tokens().len() - } else { - 0 - }; - // Mask out instruction tokens for latent attention model - let pool_mask = if instruction_lens > 0 { + let pool_mask = if !instruction.is_empty() { + let encoded_instruction = tokenizer.encode(instruction, false).map_err(E::msg)?; + let instruction_lens = encoded_instruction.get_tokens().len(); let zeros = Tensor::zeros( attention_mask.i((.., ..instruction_lens))?.shape(), dtype, device, )?; - - let batch_size = attention_mask.dims()[0]; - attention_mask.slice_assign(&[..batch_size, ..instruction_lens], &zeros)? + let b = attention_mask.dims()[0]; + attention_mask.slice_assign(&[..b, ..instruction_lens], &zeros)? } else { attention_mask.clone() }; - Ok(model + let hiddens = model .forward(&input_ids, &attention_mask, &pool_mask)? - .squeeze(1)?) + .squeeze(1)?; + + // Normalize embedding + div_l2_norm(&hiddens) +} + +fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + Ok(v.broadcast_div(&l2_norm)?) +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let (mut model, tokenizer) = args.build_model_and_tokenizer()?; + + if let Some(prompt) = args.prompt { + let emb = encode(&mut model, &tokenizer, vec![prompt], "")?; + println!("Embedding: {emb}"); + } else { + let queries = [ + "are judo throws allowed in wrestling?", + "how to become a radiology technician in michigan?", + ]; + + let passages = [ + "Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.", + "Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan." + ]; + let passage_instruction = "".to_string(); + let query_instruction = + "Instruct: Given a question, retrieve passages that answer the question\nQuery: " + .to_string(); + + let passages: Vec = passages.iter().map(|s| s.to_string()).collect(); + let queries: Vec = queries.iter().map(|s| s.to_string()).collect(); + + let emb_query = encode(&mut model, &tokenizer, queries, &query_instruction)?; + let emb_passage = encode(&mut model, &tokenizer, passages, &passage_instruction)?; + + let scores = (emb_query.matmul(&emb_passage.t()?)? * 100.0)?; + + println!("scores: {scores}"); + } + Ok(()) } From 70a77ad9fe36b33501905d7af427d2ca450c888b Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:10:46 -0400 Subject: [PATCH 09/11] Update and rename decoder.rs to embedding.rs --- .../src/models/nvembed_v2/{decoder.rs => embedding.rs} | 4 ---- 1 file changed, 4 deletions(-) rename candle-transformers/src/models/nvembed_v2/{decoder.rs => embedding.rs} (98%) diff --git a/candle-transformers/src/models/nvembed_v2/decoder.rs b/candle-transformers/src/models/nvembed_v2/embedding.rs similarity index 98% rename from candle-transformers/src/models/nvembed_v2/decoder.rs rename to candle-transformers/src/models/nvembed_v2/embedding.rs index 303d74bf36..a52192afdf 100644 --- a/candle-transformers/src/models/nvembed_v2/decoder.rs +++ b/candle-transformers/src/models/nvembed_v2/embedding.rs @@ -1,7 +1,3 @@ -// Copyright (c) NVIDIA CORPORATION, all rights reserved. -// This source code is licensed under the CC-BY-NC-4.0 license. -// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. - /// Mistral LLM, https://github.com/mistralai/mistral-src use crate::models::{ mistral::Config, From 3be3432750e782950b76fa232485c97c4cc6f094 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:12:15 -0400 Subject: [PATCH 10/11] Update mod.rs --- candle-transformers/src/models/nvembed_v2/mod.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/candle-transformers/src/models/nvembed_v2/mod.rs b/candle-transformers/src/models/nvembed_v2/mod.rs index 610a44d94c..8a8f700782 100644 --- a/candle-transformers/src/models/nvembed_v2/mod.rs +++ b/candle-transformers/src/models/nvembed_v2/mod.rs @@ -2,7 +2,7 @@ //! //! NV-Embed-v2 is a text embedding model that combines a Mistral decoder with a latent attention mechanism to produce high-quality text embeddings. //! -//! - [HuggingFace Model Card](https://huggingface.co/nvidia/NV-Embed-v2) +//! This implementation is based on the [paper](https://arxiv.org/pdf/2405.17428) and [weights](https://huggingface.co/nvidia/NV-Embed-v2) //! //! # Query-Passage Retrieval Example //! ```bash @@ -14,9 +14,5 @@ //! cargo run --example nvembed_v2 --release -- --prompt "Here is a test sentence" //! ``` -// Copyright (c) NVIDIA CORPORATION, all rights reserved. -// This source code is licensed under the CC-BY-NC-4.0 license. -// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. - -pub mod decoder; +pub mod embedding; pub mod model; From a274da98c318e42b298ef31f38b12aaca4d177c0 Mon Sep 17 00:00:00 2001 From: cdoko <190060110+cdoko@users.noreply.github.com> Date: Tue, 3 Dec 2024 05:12:55 -0400 Subject: [PATCH 11/11] Update model.rs --- .../src/models/nvembed_v2/model.rs | 423 +++++++----------- 1 file changed, 150 insertions(+), 273 deletions(-) diff --git a/candle-transformers/src/models/nvembed_v2/model.rs b/candle-transformers/src/models/nvembed_v2/model.rs index 39c7f88bd3..73ef776e3b 100644 --- a/candle-transformers/src/models/nvembed_v2/model.rs +++ b/candle-transformers/src/models/nvembed_v2/model.rs @@ -1,336 +1,201 @@ -// Copyright (c) NVIDIA CORPORATION, all rights reserved. -// This source code is licensed under the CC-BY-NC-4.0 license. -// See https://spdx.org/licenses/CC-BY-NC-4.0 for details. - -use super::decoder::Model as MistralModel; +use super::embedding::Model as EmbeddingModel; use crate::models::{ mistral::Config, with_tracing::{layer_norm, linear, linear_no_bias, LayerNorm, Linear}, }; use candle::{DType, Device, Result, Tensor, D}; -use candle_nn::{ops::softmax_last_dim, Module, VarBuilder}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] -struct LatentAttentionConfig { - num_latents_value: usize, - num_cross_heads: usize, - output_normalize: bool, - hidden_dim: usize, - latent_dim: usize, - cross_dim_head: usize, - hidden_size: usize, -} +use candle_nn::{ops::softmax_last_dim, LayerNormConfig, Module, VarBuilder}; -impl LatentAttentionConfig { - fn new(hidden_size: usize, output_normalize: bool) -> Self { - Self { - num_latents_value: 512, - num_cross_heads: 8, - output_normalize, - hidden_dim: 4096, - latent_dim: 4096, - cross_dim_head: 4096, - hidden_size, - } - } +// Geglu and feedforward from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct GeGlu { + proj: Linear, + span: tracing::Span, } -#[derive(Debug, Clone)] -#[allow(clippy::upper_case_acronyms)] -struct GEGLU {} - -impl GEGLU { - fn new() -> Self { - Self {} +impl GeGlu { + fn new(vs: VarBuilder, dim_in: usize, dim_out: usize) -> Result { + let proj = linear(dim_in, dim_out * 2, vs)?; + let span = tracing::span!(tracing::Level::TRACE, "geglu"); + Ok(Self { proj, span }) } } -impl Module for GEGLU { - fn forward(&self, x: &Tensor) -> Result { - let last_dim = x.dims().len() - 1; - let chunks = x.chunk(2, last_dim)?; - let (x, gates) = (chunks[0].clone(), chunks[1].clone()); - - let gates = gates.gelu()?; - x * gates +impl Module for GeGlu { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?; + &hidden_states_and_gate[0] * hidden_states_and_gate[1].gelu()? } } -#[derive(Debug, Clone)] +#[derive(Debug)] struct FeedForward { - linear1: Linear, - gelu: GEGLU, - linear2: Linear, + project_in: GeGlu, + linear: Linear, + span: tracing::Span, } impl FeedForward { - fn new(dim: usize, vb1: VarBuilder, vb2: VarBuilder) -> Result { - let linear1 = linear(dim, dim * 4 * 2, vb1)?; - let gelu = GEGLU::new(); - let linear2 = linear(dim * 4, dim, vb2)?; - + fn new(vs: VarBuilder, dim: usize, dim_out: Option, mult: usize) -> Result { + let inner_dim = dim * mult; + let dim_out = dim_out.unwrap_or(dim); + let vs = vs.pp("net"); + let project_in = GeGlu::new(vs.pp("0"), dim, inner_dim)?; + let linear = linear(inner_dim, dim_out, vs.pp("2"))?; + let span = tracing::span!(tracing::Level::TRACE, "ff"); Ok(Self { - linear1, - gelu, - linear2, + project_in, + linear, + span, }) } +} +impl Module for FeedForward { fn forward(&self, xs: &Tensor) -> Result { - let xs = self.linear1.forward(xs)?; - let xs = self.gelu.forward(&xs)?; - let xs = self.linear2.forward(&xs)?; - Ok(xs) + let _enter = self.span.enter(); + let xs = self.project_in.forward(xs)?; + self.linear.forward(&xs) } } -#[derive(Debug, Clone)] -struct Attention { - heads: usize, +// CrossAttention from candle-transformers/src/models/stable_diffusion/attention.rs +#[derive(Debug)] +struct CrossAttention { to_q: Linear, to_kv: Linear, to_out: Linear, - dim_head: usize, + heads: usize, + scale: f64, + span: tracing::Span, + span_attn: tracing::Span, + span_softmax: tracing::Span, } -#[allow(clippy::too_many_arguments)] -impl Attention { +impl CrossAttention { fn new( + vs: VarBuilder, query_dim: usize, context_dim: Option, - heads: Option, - dim_head: Option, - vb_to_q: VarBuilder, - vb_to_kv: VarBuilder, - vb_to_out: VarBuilder, + heads: usize, + dim_head: usize, ) -> Result { - let heads = heads.unwrap_or(8); - let dim_head = dim_head.unwrap_or(64); let inner_dim = dim_head * heads; let context_dim = context_dim.unwrap_or(query_dim); - - let to_q = linear_no_bias(query_dim, inner_dim, vb_to_q)?; - let to_kv = linear_no_bias(context_dim, inner_dim * 2, vb_to_kv)?; - let to_out = linear_no_bias(inner_dim, query_dim, vb_to_out)?; + let scale = 1.0 / f64::sqrt(dim_head as f64); + let to_q = linear_no_bias(query_dim, inner_dim, vs.pp("to_q"))?; + let to_kv = linear_no_bias(context_dim, inner_dim * 2, vs.pp("to_kv"))?; + let to_out = linear_no_bias(inner_dim, query_dim, vs.pp("to_out"))?; + let span = tracing::span!(tracing::Level::TRACE, "xa"); + let span_attn = tracing::span!(tracing::Level::TRACE, "xa-attn"); + let span_softmax = tracing::span!(tracing::Level::TRACE, "xa-softmax"); Ok(Self { - heads, to_q, to_kv, to_out, - dim_head, + heads, + scale, + span, + span_attn, + span_softmax, }) } - // Cross attn takes queries from the mistral decoder and kv from latent attention model - fn forward(&self, x: &Tensor, context: &Tensor) -> Result { - let h = self.heads; - let q = self.to_q.forward(x)?; - let kv_chunks = self - .to_kv - .forward(context)? - .chunk(2, context.shape().dims().len() - 1)?; - let (k, v) = (kv_chunks[0].clone(), kv_chunks[1].clone()); - - let (b_sz, q_len, _) = q.dims3()?; - let q = q - .reshape((b_sz, q_len, h, self.dim_head))? - .transpose(1, 2)? - .contiguous()?; - - let (_, q_len, _) = k.dims3()?; - let k = k - .reshape((b_sz, q_len, h, self.dim_head))? - .transpose(1, 2)? - .contiguous()?; - - let (_, q_len, _) = v.dims3()?; - let v = v - .reshape((b_sz, q_len, h, self.dim_head))? + fn reshape_heads_to_batch_dim(&self, xs: &Tensor) -> Result { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size, seq_len, self.heads, dim / self.heads))? .transpose(1, 2)? - .contiguous()?; - - let scale = 1f64 / f64::sqrt(self.dim_head as f64); - - let attn_weight = (q.matmul(&k.transpose(2, 3)?)? * scale)?; - let attn_weight = softmax_last_dim(&attn_weight)?; - - let out = attn_weight.matmul(&v)?; + .reshape((batch_size * self.heads, seq_len, dim / self.heads)) + } - let (_, _, q_len, _) = out.dims4()?; - let out = out + fn reshape_batch_dim_to_heads(&self, xs: &Tensor) -> Result { + let (batch_size, seq_len, dim) = xs.dims3()?; + xs.reshape((batch_size / self.heads, self.heads, seq_len, dim))? .transpose(1, 2)? - .reshape((b_sz, q_len, self.dim_head * h))?; - - self.to_out.forward(&out) + .reshape((batch_size / self.heads, seq_len, dim * self.heads)) } -} - -#[derive(Debug, Clone)] -enum PreNormInnerLayer { - Attention(Attention), - FeedForward(FeedForward), -} - -#[derive(Debug, Clone)] -struct PreNorm { - norm: LayerNorm, - norm_context: Option, - inner_layer: PreNormInnerLayer, -} -impl PreNorm { - fn new( - dim: usize, - context_dim: Option, - inner_layer: PreNormInnerLayer, - norm_vb: VarBuilder, - norm_context_vb: Option, - ) -> Result { - let norm = layer_norm(dim, candle_nn::LayerNormConfig::default(), norm_vb)?; - - let norm_context = match context_dim { - Some(context_dim) => { - let norm_context_vb = norm_context_vb - .expect("norm_context_vb must be passed if context_dim is passed"); - match layer_norm( - context_dim, - candle_nn::LayerNormConfig::default(), - norm_context_vb, - ) { - Ok(norm_context) => Some(norm_context), - Err(e) => return Err(e), - } - } - None => None, + fn attention(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result { + let _enter = self.span_attn.enter(); + + let in_dtype = query.dtype(); + let query = query.to_dtype(DType::F32)?; + let key = key.to_dtype(DType::F32)?; + let value = value.to_dtype(DType::F32)?; + let xs = query.matmul(&(key.t()? * self.scale)?)?; + let xs = { + let _enter = self.span_softmax.enter(); + softmax_last_dim(&xs)? }; - Ok(Self { - norm, - norm_context, - inner_layer, - }) + let xs = xs.matmul(&value)?.to_dtype(in_dtype)?; + + self.reshape_batch_dim_to_heads(&xs) } - // Applies a layernorm to the input before passing to cross attn or feed forward fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result { - let xs = self.norm.forward(xs)?; - - let mut normed_context = None; - if let Some(norm_context) = &self.norm_context { - if let Some(context) = context { - normed_context = Some(norm_context.forward(context)?); - } - } + let _enter = self.span.enter(); + let query = self.to_q.forward(xs)?; + let context = context.unwrap_or(xs).contiguous()?; + let kv_chunks = self + .to_kv + .forward(&context)? + .chunk(2, context.shape().dims().len() - 1)?; + let (key, value) = (kv_chunks[0].clone(), kv_chunks[1].clone()); + let query = self.reshape_heads_to_batch_dim(&query)?; + let key = self.reshape_heads_to_batch_dim(&key)?; + let value = self.reshape_heads_to_batch_dim(&value)?; - match &self.inner_layer { - PreNormInnerLayer::Attention(attn) => attn.forward(&xs, &normed_context.unwrap()), - PreNormInnerLayer::FeedForward(ff) => ff.forward(&xs), - } + let xs = self.attention(&query, &key, &value)?; + self.to_out.forward(&xs) } } -#[derive(Debug, Clone)] -struct LatentAttentionModel { - cross_attn: PreNorm, - ff: PreNorm, - output_normalize: bool, +#[derive(Debug)] +pub struct Model { + embedding_model: EmbeddingModel, + cross_attn: CrossAttention, + cross_attn_norm: LayerNorm, + cross_attn_context_norm: LayerNorm, + ff: FeedForward, + ff_norm: LayerNorm, latents: Tensor, -} - -impl LatentAttentionModel { - fn new(vb: VarBuilder, config: LatentAttentionConfig) -> Result { - let vb_cross = vb.pp("cross_attend_blocks"); - - let num_latents = config.num_latents_value; - let latent_dim = config.latent_dim; - let cross_heads = config.num_cross_heads; - let cross_dim_head = config.cross_dim_head; - let dim = config.hidden_dim; - let hidden_size = config.hidden_size; - - let cross_attn = PreNorm::new( - latent_dim, - Some(hidden_size), - PreNormInnerLayer::Attention(Attention::new( - latent_dim, - Some(dim), - Some(cross_heads), - Some(cross_dim_head), - vb_cross.pp("0.fn.to_q"), - vb_cross.pp("0.fn.to_kv"), - vb_cross.pp("0.fn.to_out"), - )?), - vb_cross.pp("0.norm"), - Some(vb_cross.pp("0.norm_context")), - )?; - - let ff = PreNorm::new( - latent_dim, - None, - PreNormInnerLayer::FeedForward(FeedForward::new( - latent_dim, - vb_cross.pp("1.fn.net.0"), - vb_cross.pp("1.fn.net.2"), - )?), - vb_cross.pp("1.norm"), - None, - )?; - - let output_normalize = config.output_normalize; - let latents = vb.get((num_latents, latent_dim), "latents")?; - - Ok(Self { - cross_attn, - ff, - output_normalize, - latents, - }) - } - - fn forward(&self, hiddens: &Tensor, attention_mask: &Tensor) -> Result { - let b = hiddens.dims()[0]; - let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?; - - let hiddens = (self.cross_attn.forward(hiddens, Some(&x))? + hiddens)?; - let hiddens = (self.ff.forward(&hiddens, None)? + hiddens)?; - - // Mean pooling - let hiddens_masked = hiddens.broadcast_mul(&attention_mask.unsqueeze(D::Minus1)?)?; - let s = hiddens_masked.sum(1)?; - let d = attention_mask.sum_keepdim(1)?; - let hiddens = s.broadcast_div(&d)?; - - if self.output_normalize { - let hiddens = div_l2_norm(&hiddens)?; - - Ok(hiddens) - } else { - Ok(hiddens) - } - } -} - -#[derive(Debug, Clone)] -pub struct NVEmbedModel { - latent_attention_model: LatentAttentionModel, - embedding_model: MistralModel, pub device: Device, pub dtype: DType, } -impl NVEmbedModel { - pub fn new(vb: VarBuilder, output_normalize: bool) -> Result { +impl Model { + pub fn new(vb: VarBuilder) -> Result { + // Embedding model let cfg = Config::config_7b_v0_1(false); - let embedding_model = MistralModel::new(&cfg, vb.pp("embedding_model"))?; - let hidden_size = embedding_model.cfg.hidden_size; - let latent_attention_model = LatentAttentionModel::new( - vb.pp("latent_attention_model"), - LatentAttentionConfig::new(hidden_size, output_normalize), + let embedding_model = EmbeddingModel::new(&cfg, vb.pp("embedding_model"))?; + + // Latent attention + let dim = 4096; + let vb = vb.pp("latent_attention_model"); + let latents = vb.get((512, dim), "latents")?; + + // Cross attend blocks + let vb = vb.pp("cross_attend_blocks"); + let cross_attn_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("0.norm"))?; + let cross_attn_context_norm = layer_norm( + dim, + candle_nn::LayerNormConfig::default(), + vb.pp("0.norm_context"), )?; + let cross_attn = CrossAttention::new(vb.pp("0.fn"), dim, None, 8, 4096)?; + + let ff_norm = layer_norm(dim, LayerNormConfig::default(), vb.pp("1.norm"))?; + let ff = FeedForward::new(vb.pp("1.fn"), dim, None, 4)?; Ok(Self { - latent_attention_model, embedding_model, + cross_attn, + cross_attn_norm, + cross_attn_context_norm, + ff, + ff_norm, + latents, device: vb.device().clone(), dtype: vb.dtype(), }) @@ -342,15 +207,27 @@ impl NVEmbedModel { attn_mask: &Tensor, pool_mask: &Tensor, ) -> Result { - let outputs = self + // Embedding model + let hiddens = self .embedding_model .forward(attn_mask, input_ids, self.dtype)?; - self.latent_attention_model.forward(&outputs, pool_mask) - } -} + // Latent attention + let b = hiddens.dims()[0]; + let x = self.latents.unsqueeze(0)?.repeat((b, 1, 1))?; + let original_hiddens = &hiddens; -fn div_l2_norm(v: &Tensor) -> Result { - let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; - v.broadcast_div(&l2_norm) + let hiddens = self.cross_attn_norm.forward(original_hiddens)?; + let x = self.cross_attn_context_norm.forward(&x)?; + let cross_hiddens = (self.cross_attn.forward(&hiddens, Some(&x))? + original_hiddens)?; + + let hiddens = self.ff_norm.forward(&cross_hiddens)?; + let hiddens = (self.ff.forward(&hiddens)? + cross_hiddens)?; + + // Mean pooling + let hiddens_masked = hiddens.broadcast_mul(&pool_mask.unsqueeze(D::Minus1)?)?; + let s = hiddens_masked.sum(1)?; + let d = pool_mask.sum_keepdim(1)?; + s.broadcast_div(&d) + } }