diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index 0c81bca8ef..6702618ec2 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -108,7 +108,7 @@ pub fn parse_config>(path: T) -> Result { } enum Bl { - Layer(Box), + Layer(Box), Route(Vec), Shortcut(usize), Yolo(usize, Vec<(usize, usize)>), diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index e7fd73ae1d..39311d458c 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -1,10 +1,12 @@ //! Layers defined by closures. use candle::{Result, Tensor}; +use std::sync::Arc; /// A layer defined by a simple closure. +#[derive(Clone)] pub struct Func<'a> { #[allow(clippy::type_complexity)] - f: Box Result + Send>, + f: Arc Result + Send + Sync>, } impl<'a> std::fmt::Debug for Func<'a> { @@ -15,9 +17,9 @@ impl<'a> std::fmt::Debug for Func<'a> { pub fn func<'a, F>(f: F) -> Func<'a> where - F: 'a + Fn(&Tensor) -> Result + Send, + F: 'a + Fn(&Tensor) -> Result + Send + Sync, { - Func { f: Box::new(f) } + Func { f: Arc::new(f) } } impl<'a> super::Module for Func<'a> { @@ -29,8 +31,8 @@ impl<'a> super::Module for Func<'a> { impl<'a> Func<'a> { pub fn new(f: F) -> Self where - F: 'a + Fn(&Tensor) -> Result + Send, + F: 'a + Fn(&Tensor) -> Result + Send + Sync, { - Self { f: Box::new(f) } + Self { f: Arc::new(f) } } } diff --git a/candle-nn/src/sequential.rs b/candle-nn/src/sequential.rs index 2fef774297..bef9975287 100644 --- a/candle-nn/src/sequential.rs +++ b/candle-nn/src/sequential.rs @@ -44,7 +44,7 @@ impl Sequential { /// Appends a closure after all the current layers. pub fn add_fn(self, f: F) -> Self where - F: 'static + Fn(&Tensor) -> Result + Send, + F: 'static + Fn(&Tensor) -> Result + Send + Sync, { self.add(super::func(f)) }