Skip to content

Commit

Permalink
Bugfixes for marian-mt. (#1219)
Browse files Browse the repository at this point in the history
* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
  • Loading branch information
LaurentMazare authored Oct 30, 2023
1 parent 5fc66bd commit 9699608
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
7 changes: 3 additions & 4 deletions candle-examples/examples/marian-mt/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ struct Args {
text: String,
}

const SEP_TOKEN_ID: u32 = 102;

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

Expand All @@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> {
model.encoder().forward(&tokens, 0)?
};

let mut token_ids = vec![30522u32];
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
Expand All @@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
println!("{token}");
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
token_ids.push(token);
Expand Down
27 changes: 18 additions & 9 deletions candle-transformers/src/models/marian.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ pub struct Config {
pub is_encoder_decoder: bool,
pub activation_function: candle_nn::Activation,
pub d_model: usize,
pub decoder_start_token_id: usize,
pub decoder_start_token_id: u32,
pub scale_embedding: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
pub forced_eos_token_id: usize,
pub pad_token_id: u32,
pub eos_token_id: u32,
pub forced_eos_token_id: u32,
pub share_encoder_decoder_embeddings: bool,
}

Expand Down Expand Up @@ -224,7 +224,8 @@ impl DecoderLayer {
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn_layer_norm =
layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
Expand All @@ -249,15 +250,16 @@ impl DecoderLayer {
Some(encoder_xs) => {
let residual = &xs;
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
(residual + xs)?.apply(&self.self_attn_layer_norm)?
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
}
};
let residual = &xs;
let xs = xs
.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
}

Expand Down Expand Up @@ -356,7 +358,7 @@ impl Decoder {
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, encoder_xs)?
xs = layer.forward(&xs, encoder_xs)?;
}
Ok(xs)
}
Expand Down Expand Up @@ -385,6 +387,7 @@ impl Model {
#[derive(Debug, Clone)]
pub struct MTModel {
model: Model,
lm_head: Linear,
final_logits_bias: Tensor,
}

Expand All @@ -393,8 +396,10 @@ impl MTModel {
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
let model = Model::new(cfg, vb.pp("model"))?;
let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
Ok(Self {
model,
lm_head,
final_logits_bias,
})
}
Expand All @@ -408,6 +413,10 @@ impl MTModel {
}

pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
self.model.decoder.forward(xs, Some(encoder_xs), 0)
self.model
.decoder
.forward(xs, Some(encoder_xs), 0)?
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
}

0 comments on commit 9699608

Please sign in to comment.