From 1da21cd1666bcc6843ee3f37dde971929bcacc08 Mon Sep 17 00:00:00 2001 From: macroexpansion Date: Tue, 24 Oct 2023 02:55:33 +0700 Subject: [PATCH] convert pytorch's tensor --- candle-pyo3/py_src/candle/__init__.pyi | 5 +++++ candle-pyo3/src/lib.rs | 24 ++++++++++++++++++++++++ candle-pyo3/test.py | 13 +++++++++++++ 3 files changed, 42 insertions(+) diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index 7a0b2fcf15..437221683b 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -396,6 +396,11 @@ class Tensor: Convert the tensor to a new dtype. """ pass + def to_torch(self) -> torch.Tensor: + """ + Converts candle's tensor to pytorch's tensor + """ + pass def transpose(self, dim1: int, dim2: int) -> Tensor: """ Returns a tensor that is a transposed version of the input, the given dimensions are swapped. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index e2c8014f36..6d4de80bfb 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -211,6 +211,16 @@ enum Indexer { IndexSelect(Tensor), } +#[derive(Clone, Debug)] +struct TorchTensor(PyObject); + +impl<'source> pyo3::FromPyObject<'source> for TorchTensor { + fn extract(ob: &'source PyAny) -> PyResult { + let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?; + Ok(TorchTensor(numpy_value)) + } +} + #[pymethods] impl PyTensor { #[new] @@ -246,6 +256,8 @@ impl PyTensor { Tensor::new(vs, &Cpu).map_err(wrap_err)? } else if let Ok(vs) = data.extract::>>>(py) { Tensor::new(vs, &Cpu).map_err(wrap_err)? + } else if let Ok(TorchTensor(numpy)) = data.extract::(py) { + return PyTensor::new(py, numpy); } else { let ty = data.as_ref(py).get_type(); Err(PyTypeError::new_err(format!( @@ -299,6 +311,18 @@ impl PyTensor { M(py).map(self) } + /// Converts candle's tensor to pytorch's tensor + /// &RETURNS&: torch.Tensor + fn to_torch(&self, py: Python<'_>) -> PyResult { + let candle_values = self.values(py)?; + let torch_tensor: PyObject = py + .import("torch")? + .getattr("tensor")? + .call1((candle_values,))? + .extract()?; + Ok(torch_tensor) + } + #[getter] /// Gets the tensor's shape. /// &RETURNS&: Tuple[int] diff --git a/candle-pyo3/test.py b/candle-pyo3/test.py index e4ff772a1f..4d0b52f9e7 100644 --- a/candle-pyo3/test.py +++ b/candle-pyo3/test.py @@ -1,4 +1,5 @@ import candle +import torch print(f"mkl: {candle.utils.has_mkl()}") print(f"accelerate: {candle.utils.has_accelerate()}") @@ -29,3 +30,15 @@ dequant_t = quant_t.dequantize() diff2 = (t - dequant_t).sqr() print(diff2.mean_all()) + +# convert from candle tensor to torch tensor +t = candle.randn((3, 512, 512)) +torch_tensor = t.to_torch() +print(torch_tensor) +print(type(torch_tensor)) + +# convert from torch tensor to candle tensor +t = torch.randn((3, 512, 512)) +candle_tensor = candle.Tensor(t) +print(candle_tensor) +print(type(candle_tensor))