Skip to content

Commit

Permalink
Add the blip example. (#1144)
Browse files Browse the repository at this point in the history
* Add the blip example.

* Tweak the example.

* Implement the cross-attn logic.

* Fix some shape mismatches.

* Get some logits out.

* Get some caption to be generated.
  • Loading branch information
LaurentMazare authored Oct 21, 2023
1 parent e8f760e commit 0d9bb4e
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 45 deletions.
54 changes: 54 additions & 0 deletions candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use clap::Parser;

use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;

#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,

#[arg(long)]
image: String,

/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

let device = candle_examples::device(args.cpu)?;

let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");

let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let config = blip::Config::image_captioning_large();
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
println!("model built");
// TODO: Maybe add support for the conditional prompt.
let out = model.generate(&image.unsqueeze(0)?, None, None)?;
println!(">>>\n{out}");
Ok(())
}
108 changes: 83 additions & 25 deletions candle-transformers/src/models/blip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,59 @@ use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};

#[derive(Debug, Clone)]
struct VisionConfig {
hidden_size: usize,
intermediate_size: usize,
projection_dim: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
image_size: usize,
patch_size: usize,
hidden_act: candle_nn::Activation,
layer_norm_eps: f64,
pub struct VisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub projection_dim: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub image_size: usize,
pub patch_size: usize,
pub hidden_act: candle_nn::Activation,
pub layer_norm_eps: f64,
}

#[derive(Debug, Clone)]
struct Config {
text_config: blip_text::Config,
vision_config: VisionConfig,
projection_dim: usize,
image_text_hidden_size: usize,
pub struct Config {
pub text_config: blip_text::Config,
pub vision_config: VisionConfig,
pub projection_dim: usize,
pub image_text_hidden_size: usize,
}

impl Config {
pub fn image_captioning_large() -> Self {
let text_config = blip_text::Config {
vocab_size: 30524,
hidden_size: 768,
encoder_hidden_size: 1024,
intermediate_size: 3072,
projection_dim: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
max_position_embeddings: 512,
hidden_act: candle_nn::Activation::Gelu,
layer_norm_eps: 1e-12,
is_decoder: true,
};
let vision_config = VisionConfig {
hidden_size: 1024,
intermediate_size: 4096,
projection_dim: 512,
num_hidden_layers: 24,
num_attention_heads: 16,
image_size: 384,
patch_size: 16,
hidden_act: candle_nn::Activation::Gelu,
layer_norm_eps: 1e-5,
};
Self {
text_config,
vision_config,
projection_dim: 512,
image_text_hidden_size: 256,
}
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -200,6 +235,7 @@ struct Encoder {
impl Encoder {
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb = vb.pp("layers");
for i in 0..cfg.num_hidden_layers {
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
layers.push(layer)
Expand All @@ -217,7 +253,7 @@ impl Encoder {
}

#[derive(Debug, Clone)]
struct VisionModel {
pub struct VisionModel {
embeddings: VisionEmbeddings,
encoder: Encoder,
post_layernorm: LayerNorm,
Expand All @@ -241,23 +277,19 @@ impl Module for VisionModel {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embeddings)?;
let encoder_outputs = self.encoder.forward(&xs, None)?;
let last_hidden_state = encoder_outputs.get(0)?;
last_hidden_state
.apply(&self.post_layernorm)?
.narrow(1, 0, 1)?
.squeeze(1)?
.apply(&self.post_layernorm)
// Return the last hidden state rather than pooled outputs.
encoder_outputs.apply(&self.post_layernorm)
}
}

#[derive(Debug, Clone)]
struct BlipForConditionalGeneration {
pub struct BlipForConditionalGeneration {
vision_model: VisionModel,
text_decoder: blip_text::TextLMHeadModel,
}

impl BlipForConditionalGeneration {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
let text_decoder =
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
Expand All @@ -267,12 +299,38 @@ impl BlipForConditionalGeneration {
})
}

fn forward(
pub fn vision_model(&self) -> &VisionModel {
&self.vision_model
}

pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
&self.text_decoder
}

pub fn generate(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let image_embeds = pixel_values.apply(&self.vision_model)?;
let b_size = image_embeds.dim(0)?;
if b_size > 1 {
candle::bail!("only a batch size of 1 is supported")
}
let mut logits_processor = crate::generation::LogitsProcessor::new(1337, None, None);
let mut token_ids = vec![30522u32];
for i in 0..1000 {
let input_ids =
Tensor::new(token_ids.as_slice(), pixel_values.device())?.broadcast_left(b_size)?;
let logits = self.text_decoder.forward(&input_ids, &image_embeds)?;
println!("{logits:?}");
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
println!("{token}");
token_ids.push(token)
}
todo!()
}
}
Loading

0 comments on commit 0d9bb4e

Please sign in to comment.