From 55bc3382cfd3a86018c54f2343567f7c0c0b677c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 07:53:09 +0100 Subject: [PATCH] Allow for different behavior between training and eval (#1213) * Forward with training. * Do not use dropout on vgg evaluation. --- candle-core/src/lib.rs | 12 +++++++ candle-core/src/tensor.rs | 5 +++ .../examples/mnist-training/main.rs | 4 +-- candle-examples/examples/vgg/main.rs | 4 +-- candle-nn/src/func.rs | 35 +++++++++++++++++++ candle-nn/src/lib.rs | 4 +-- candle-nn/src/ops.rs | 6 ++++ candle-transformers/src/models/vgg.rs | 35 ++++++++++--------- 8 files changed, 83 insertions(+), 22 deletions(-) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 52effdcf80..73830229cf 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -125,3 +125,15 @@ impl Result> Module for T { self(xs) } } + +// A trait defining a module with forward method using a single tensor argument and a flag to +// separate the training and evaluation behaviors. +pub trait ModuleT { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result; +} + +impl ModuleT for M { + fn forward_t(&self, xs: &Tensor, _train: bool) -> Result { + self.forward(xs) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce81d8aff0..c6f2364d60 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2271,6 +2271,11 @@ impl Tensor { m.forward(self) } + /// Run the `forward` method of `m` on `self`. + pub fn apply_t(&self, m: &M, train: bool) -> Result { + m.forward_t(self, train) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index a07505bf46..a41a6496b9 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum}; use rand::prelude::*; use candle::{DType, Result, Tensor, D}; -use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap}; +use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; @@ -95,7 +95,7 @@ impl ConvNet { .flatten_from(1)? .apply(&self.fc1)? .relu()?; - self.dropout.forward(&xs, train)?.apply(&self.fc2) + self.dropout.forward_t(&xs, train)?.apply(&self.fc2) } } diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs index e01fa8e8b5..27e141cb95 100644 --- a/candle-examples/examples/vgg/main.rs +++ b/candle-examples/examples/vgg/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{DType, IndexOp, D}; -use candle_nn::{Module, VarBuilder}; +use candle_nn::{ModuleT, VarBuilder}; use candle_transformers::models::vgg::{Models, Vgg}; use clap::{Parser, ValueEnum}; @@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> { Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, }; - let logits = model.forward(&image)?; + let logits = model.forward_t(&image, /*train=*/ false)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? .i(0)? diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index 39311d458c..3adfda860d 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -36,3 +36,38 @@ impl<'a> Func<'a> { Self { f: Arc::new(f) } } } + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct FuncT<'a> { + #[allow(clippy::type_complexity)] + f: Arc Result + Send + Sync>, +} + +impl<'a> std::fmt::Debug for FuncT<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func_t<'a, F>(f: F) -> FuncT<'a> +where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, +{ + FuncT { f: Arc::new(f) } +} + +impl<'a> super::ModuleT for FuncT<'a> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + (*self.f)(xs, train) + } +} + +impl<'a> FuncT<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, + { + Self { f: Arc::new(f) } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index be95f53121..52d8f0c595 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -22,7 +22,7 @@ pub use conv::{ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig, }; pub use embedding::{embedding, Embedding}; -pub use func::{func, Func}; +pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; @@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; -pub use candle::Module; +pub use candle::{Module, ModuleT}; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 32de1af9c8..e98121083e 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -84,6 +84,12 @@ impl Dropout { } } +impl candle::ModuleT for Dropout { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + self.forward(xs, train) + } +} + struct SoftmaxLastDim; impl candle::CustomOp1 for SoftmaxLastDim { diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 7837dc3e69..a20b5e3725 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -2,8 +2,8 @@ //! //! See Very Deep Convolutional Networks for Large-Scale Image Recognition //! -use candle::{Module, Result, Tensor}; -use candle_nn::{Func, VarBuilder}; +use candle::{ModuleT, Result, Tensor}; +use candle_nn::{FuncT, VarBuilder}; // Enum representing the different VGG models pub enum Models { @@ -15,7 +15,7 @@ pub enum Models { // Struct representing a VGG model #[derive(Debug)] pub struct Vgg<'a> { - blocks: Vec>, + blocks: Vec>, } // Struct representing the configuration for the pre-logit layer @@ -39,11 +39,11 @@ impl<'a> Vgg<'a> { } // Implementation of the forward pass for the VGG model -impl Module for Vgg<'_> { - fn forward(&self, xs: &Tensor) -> Result { +impl ModuleT for Vgg<'_> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { let mut xs = xs.unsqueeze(0)?; for block in self.blocks.iter() { - xs = xs.apply(block)?; + xs = xs.apply_t(block, train)?; } Ok(xs) } @@ -51,7 +51,7 @@ impl Module for Vgg<'_> { // Function to create a conv2d block // The block is composed of two conv2d layers followed by a max pool layer -fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() .enumerate() @@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result>>()?; - Ok(Func::new(move |xs| { + Ok(FuncT::new(move |xs, _train| { let mut xs = xs.clone(); for layer in layers.iter() { xs = xs.apply(layer)?.relu()? @@ -87,7 +87,7 @@ fn fully_connected( pre_logit_1: PreLogitConfig, pre_logit_2: PreLogitConfig, vb: VarBuilder, -) -> Result { +) -> Result { let lin = get_weights_and_biases( &vb.pp("pre_logits.fc1"), pre_logit_1.in_dim, @@ -100,12 +100,15 @@ fn fully_connected( pre_logit_2.target_in, pre_logit_2.target_out, )?; - Ok(Func::new(move |xs| { + let dropout1 = candle_nn::Dropout::new(0.5); + let dropout2 = candle_nn::Dropout::new(0.5); + let dropout3 = candle_nn::Dropout::new(0.5); + Ok(FuncT::new(move |xs, train| { let xs = xs.reshape((1, pre_logit_1.target_out))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?; + let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?; let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?; Ok(xs) })) } @@ -130,7 +133,7 @@ fn get_weights_and_biases( Ok(candle_nn::Linear::new(ws, Some(bs))) } -fn vgg13_blocks(vb: VarBuilder) -> Result> { +fn vgg13_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg16_blocks(vb: VarBuilder) -> Result> { +fn vgg16_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg19_blocks(vb: VarBuilder) -> Result> { +fn vgg19_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,