From 461e8c1685e003bdddfd1e7d1aa5092786ca9df5 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 13 Jan 2025 09:39:27 +0200 Subject: [PATCH] ModernBERT model (#2713) * layer_norm_no_bias * Modernbert model. * Format + cleanup error. --------- Co-authored-by: laurent --- candle-examples/examples/modernbert/README.md | 12 + candle-examples/examples/modernbert/main.rs | 180 ++++++++ candle-nn/src/layer_norm.rs | 9 + candle-nn/src/lib.rs | 4 +- candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/modernbert.rs | 407 ++++++++++++++++++ 6 files changed, 612 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/modernbert/README.md create mode 100644 candle-examples/examples/modernbert/main.rs create mode 100644 candle-transformers/src/models/modernbert.rs diff --git a/candle-examples/examples/modernbert/README.md b/candle-examples/examples/modernbert/README.md new file mode 100644 index 0000000000..4eba2d7dbd --- /dev/null +++ b/candle-examples/examples/modernbert/README.md @@ -0,0 +1,12 @@ +# candle-modernbert + +ModernBERT is a bidirectional encoder-only language model. In this example it is used for the fill-mask task: + +## Usage + +```bash +cargo run --example modernbert --release -- --model modern-bert-large --prompt 'The capital of France is [MASK].' +``` +```markdown +Sentence: 1 : The capital of France is Paris. +``` diff --git a/candle-examples/examples/modernbert/main.rs b/candle-examples/examples/modernbert/main.rs new file mode 100644 index 0000000000..122aa99533 --- /dev/null +++ b/candle-examples/examples/modernbert/main.rs @@ -0,0 +1,180 @@ +use std::path::PathBuf; + +use anyhow::{Error as E, Result}; +use candle::{Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::modernbert; +use clap::{Parser, ValueEnum}; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +#[derive(Debug, Clone, ValueEnum)] +enum Model { + ModernBertBase, + ModernBertLarge, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long, default_value = "modern-bert-base")] + model: Model, + + // Path to the tokenizer file. + #[arg(long)] + tokenizer_file: Option, + + // Path to the weight files. + #[arg(long)] + weight_files: Option, + + // Path to the config file. + #[arg(long)] + config_file: Option, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => match args.model { + Model::ModernBertBase => "answerdotai/ModernBERT-base".to_string(), + Model::ModernBertLarge => "answerdotai/ModernBERT-large".to_string(), + }, + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + + let weights_filename = match args.weight_files { + Some(files) => PathBuf::from(files), + None => match repo.get("model.safetensors") { + Ok(safetensors) => safetensors, + Err(_) => match repo.get("pytorch_model.bin") { + Ok(pytorch_model) => pytorch_model, + Err(e) => { + anyhow::bail!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}") + } + }, + }, + }; + + let config = std::fs::read_to_string(config_filename)?; + let config: modernbert::Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let device = candle_examples::device(args.cpu)?; + + let vb = if weights_filename.ends_with("model.safetensors") { + unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F32, &device) + .unwrap() + } + } else { + println!("Loading weights from pytorch_model.bin"); + VarBuilder::from_pth(&weights_filename, candle::DType::F32, &device).unwrap() + }; + tokenizer + .with_padding(Some(PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + pad_id: config.pad_token_id, + ..Default::default() + })) + .with_truncation(None) + .map_err(E::msg)?; + + let prompt = match &args.prompt { + Some(p) => vec![p.as_str()], + None => vec![ + "Hello I'm a [MASK] model.", + "I'm a [MASK] boy.", + "I'm [MASK] in berlin.", + "The capital of France is [MASK].", + ], + }; + let model = modernbert::ModernBertForMaskedLM::load(vb, &config)?; + + let input_ids = tokenize_batch(&tokenizer, prompt.clone(), &device)?; + let attention_mask = get_attention_mask(&tokenizer, prompt.clone(), &device)?; + + let output = model + .forward(&input_ids, &attention_mask)? + .to_dtype(candle::DType::F32)?; + + let max_outs = output.argmax(2)?; + + let max_out = max_outs.to_vec2::()?; + let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect(); + let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap(); + for (i, sentence) in decoded.iter().enumerate() { + println!("Sentence: {} : {}", i + 1, sentence); + } + + Ok(()) +} + +pub fn tokenize_batch( + tokenizer: &Tokenizer, + input: Vec<&str>, + device: &Device, +) -> anyhow::Result { + let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?; + + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::>>()?; + + Ok(Tensor::stack(&token_ids, 0)?) +} + +pub fn get_attention_mask( + tokenizer: &Tokenizer, + input: Vec<&str>, + device: &Device, +) -> anyhow::Result { + let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?; + + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::>>()?; + Ok(Tensor::stack(&attention_mask, 0)?) +} diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index b7dd61cba1..468fe24d26 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -155,6 +155,15 @@ pub fn layer_norm>( }) } +pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result { + let config = LayerNormConfig { + eps, + remove_mean: true, + affine: false, + }; + layer_norm(size, config, vb) +} + /// RmsNorm is a specialized version of the LayerNorm module. #[derive(Clone, Debug)] pub struct RmsNorm(LayerNorm); diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index eb3cde4a75..2113566d33 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -46,7 +46,9 @@ pub use embedding::{embedding, Embedding}; pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; -pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; +pub use layer_norm::{ + layer_norm, layer_norm_no_bias, rms_norm, LayerNorm, LayerNormConfig, RmsNorm, +}; pub use linear::{linear, linear_b, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 5f56699135..473a276f0d 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -60,6 +60,7 @@ pub mod mmdit; pub mod mobileclip; pub mod mobilenetv4; pub mod mobileone; +pub mod modernbert; pub mod moondream; pub mod mpt; pub mod nvembed_v2; diff --git a/candle-transformers/src/models/modernbert.rs b/candle-transformers/src/models/modernbert.rs new file mode 100644 index 0000000000..b0ba9b4695 --- /dev/null +++ b/candle-transformers/src/models/modernbert.rs @@ -0,0 +1,407 @@ +//! ModernBERT +//! +//! ModernBERT is a modernized bidirectional encoder-only Transformer model. +//! - [Arxiv](https://arxiv.org/abs/2412.13663) "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference" +//! - Upstream [Github repo](https://github.com/AnswerDotAI/ModernBERT). +//! - See modernbert in [candle-examples](https://github.com/huggingface/candle/tree/main/candle-examples/) for runnable code +//! + +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{ + embedding, layer_norm_no_bias, linear_no_bias, ops::softmax, Embedding, LayerNorm, Linear, + Module, VarBuilder, +}; +use serde::Deserialize; + +use core::f32; +use std::sync::Arc; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub max_position_embeddings: usize, + pub layer_norm_eps: f64, + pub pad_token_id: u32, + pub global_attn_every_n_layers: usize, + pub global_rope_theta: f64, + pub local_attention: usize, + pub local_rope_theta: f64, +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, config: &Config, rope_theta: f64, dev: &Device) -> Result { + let dim = config.hidden_size / config.num_attention_heads; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let max_seq_len = config.max_position_embeddings; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &self.cos, &self.sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &self.cos, &self.sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Clone)] +struct ModernBertAttention { + qkv: Linear, + proj: Linear, + num_attention_heads: usize, + attention_head_size: usize, + rotary_emb: Arc, +} + +impl ModernBertAttention { + fn load(vb: VarBuilder, config: &Config, rotary_emb: Arc) -> Result { + let num_attention_heads = config.num_attention_heads; + let attention_head_size = config.hidden_size / config.num_attention_heads; + + let qkv = linear_no_bias(config.hidden_size, config.hidden_size * 3, vb.pp("Wqkv"))?; + let proj = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("Wo"))?; + + Ok(Self { + qkv, + proj, + num_attention_heads, + attention_head_size, + rotary_emb, + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let xs = hidden_states.clone(); + let (b, seq_len, d) = xs.dims3()?; + let qkv = xs + .apply(&self.qkv)? + .reshape(( + b, + seq_len, + 3, + self.num_attention_heads, + self.attention_head_size, + ))? + .permute((2, 0, 3, 1, 4))?; + + let q = qkv.get(0)?; + let k = qkv.get(1)?; + let v = qkv.get(2)?; + + let (q, k) = self.rotary_emb.apply_rotary_emb_qkv(&q, &k)?; + + let scale = (self.attention_head_size as f64).powf(-0.5); + let q = (q * scale)?; + + let att = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?; + + let att = att.broadcast_add(attention_mask)?; + let att = softmax(&att, D::Minus1)?; + + let xs = att.matmul(&v)?; + + let xs = xs.transpose(1, 2)?.reshape((b, seq_len, d))?; + let xs = xs.apply(&self.proj)?; + let xs = xs.reshape((b, seq_len, d))?; + + Ok(xs) + } +} + +#[derive(Clone)] +pub struct ModernBertMLP { + wi: Linear, + wo: Linear, +} + +impl ModernBertMLP { + fn load(vb: VarBuilder, config: &Config) -> Result { + let wi = linear_no_bias( + config.hidden_size, + config.intermediate_size * 2, + vb.pp("Wi"), + )?; + let wo = linear_no_bias(config.intermediate_size, config.hidden_size, vb.pp("Wo"))?; + Ok(Self { wi, wo }) + } +} + +impl Module for ModernBertMLP { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.wi)?; + let xs = xs.chunk(2, D::Minus1)?; + let xs = (&xs[0].gelu_erf()? * &xs[1])?.apply(&self.wo)?; // GeGLU + Ok(xs) + } +} + +#[derive(Clone)] +pub struct ModernBertLayer { + attn: ModernBertAttention, + mlp: ModernBertMLP, + attn_norm: Option, + mlp_norm: LayerNorm, + uses_local_attention: bool, +} + +impl ModernBertLayer { + fn load( + vb: VarBuilder, + config: &Config, + rotary_emb: Arc, + uses_local_attention: bool, + ) -> Result { + let attn = ModernBertAttention::load(vb.pp("attn"), config, rotary_emb)?; + let mlp = ModernBertMLP::load(vb.pp("mlp"), config)?; + let attn_norm = layer_norm_no_bias( + config.hidden_size, + config.layer_norm_eps, + vb.pp("attn_norm"), + ) + .ok(); + let mlp_norm = + layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("mlp_norm"))?; + Ok(Self { + attn, + mlp, + attn_norm, + mlp_norm, + uses_local_attention, + }) + } + + fn forward( + &self, + xs: &Tensor, + global_attention_mask: &Tensor, + local_attention_mask: &Tensor, + ) -> Result { + let residual = xs.clone(); + let mut xs = xs.clone(); + if let Some(norm) = &self.attn_norm { + xs = xs.apply(norm)?; + } + + let attention_mask = if self.uses_local_attention { + &global_attention_mask.broadcast_add(local_attention_mask)? + } else { + global_attention_mask + }; + let xs = self.attn.forward(&xs, attention_mask)?; + let xs = (xs + residual)?; + let mlp_out = xs.apply(&self.mlp_norm)?.apply(&self.mlp)?; + let xs = (xs + mlp_out)?; + Ok(xs) + } +} + +#[derive(Clone)] +pub struct ModernBertHead { + dense: Linear, + norm: LayerNorm, +} + +impl ModernBertHead { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear_no_bias(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let norm = layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, vb.pp("norm"))?; + Ok(Self { dense, norm }) + } +} + +impl Module for ModernBertHead { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.dense)?.gelu_erf()?.apply(&self.norm)?; + Ok(xs) + } +} + +#[derive(Clone)] +pub struct ModernBertDecoder { + decoder: Linear, +} + +impl ModernBertDecoder { + fn load(vb: VarBuilder, config: &Config) -> Result { + // The decoder weights are tied with the embeddings layer weights + let decoder_weights = vb.get( + (config.vocab_size, config.hidden_size), + "model.embeddings.tok_embeddings.weight", + )?; + let decoder_bias = vb.get(config.vocab_size, "decoder.bias")?; + let decoder = Linear::new(decoder_weights, Some(decoder_bias)); + Ok(Self { decoder }) + } +} + +impl Module for ModernBertDecoder { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.decoder)?; + Ok(xs) + } +} + +// Global attention mask calculated from padded token inputs +fn prepare_4d_attention_mask( + mask: &Tensor, + dtype: DType, + tgt_len: Option, +) -> Result { + let bsz = mask.dim(0)?; + let src_len = mask.dim(1)?; + let tgt_len = tgt_len.unwrap_or(src_len); + + let expanded_mask = mask + .unsqueeze(1)? + .unsqueeze(2)? + .expand((bsz, 1, tgt_len, src_len))? + .to_dtype(dtype)?; + + let inverted_mask = (1.0 - expanded_mask)?; + + (inverted_mask * f32::MIN as f64)?.to_dtype(dtype) +} + +// Attention mask caused by the sliding window +fn get_local_attention_mask( + seq_len: usize, + max_distance: usize, + device: &Device, +) -> Result { + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| { + (0..seq_len).map(move |j| { + if (j as i32 - i as i32).abs() > max_distance as i32 { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (seq_len, seq_len), device) +} + +// ModernBERT backbone +#[derive(Clone)] +pub struct ModernBert { + word_embeddings: Embedding, + norm: LayerNorm, + layers: Vec, + final_norm: LayerNorm, + head: ModernBertHead, + local_attention_size: usize, +} + +impl ModernBert { + fn load(vb: VarBuilder, config: &Config) -> Result { + let word_embeddings = embedding( + config.vocab_size, + config.hidden_size, + vb.pp("model.embeddings.tok_embeddings"), + )?; + let norm = layer_norm_no_bias( + config.hidden_size, + config.layer_norm_eps, + vb.pp("model.embeddings.norm"), + )?; + let global_rotary_emb = Arc::new(RotaryEmbedding::new( + vb.dtype(), + config, + config.global_rope_theta, + vb.device(), + )?); + let local_rotary_emb = Arc::new(RotaryEmbedding::new( + vb.dtype(), + config, + config.local_rope_theta, + vb.device(), + )?); + + let mut layers = Vec::with_capacity(config.num_hidden_layers); + for layer_id in 0..config.num_hidden_layers { + let layer_uses_local_attention = layer_id % config.global_attn_every_n_layers != 0; + layers.push(ModernBertLayer::load( + vb.pp(format!("model.layers.{layer_id}")), + config, + if layer_uses_local_attention { + local_rotary_emb.clone() + } else { + global_rotary_emb.clone() + }, + layer_uses_local_attention, + )?); + } + + let final_norm = layer_norm_no_bias( + config.hidden_size, + config.layer_norm_eps, + vb.pp("model.final_norm"), + )?; + let head = ModernBertHead::load(vb.pp("head"), config)?; + + Ok(Self { + word_embeddings, + norm, + layers, + final_norm, + head, + local_attention_size: config.local_attention, + }) + } + + fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result { + let seq_len = xs.shape().dims()[1]; + let global_attention_mask = + prepare_4d_attention_mask(mask, DType::F32, None)?.to_device(xs.device())?; + let local_attention_mask = + get_local_attention_mask(seq_len, self.local_attention_size / 2, xs.device())?; + let mut xs = xs.apply(&self.word_embeddings)?.apply(&self.norm)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs, &global_attention_mask, &local_attention_mask)?; + } + let xs = xs.apply(&self.final_norm)?.apply(&self.head)?; + Ok(xs) + } +} + +// ModernBERT for the fill-mask task +#[derive(Clone)] +pub struct ModernBertForMaskedLM { + model: ModernBert, + decoder: ModernBertDecoder, +} + +impl ModernBertForMaskedLM { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let model = ModernBert::load(vb.clone(), config)?; + let decoder = ModernBertDecoder::load(vb.clone(), config)?; + Ok(Self { model, decoder }) + } + + pub fn forward(&self, xs: &Tensor, mask: &Tensor) -> Result { + let xs = self.model.forward(xs, mask)?.apply(&self.decoder)?; + Ok(xs) + } +}