From 4646ae211e710fae2d883a7e46bbecb8d9e63e58 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 20 Oct 2023 16:08:50 +0100 Subject: [PATCH] Add the sequential layer. (#1136) --- candle-nn/src/lib.rs | 2 ++ candle-nn/src/sequential.rs | 62 +++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 candle-nn/src/sequential.rs diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index 8e5580dfff..be95f53121 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -11,6 +11,7 @@ pub mod loss; pub mod ops; pub mod optim; pub mod rnn; +pub mod sequential; pub mod var_builder; pub mod var_map; @@ -29,6 +30,7 @@ pub use linear::{linear, linear_no_bias, Linear}; pub use ops::Dropout; pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD}; pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN}; +pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs new file mode 100644 index 0000000000..2fef774297 --- /dev/null +++ b/candle-nn/src/sequential.rs @@ -0,0 +1,62 @@ +//! A sequential layer used to chain multiple layers and closures. +use candle::{Module, Result, Tensor}; + +/// A sequential layer combining multiple other layers. +pub struct Sequential { + layers: Vec>, +} + +/// Creates a new empty sequential layer. +pub fn seq() -> Sequential { + Sequential { layers: vec![] } +} + +impl Sequential { + /// The number of sub-layers embedded in this layer. + pub fn len(&self) -> i64 { + self.layers.len() as i64 + } + + /// Returns true if this layer does not have any sub-layer. + pub fn is_empty(&self) -> bool { + self.layers.is_empty() + } +} + +impl Module for Sequential { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + Ok(xs) + } +} + +impl Sequential { + /// Appends a layer after all the current layers. + #[allow(clippy::should_implement_trait)] + pub fn add(mut self, layer: M) -> Self { + self.layers.push(Box::new(layer)); + self + } + + /// Appends a closure after all the current layers. + pub fn add_fn(self, f: F) -> Self + where + F: 'static + Fn(&Tensor) -> Result + Send, + { + self.add(super::func(f)) + } + + /// Applies the forward pass and returns the output for each layer. + pub fn forward_all(&self, xs: &Tensor) -> Result> { + let mut vec = Vec::with_capacity(self.layers.len()); + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs)?; + vec.push(xs.clone()) + } + Ok(vec) + } +}