From 4e88911c1f15d46a5c73ddc99c5e855a80deefeb Mon Sep 17 00:00:00 2001 From: Zachary Charlop-Powers Date: Tue, 26 Nov 2024 11:00:43 -0500 Subject: [PATCH 1/3] Provide a method to allow PTH files iwth state maps to be loaded. --- candle-nn/src/var_builder.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 0d836c7fd4..2489988f75 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -544,7 +544,16 @@ impl<'a> VarBuilder<'a> { let pth = candle::pickle::PthTensors::new(p, None)?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } - + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + pub fn from_pth_with_state>( + p: P, + dtype: DType, + state_key: String, + dev: &Device, + ) -> Result { + let pth = candle::pickle::PthTensors::new(p, Some(state_key.as_str()))?; + Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) + } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before /// passing the new names to the inner VarBuilder. /// From 0d753ee0717c047127a55bcb30f3a93abb4b1fae Mon Sep 17 00:00:00 2001 From: Zachary Charlop-Powers Date: Tue, 26 Nov 2024 11:03:07 -0500 Subject: [PATCH 2/3] add a line to the doc --- candle-nn/src/var_builder.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 2489988f75..66f93e19a2 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -545,6 +545,7 @@ impl<'a> VarBuilder<'a> { Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + /// similar to [`from_pth`] but requires a `state_key`. pub fn from_pth_with_state>( p: P, dtype: DType, From c4dba2fe4802237fe6e38f2b51f875c26652266f Mon Sep 17 00:00:00 2001 From: Zachary Charlop-Powers Date: Tue, 26 Nov 2024 11:53:08 -0500 Subject: [PATCH 3/3] String-. &str --- candle-nn/src/var_builder.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 66f93e19a2..2731456d43 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -549,10 +549,10 @@ impl<'a> VarBuilder<'a> { pub fn from_pth_with_state>( p: P, dtype: DType, - state_key: String, + state_key: &str, dev: &Device, ) -> Result { - let pth = candle::pickle::PthTensors::new(p, Some(state_key.as_str()))?; + let pth = candle::pickle::PthTensors::new(p, Some(state_key))?; Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) } /// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before