From e5f009d9486b6d26cee9a2c504ee11c6ac7501b3 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 2 Sep 2024 22:16:33 -0400 Subject: [PATCH 1/2] Add the phi 3.5 moe model --- candle-kernels/src/lib.rs | 4 + candle-transformers/src/layers/mod.rs | 3 + candle-transformers/src/layers/rope.rs | 306 ++++++++++ candle-transformers/src/lib.rs | 1 + candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/phi3_5moe.rs | 607 ++++++++++++++++++++ 6 files changed, 922 insertions(+) create mode 100644 candle-transformers/src/layers/mod.rs create mode 100644 candle-transformers/src/layers/rope.rs create mode 100644 candle-transformers/src/models/phi3_5moe.rs diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index 1c73d6b774..cec1b1e2d4 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -3,7 +3,11 @@ pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx")); pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const FUSED_LAYER_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_layer_norm.ptx")); +pub const FUSED_RMS_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rms_norm.ptx")); +pub const FUSED_ROPE: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_rope.ptx")); pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx")); +pub const KVCONCAT: &str = include_str!(concat!(env!("OUT_DIR"), "/kvconcat.ptx")); pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const SORT: &str = include_str!(concat!(env!("OUT_DIR"), "/sort.ptx")); diff --git a/candle-transformers/src/layers/mod.rs b/candle-transformers/src/layers/mod.rs new file mode 100644 index 0000000000..b9e47503c7 --- /dev/null +++ b/candle-transformers/src/layers/mod.rs @@ -0,0 +1,3 @@ +pub mod rope; + +pub use rope::{PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding}; diff --git a/candle-transformers/src/layers/rope.rs b/candle-transformers/src/layers/rope.rs new file mode 100644 index 0000000000..ffcf096059 --- /dev/null +++ b/candle-transformers/src/layers/rope.rs @@ -0,0 +1,306 @@ +use std::{ops::Mul, str::FromStr}; + +use candle::{DType, Device, IndexOp, Result, Tensor}; +use serde::Deserialize; + +/// RoPE supporting LongRope +#[derive(Debug, Clone)] +pub struct PhiRotaryEmbedding { + short_sin: Tensor, + short_cos: Tensor, + long_cos: Option, + long_sin: Option, + original_max_position_embeddings: usize, +} + +#[derive(Debug, Clone, Deserialize)] +pub enum ScaledRopeType { + #[serde(alias = "su")] + #[serde(alias = "longrope")] + Su, + #[serde(alias = "yarn")] + Yarn, +} + +impl FromStr for ScaledRopeType { + type Err = candle::Error; + fn from_str(s: &str) -> std::result::Result { + match s { + "su" | "longrope" => Ok(Self::Su), + "yarn" => Ok(Self::Yarn), + _ => Err(candle::Error::Msg( + "Expected either `su` or `yarn` scaled RoPE type.".to_string(), + )), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum PhiRopeScalingConfig { + Classic { + short_factor: Vec, + long_factor: Vec, + #[serde(rename = "type")] + scaling_type: ScaledRopeType, + }, + Scaled { + short_factor: Vec, + long_factor: Vec, + #[serde(rename = "type")] + scaling_type: ScaledRopeType, + long_mscale: f64, + short_mscale: f64, + }, +} + +pub struct PhiRopeConfig { + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub original_max_position_embeddings: usize, + pub rope_theta: f64, + pub head_dim: usize, +} + +impl PhiRotaryEmbedding { + fn new_classic_scaled( + short_factor: &[f64], + long_factor: &[f64], + scaling_type: &ScaledRopeType, + cfg: &PhiRopeConfig, + dtype: DType, + dev: &Device, + ) -> Result { + let max_seq_len = cfg.max_position_embeddings; + let dim = cfg.head_dim; + + // Calculate scale + let scale = + cfg.max_position_embeddings as f64 / cfg.original_max_position_embeddings as f64; + let scaling_factor = if scale <= 1.0 { + 1.0 + } else { + match scaling_type { + ScaledRopeType::Su => { + (1.0 + scale.ln() / (cfg.original_max_position_embeddings as f64).ln()).sqrt() + } + ScaledRopeType::Yarn => 0.1 * scale.ln() + 1.0, + } + }; + + // Calculate inv freqs for short, long + let inv_freq_long = (0..dim) + .step_by(2) + .enumerate() + .map(|(k, i)| { + (1f64 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32 + }) + .collect::>(); + let inv_freq_short = (0..dim) + .step_by(2) + .enumerate() + .map(|(k, i)| { + (1f64 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64))) as f32 + }) + .collect::>(); + let inv_freq_len = inv_freq_long.len(); + + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + + // Calculate sin,cos for long + let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len), dev)?; + let freqs_long = t.matmul(&inv_freq_long)?; + let long_sin = freqs_long.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let long_cos = freqs_long.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; + + // Calculate sin,cos for short + let inv_freq_short = + Tensor::from_vec(inv_freq_short, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?; + let freqs_short = t.matmul(&inv_freq_short)?; + let short_sin = freqs_short.sin()?.mul(scaling_factor)?.to_dtype(dtype)?; + let short_cos = freqs_short.cos()?.mul(scaling_factor)?.to_dtype(dtype)?; + + Ok(Self { + short_cos, + short_sin, + long_cos: Some(long_cos), + long_sin: Some(long_sin), + original_max_position_embeddings: cfg.original_max_position_embeddings, + }) + } + + fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result { + let max_seq_len = cfg.max_position_embeddings; + let dim = cfg.head_dim; + + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + let sin = freqs.sin()?.to_dtype(dtype)?; + let cos = freqs.cos()?.to_dtype(dtype)?; + Ok(Self { + short_cos: cos, + short_sin: sin, + long_cos: None, + long_sin: None, + original_max_position_embeddings: cfg.original_max_position_embeddings, + }) + } + + #[allow(clippy::too_many_arguments)] + fn new_scaled( + short_factor: &[f64], + long_factor: &[f64], + scaling_type: &ScaledRopeType, + long_mscale: f64, + short_mscale: f64, + cfg: &PhiRopeConfig, + dtype: DType, + dev: &Device, + ) -> Result { + let max_seq_len = cfg.max_position_embeddings; + let dim = cfg.head_dim; + + if !matches!(scaling_type, ScaledRopeType::Su) { + candle::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`."); + } + + if short_factor.len() != dim / 2 { + candle::bail!( + "Misaligned length {}, expected {} for `su`/`longrope` short rescale factors", + short_factor.len(), + dim / 2 + ); + } + if long_factor.len() != dim / 2 { + candle::bail!( + "Misaligned length {}, expected {} for `su`/`longrope` long rescale factors", + long_factor.len(), + dim / 2 + ); + } + + // Short cos/sin + let inv_freq_short: Vec<_> = (0..dim) + .step_by(2) + .enumerate() + .map(|(k, i)| { + 1f32 / (short_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32 + }) + .collect(); + let inv_freq_len_short = inv_freq_short.len(); + let inv_freq_short = Tensor::from_vec(inv_freq_short, (1, inv_freq_len_short), dev)?; + let t_short = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs_short = t_short.matmul(&inv_freq_short)?; + let sin_short = (freqs_short.sin()?.to_dtype(dtype)? * short_mscale)?; + let cos_short = (freqs_short.cos()?.to_dtype(dtype)? * short_mscale)?; + + // Long cos/sin + let inv_freq_long: Vec<_> = (0..dim) + .step_by(2) + .enumerate() + .map(|(k, i)| { + 1f32 / (long_factor[k] * cfg.rope_theta.powf(i as f64 / dim as f64)) as f32 + }) + .collect(); + let inv_freq_len_long = inv_freq_long.len(); + let inv_freq_long = Tensor::from_vec(inv_freq_long, (1, inv_freq_len_long), dev)?; + let t_long = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))?; + let freqs_long = t_long.matmul(&inv_freq_long)?; + let sin_long = (freqs_long.sin()?.to_dtype(dtype)? * long_mscale)?; + let cos_long = (freqs_long.cos()?.to_dtype(dtype)? * long_mscale)?; + Ok(Self { + short_cos: cos_short, + short_sin: sin_short, + long_cos: Some(cos_long), + long_sin: Some(sin_long), + original_max_position_embeddings: cfg.original_max_position_embeddings, + }) + } + + pub fn new(dtype: DType, cfg: impl Into, dev: &Device) -> Result { + let cfg: PhiRopeConfig = cfg.into(); + + match &cfg.rope_scaling { + Some(PhiRopeScalingConfig::Classic { + short_factor, + long_factor, + scaling_type, + }) => { + Self::new_classic_scaled(short_factor, long_factor, scaling_type, &cfg, dtype, dev) + } + + Some(PhiRopeScalingConfig::Scaled { + short_factor, + long_factor, + scaling_type, + long_mscale, + short_mscale, + }) => Self::new_scaled( + short_factor, + long_factor, + scaling_type, + *long_mscale, + *short_mscale, + &cfg, + dtype, + dev, + ), + + None => Self::new_unscaled(&cfg, dtype, dev), + } + } + + /// Returns (sin, cos) taking into account LongRope + fn get_long_or_short_sin_cos(&self, position_ids: &[usize]) -> (&Tensor, &Tensor) { + if self.long_cos.is_none() { + return (&self.short_sin, &self.short_cos); + } + let seq_len = position_ids.iter().max().unwrap() + 1; + if seq_len > self.original_max_position_embeddings { + ( + self.long_sin.as_ref().unwrap(), + self.long_cos.as_ref().unwrap(), + ) + } else { + (&self.short_sin, &self.short_cos) + } + } + + pub fn forward( + &self, + q: &Tensor, + k: &Tensor, + seqlen_offsets: &[usize], + position_ids: &[usize], + ) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let mut q_embeds = Vec::new(); + let mut k_embeds = Vec::new(); + let (sin, cos) = self.get_long_or_short_sin_cos(position_ids); + for (i, offset) in seqlen_offsets.iter().enumerate() { + let cos = cos.narrow(0, *offset, seq_len)?; + let sin = sin.narrow(0, *offset, seq_len)?; + let q_embed = + candle_nn::rotary_emb::rope(&q.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + let k_embed = + candle_nn::rotary_emb::rope(&k.i(i)?.unsqueeze(0)?.contiguous()?, &cos, &sin)?; + q_embeds.push(q_embed); + k_embeds.push(k_embed); + } + Ok((Tensor::cat(&q_embeds, 0)?, Tensor::cat(&k_embeds, 0)?)) + } +} diff --git a/candle-transformers/src/lib.rs b/candle-transformers/src/lib.rs index b2b062a9d7..e0df04ee88 100644 --- a/candle-transformers/src/lib.rs +++ b/candle-transformers/src/lib.rs @@ -1,4 +1,5 @@ pub mod generation; +pub mod layers; pub mod models; pub mod object_detection; pub mod pipelines; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 9f7856ea20..91faee99e4 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -48,6 +48,7 @@ pub mod parler_tts; pub mod persimmon; pub mod phi; pub mod phi3; +pub mod phi3_5moe; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; diff --git a/candle-transformers/src/models/phi3_5moe.rs b/candle-transformers/src/models/phi3_5moe.rs new file mode 100644 index 0000000000..03522fdc7c --- /dev/null +++ b/candle-transformers/src/models/phi3_5moe.rs @@ -0,0 +1,607 @@ +#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] + +use candle::{ + CpuStorage, CustomOp1, DType, Device, Error, IndexOp, Layout, Module, Result, Shape, Tensor, + WithDType, D, +}; +use candle_nn::{layer_norm, LayerNorm, Linear, VarBuilder}; +use std::sync::Arc; + +use crate::layers::{PhiRopeConfig, PhiRopeScalingConfig, PhiRotaryEmbedding}; + +fn default_use_flash_attn() -> bool { + false +} + +#[derive(Debug, Clone, serde::Deserialize, Default)] +pub struct Config { + pub vocab_size: usize, + pub hidden_act: candle_nn::Activation, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub rms_norm_eps: f64, + pub rope_theta: f64, + pub rope_scaling: Option, + pub max_position_embeddings: usize, + pub sliding_window: Option, + pub original_max_position_embeddings: usize, + pub lm_head_bias: bool, + pub attention_bias: bool, + pub num_local_experts: usize, + pub router_jitter_noise: f64, + #[serde(default = "default_use_flash_attn")] + pub use_flash_attn: bool, +} + +impl From for PhiRopeConfig { + fn from(val: Config) -> Self { + PhiRopeConfig { + rope_scaling: val.rope_scaling, + max_position_embeddings: val.max_position_embeddings, + original_max_position_embeddings: val.original_max_position_embeddings, + rope_theta: val.rope_theta, + head_dim: val.hidden_size / val.num_attention_heads, + } + } +} + +impl Config { + pub fn head_dim(&self) -> usize { + self.hidden_size / self.num_attention_heads + } + + pub fn with_flash_attn(mut self, use_flash_attn: bool) -> Self { + self.use_flash_attn = use_flash_attn; + self + } +} + +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + head_dim: usize, + rotary_emb: Arc, + kv_cache: Option<(Tensor, Tensor)>, + use_flash_attn: bool, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let head_dim = cfg.head_dim(); + + let q_proj = candle_nn::linear_b( + cfg.hidden_size, + num_heads * head_dim, + cfg.attention_bias, + vb.pp("q_proj"), + )?; + let k_proj = candle_nn::linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("k_proj"), + )?; + let v_proj = candle_nn::linear_b( + cfg.hidden_size, + num_kv_heads * head_dim, + cfg.attention_bias, + vb.pp("v_proj"), + )?; + let o_proj = candle_nn::linear_b( + num_heads * head_dim, + cfg.hidden_size, + cfg.attention_bias, + vb.pp("o_proj"), + )?; + + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + rotary_emb, + num_heads, + num_kv_heads, + head_dim, + kv_cache: None, + use_flash_attn: cfg.use_flash_attn, + }) + } + + #[allow(clippy::too_many_arguments)] + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + position_id: usize, + ) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let q = self.q_proj.forward(xs)?; + let k = self.k_proj.forward(xs)?; + let v = self.v_proj.forward(xs)?; + + let (q, k, v) = if q_len != 1 { + let q = q + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let k = k + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let v = v + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + (q, k, v) + } else { + let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?; + let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?; + let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?; + (q, k, v) + }; + + let (q, k) = self + .rotary_emb + .forward(&q, &k, &[seqlen_offset], &[position_id])?; + + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((prev_k, prev_v)) => { + let key_states = Tensor::cat(&[prev_k, &k], 2)?; + let value_states = Tensor::cat(&[prev_v, &v], 2)?; + (key_states, value_states) + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let attn_output = if self.use_flash_attn { + // flash-attn expects (b_sz, seq_len, nheads, head_dim) + let q = q.transpose(1, 2)?; + let k = k.transpose(1, 2)?; + let v = v.transpose(1, 2)?; + let softmax_scale = 1f32 / (self.head_dim as f32).sqrt(); + flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)? + } else { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&v)? + }; + + let attn_output = if attention_mask.is_some() { + attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))? + } else { + attn_output.reshape((b_sz, q_len, ()))? + }; + self.o_proj.forward(&attn_output) + } +} + +#[derive(Clone)] +struct Mlp { + w1: Linear, + w2: Linear, + w3: Linear, + act_fn: candle_nn::Activation, +} + +impl Mlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_size = cfg.hidden_size; + let i_size = cfg.intermediate_size; + + let w1 = candle_nn::linear_no_bias(hidden_size, i_size, vb.pp("w1"))?; + let w2 = candle_nn::linear_no_bias(i_size, hidden_size, vb.pp("w2"))?; + let w3 = candle_nn::linear_no_bias(hidden_size, i_size, vb.pp("w3"))?; + + Ok(Self { + w1, + w2, + w3, + act_fn: cfg.hidden_act, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let mut current_hidden_states = self.w1.forward(xs)?.apply(&self.act_fn)?; + let rhs = self.w3.forward(xs)?; + current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?; + self.w2.forward(¤t_hidden_states) + } +} + +struct NonZero {} + +impl NonZero { + // Sequential CPU version + fn nonzero(&self, vs: &[T], layout: &Layout) -> Vec { + let n = layout.dims().len(); + let mut result = Vec::new(); + let mut indices = vec![0u32; n]; + for (i, v) in vs.iter().enumerate() { + if !v.is_zero() { + let mut idx = i; + for (dim_index, dim) in layout.dims().iter().enumerate().rev() { + let d = idx % dim; + indices[dim_index] = u32::try_from(d).unwrap(); + idx /= dim; + } + result.extend_from_slice(&indices); + } + } + result + } +} + +impl CustomOp1 for NonZero { + fn name(&self) -> &'static str { + "nonzero" + } + + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> { + if !layout.is_contiguous() { + return Err(Error::RequiresContiguous { op: "nonzero" }); + } + let result = match storage { + candle::CpuStorage::U8(vs) => self.nonzero(vs, layout), + candle::CpuStorage::U32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::I64(vs) => self.nonzero(vs, layout), + candle::CpuStorage::BF16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F16(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F32(vs) => self.nonzero(vs, layout), + candle::CpuStorage::F64(vs) => self.nonzero(vs, layout), + }; + let index_len = layout.dims().len(); + let result_len = result.len() / index_len; + let result = CpuStorage::U32(result); + let shape = Shape::from_dims(&[result_len, index_len]); + Ok((result, shape)) + } +} + +pub trait NonZeroOp { + fn nonzero(&self) -> Result; +} + +impl NonZeroOp for Tensor { + fn nonzero(&self) -> Result { + if !self.is_contiguous() { + return Err(candle::Error::RequiresContiguous { op: "nonzero" }); + } + let original_device = self.device(); + self.to_device(&candle::Device::Cpu)? + .apply_op1_no_bwd(&NonZero {})? + .to_device(original_device) + } +} + +struct MoeMlp { + gate: candle_nn::Linear, + experts: Vec, + router_jitter_noise: f64, + num_experts: usize, +} + +// https://github.com/mokeyish/candle-ext/blob/main/src/masked_fill.rs +/// xs are on false (0), value is on true (1) +pub fn masked_fill(xs: &Tensor, mask: &Tensor, value: D) -> Result { + let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?; + let on_false = xs; + let res = mask + .broadcast_as(xs.shape())? + .where_cond(&on_true, on_false)?; + Ok(res) +} + +impl MoeMlp { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let num_experts = cfg.num_local_experts; + let gate = candle_nn::linear_no_bias(cfg.hidden_size, num_experts, vb.pp("gate"))?; + + let experts_vb = vb.pp("experts"); + let mut experts = Vec::with_capacity(num_experts); + for i in 0..num_experts { + experts.push(Mlp::new(cfg, experts_vb.pp(i))?); + } + + Ok(Self { + gate, + experts, + router_jitter_noise: cfg.router_jitter_noise, + num_experts, + }) + } + + fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> { + // Compute mask for sparsity + let selected_experts = scores.argmax_keepdim(D::Minus1)?; + let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?; + let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?; + let mask_logits_threshold = mask_logits_threshold + .broadcast_sub(scores)? + .broadcast_div(&factor)? + .gt(2. * jitter_eps)?; + + // Apply mask + let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?; + + // Compute scores + let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?; + let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?; + + // Mask out first expert + let masked_scores = scores.scatter_add( + &selected_experts + .broadcast_as(scores.shape())? + .contiguous()?, + &(scores.ones_like()? * f64::NEG_INFINITY)?, + D::Minus1, + )?; + + // Compute mask for sparsity + let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?; + let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?; + let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?; + let mask_logits_threshold = mask_logits_threshold + .broadcast_sub(scores)? + .broadcast_div(&factor)? + .gt(2. * jitter_eps)?; + + // Apply mask + let masked_gates_top2 = + masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?; + let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?; + let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?; + + let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?; + let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?; + + Ok((multiplier, selected_experts)) + } + + fn forward(&self, xs: &Tensor) -> Result { + let (bs, seq, hidden) = xs.dims3()?; + let xs = xs.reshape(((), hidden))?; + let xs_dev = xs.device(); + let xs = xs.to_device(&Device::Cpu)?; + + // Sparse MoE block accumulates hidden states on CPU, but MLP and gate weights are untouched (maybe on GPU) + + let router_logits = self + .gate + .forward(&xs.to_device(xs_dev)?)? + .to_device(&Device::Cpu)?; + let (routing_weights, selected_experts) = self.sparsemixer( + &router_logits.to_device(&Device::Cpu)?, + self.router_jitter_noise, + )?; + + let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?; + + // One hot encode the selected experts to create an expert mask + // this will be used to easily index which expert to activate + let experts_mask = + candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)? + .permute((2, 1, 0))?; + + // Loop over all avail experts in the model and perform the computation on each expert + for expert_idx in 0..self.num_experts { + let expert = &self.experts[expert_idx]; + let expert_mask = experts_mask.i(expert_idx)?; + assert_eq!(expert_mask.rank(), 2); + let nonzero_mask = expert_mask.contiguous()?.nonzero()?; + let idx = nonzero_mask.i((.., 0))?; + let top_x = nonzero_mask.i((.., 1))?; + + if top_x.dim(0)? == 0 { + continue; + } + + // Index the correct hidden staters and compute the expert hidden state + // for the current expert, we need to make sure to multiply the output hidden + // states by `routing_weights` on the corresponding tokens (top-1, top-2) + let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden))?; + let current_routing_weights = routing_weights + .index_select(&top_x, 0)? + .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?; + let exp_out = expert + .forward(¤t_state.to_device(xs_dev)?)? + .to_device(&Device::Cpu)?; + + let current_hidden_states = exp_out.broadcast_mul(¤t_routing_weights)?; + + final_hidden_states = final_hidden_states.index_add( + &top_x.contiguous()?, + ¤t_hidden_states.to_dtype(xs.dtype())?, + 0, + )?; + } + + final_hidden_states + .reshape((bs, seq, hidden))? + .to_device(xs_dev) + } +} + +struct DecoderLayer { + self_attn: Attention, + mlp: MoeMlp, + input_layernorm: LayerNorm, + post_attention_layernorm: LayerNorm, +} + +impl DecoderLayer { + #[allow(clippy::too_many_arguments)] + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MoeMlp::new(cfg, vb.pp("block_sparse_moe"))?; + let input_layernorm = + layer_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = layer_norm( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + #[allow(clippy::too_many_arguments)] + fn forward( + &mut self, + xs: &Tensor, + attention_mask: Option<&Tensor>, + seqlen_offset: usize, + position_id: usize, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self + .self_attn + .forward(&xs, attention_mask, seqlen_offset, position_id)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .mlp + .forward(&xs.apply(&self.post_attention_layernorm)?)?; + residual + xs + } +} + +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: LayerNorm, + lm_head: Linear, + device: Device, + sliding_window: Option, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let rotary_emb = Arc::new(PhiRotaryEmbedding::new( + vb.dtype(), + cfg.clone(), + vb.device(), + )?); + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = layer_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + let lm_head = candle_nn::linear_b( + cfg.hidden_size, + cfg.vocab_size, + cfg.lm_head_bias, + vb.pp("lm_head"), + )?; + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: vb.device().clone(), + sliding_window: cfg.sliding_window, + }) + } + + fn prepare_decoder_attention_mask( + &self, + tgt_len: usize, + seqlen_offset: usize, + dtype: DType, + ) -> Result { + let sliding_window = self.sliding_window.unwrap_or(tgt_len + 1); + let mask: Vec<_> = (0..tgt_len) + .flat_map(|i| { + (0..tgt_len).map(move |j| { + if i < j || j + sliding_window < i { + f32::NEG_INFINITY + } else { + 0. + } + }) + }) + .collect(); + let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?; + let mask = if seqlen_offset > 0 { + let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?; + Tensor::cat(&[&mask0, &mask], D::Minus1)? + } else { + mask + }; + mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))? + .to_dtype(dtype) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + seqlen_offset: usize, + position_id: usize, + ) -> Result { + let mut xs = self.embed_tokens.forward(input_ids)?; + + let seq_len = xs.dim(1)?; + let attention_mask = if seq_len <= 1 { + None + } else { + let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset, xs.dtype())?; + Some(mask) + }; + + for layer in self.layers.iter_mut() { + xs = layer.forward( + &xs, + attention_mask + .as_ref() + .map(|m| m.to_device(xs.device()).unwrap()) + .as_ref(), + seqlen_offset, + position_id, + )? + } + + self.lm_head.forward(&xs.apply(&self.norm)?) + } +} From 69f499fbe44232a18a4d401628402bc72ebc077b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Mon, 2 Sep 2024 22:27:01 -0400 Subject: [PATCH 2/2] Add example --- candle-examples/examples/phi3_5_moe/README.md | 13 + candle-examples/examples/phi3_5_moe/main.rs | 344 ++++++++++++++++++ candle-transformers/src/models/phi3_5moe.rs | 14 + 3 files changed, 371 insertions(+) create mode 100644 candle-examples/examples/phi3_5_moe/README.md create mode 100644 candle-examples/examples/phi3_5_moe/main.rs diff --git a/candle-examples/examples/phi3_5_moe/README.md b/candle-examples/examples/phi3_5_moe/README.md new file mode 100644 index 0000000000..13cea110cc --- /dev/null +++ b/candle-examples/examples/phi3_5_moe/README.md @@ -0,0 +1,13 @@ +# candle-phi3_5_moe: High performing 16x3.8B model, 6.6B active parameters + +Model: [Phi-3.5 MoE](https://huggingface.co/microsoft/Phi-3.5-MoE-instruct) + +The candle implementation provides the standard version. + +## Running some examples + +For the v2 version. +```bash +$ cargo run --example phi3_5_moe --release -- --model 2 \ + --prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?" +``` diff --git a/candle-examples/examples/phi3_5_moe/main.rs b/candle-examples/examples/phi3_5_moe/main.rs new file mode 100644 index 0000000000..7738a76c00 --- /dev/null +++ b/candle-examples/examples/phi3_5_moe/main.rs @@ -0,0 +1,344 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::phi3_5moe::{Config, Model}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: Model, + device: Device, + tokenizer: TokenOutputStream, + logits_processor: LogitsProcessor, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, +} + +impl TextGeneration { + #[allow(clippy::too_many_arguments)] + fn new( + model: Model, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, + verbose_prompt: bool, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp, top_p); + Self { + model, + tokenizer: TokenOutputStream::new(tokenizer), + logits_processor, + repeat_penalty, + repeat_last_n, + verbose_prompt, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + use std::io::Write; + println!("starting the inference loop"); + let tokens = self + .tokenizer + .tokenizer() + .encode(prompt, true) + .map_err(E::msg)?; + if tokens.is_empty() { + anyhow::bail!("Empty prompts are not supported in the phi model.") + } + if self.verbose_prompt { + for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { + let token = token.replace('▁', " ").replace("<0x0A>", "\n"); + println!("{id:7} -> '{token}'"); + } + } + let mut tokens = tokens.get_ids().to_vec(); + let mut generated_tokens = 0usize; + let eos_token = match self.tokenizer.get_token("<|endoftext|>") { + Some(token) => token, + None => anyhow::bail!("cannot find the endoftext token"), + }; + print!("{prompt}"); + std::io::stdout().flush()?; + let start_gen = std::time::Instant::now(); + let mut pos = 0; + for index in 0..sample_len { + let context_size = if index > 0 { 1 } else { tokens.len() }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, pos, pos + 1)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits = if self.repeat_penalty == 1. { + logits + } else { + let start_at = tokens.len().saturating_sub(self.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + self.repeat_penalty, + &tokens[start_at..], + )? + }; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + generated_tokens += 1; + if next_token == eos_token { + if let Some(t) = self.tokenizer.decode_rest()? { + print!("{t}"); + std::io::stdout().flush()?; + } + break; + } + if let Some(t) = self.tokenizer.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + pos += context_size; + } + let dt = start_gen.elapsed(); + println!( + "\n{generated_tokens} tokens generated ({:.2} token/s)", + generated_tokens as f64 / dt.as_secs_f64(), + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Display the token for the specified prompt. + #[arg(long)] + verbose_prompt: bool, + + #[arg(long)] + prompt: Option, + + #[arg(long)] + mmlu_dir: Option, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, short = 'n', default_value_t = 5000)] + sample_len: usize, + + #[arg(long)] + model_id: Option, + + #[arg(long)] + revision: Option, + + #[arg(long)] + weight_file: Option, + + #[arg(long)] + tokenizer: Option, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The dtype to be used for running the model, e.g. f32, bf16, or f16. + #[arg(long)] + dtype: Option, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature.unwrap_or(0.), + args.repeat_penalty, + args.repeat_last_n + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => "microsoft/Phi-3.5-MoE-instruct".to_string(), + }; + let revision = match args.revision { + Some(rev) => rev.to_string(), + None => "main".to_string(), + }; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); + let tokenizer_filename = match args.tokenizer { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], + None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let start = std::time::Instant::now(); + + let device = candle_examples::device(args.cpu)?; + let dtype = match args.dtype { + Some(dtype) => std::str::FromStr::from_str(&dtype)?, + None => DType::F32, + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; + let model = { + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + Model::new(&config, vb)? + }; + println!("loaded the model in {:?}", start.elapsed()); + + match (args.prompt, args.mmlu_dir) { + (None, None) | (Some(_), Some(_)) => { + anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified") + } + (Some(prompt), None) => { + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + args.repeat_penalty, + args.repeat_last_n, + args.verbose_prompt, + &device, + ); + pipeline.run(&prompt, args.sample_len)?; + } + (None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?, + } + Ok(()) +} + +fn mmlu>( + mut model: Model, + tokenizer: Tokenizer, + device: &Device, + mmlu_dir: P, +) -> anyhow::Result<()> { + for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() { + let dir_entry = dir_entry.path(); + let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) { + None => "".to_string(), + Some(v) => match v.strip_suffix("_test") { + None => v.replace('_', " "), + Some(v) => v.replace('_', " "), + }, + }; + if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") { + continue; + } + println!("reading {dir_entry:?}"); + let dir_entry = std::fs::File::open(dir_entry)?; + let mut reader = csv::ReaderBuilder::new() + .has_headers(false) + .from_reader(dir_entry); + let token_a = tokenizer.token_to_id("A").unwrap(); + let token_b = tokenizer.token_to_id("B").unwrap(); + let token_c = tokenizer.token_to_id("C").unwrap(); + let token_d = tokenizer.token_to_id("D").unwrap(); + for row in reader.records() { + let row = match row { + Err(_) => continue, + Ok(row) => row, + }; + if row.len() < 5 { + continue; + } + let question = row.get(0).unwrap(); + let answer_a = row.get(1).unwrap(); + let answer_b = row.get(2).unwrap(); + let answer_c = row.get(3).unwrap(); + let answer_d = row.get(4).unwrap(); + let answer = row.get(5).unwrap(); + let prompt = format!( + "{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n", + "The following are multiple choice questions (with answers) about" + ); + let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?; + let tokens = tokens.get_ids().to_vec(); + let input = Tensor::new(tokens, device)?.unsqueeze(0)?; + let logits = { + model.clear_kv_cache(); + model.forward(&input, 0, 1)? + }; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + let logits_v: Vec = logits.to_vec1()?; + let pr_a = logits_v[token_a as usize]; + let pr_b = logits_v[token_b as usize]; + let pr_c = logits_v[token_c as usize]; + let pr_d = logits_v[token_d as usize]; + let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d { + "A" + } else if pr_b > pr_c && pr_b > pr_d { + "B" + } else if pr_c > pr_d { + "C" + } else { + "D" + }; + + println!("{prompt}\n -> {model_answer} vs {answer}"); + } + } + Ok(()) +} diff --git a/candle-transformers/src/models/phi3_5moe.rs b/candle-transformers/src/models/phi3_5moe.rs index 03522fdc7c..1ae33b1c97 100644 --- a/candle-transformers/src/models/phi3_5moe.rs +++ b/candle-transformers/src/models/phi3_5moe.rs @@ -205,6 +205,10 @@ impl Attention { }; self.o_proj.forward(&attn_output) } + + fn clear_kv_cache(&mut self) { + self.kv_cache = None + } } #[derive(Clone)] @@ -500,6 +504,10 @@ impl DecoderLayer { .forward(&xs.apply(&self.post_attention_layernorm)?)?; residual + xs } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache() + } } pub struct Model { @@ -604,4 +612,10 @@ impl Model { self.lm_head.forward(&xs.apply(&self.norm)?) } + + pub fn clear_kv_cache(&mut self) { + for layer in self.layers.iter_mut() { + layer.clear_kv_cache() + } + } }