Skip to content

Commit

Permalink
Add the sequential layer. (huggingface#1136)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored and EricLBuehler committed Oct 25, 2023
1 parent f7b06fc commit 4646ae2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
2 changes: 2 additions & 0 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
pub mod sequential;
pub mod var_builder;
pub mod var_map;

Expand All @@ -29,6 +30,7 @@ pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;

Expand Down
62 changes: 62 additions & 0 deletions candle-nn/src/sequential.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//! A sequential layer used to chain multiple layers and closures.
use candle::{Module, Result, Tensor};

/// A sequential layer combining multiple other layers.
pub struct Sequential {
layers: Vec<Box<dyn Module>>,
}

/// Creates a new empty sequential layer.
pub fn seq() -> Sequential {
Sequential { layers: vec![] }
}

impl Sequential {
/// The number of sub-layers embedded in this layer.
pub fn len(&self) -> i64 {
self.layers.len() as i64
}

/// Returns true if this layer does not have any sub-layer.
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}

impl Module for Sequential {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
Ok(xs)
}
}

impl Sequential {
/// Appends a layer after all the current layers.
#[allow(clippy::should_implement_trait)]
pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
self.layers.push(Box::new(layer));
self
}

/// Appends a closure after all the current layers.
pub fn add_fn<F>(self, f: F) -> Self
where
F: 'static + Fn(&Tensor) -> Result<Tensor> + Send,
{
self.add(super::func(f))
}

/// Applies the forward pass and returns the output for each layer.
pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
let mut vec = Vec::with_capacity(self.layers.len());
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs)?;
vec.push(xs.clone())
}
Ok(vec)
}
}

0 comments on commit 4646ae2

Please sign in to comment.