Skip to content

Commit

Permalink
convert pytorch's tensor in Python API (#1172)
Browse files Browse the repository at this point in the history
* convert pytorch's tensor

* separate tests for convert pytorch tensor
  • Loading branch information
macroexpansion authored Oct 25, 2023
1 parent 0acd167 commit 6a446d9
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
5 changes: 5 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
Expand Down
24 changes: 24 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}

#[derive(Clone, Debug)]
struct TorchTensor(PyObject);

impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
}

#[pymethods]
impl PyTensor {
#[new]
Expand Down Expand Up @@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
Expand Down Expand Up @@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}

/// Converts candle's tensor to pytorch's tensor
/// &RETURNS&: torch.Tensor
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
Ok(torch_tensor)
}

#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
Expand Down
14 changes: 14 additions & 0 deletions candle-pyo3/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import candle
import torch

# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))

# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))

0 comments on commit 6a446d9

Please sign in to comment.