Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert pytorch's tensor in Python API #1172

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
Expand Down
24 changes: 24 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,16 @@ enum Indexer {
IndexSelect(Tensor),
}

#[derive(Clone, Debug)]
struct TorchTensor(PyObject);

impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
}

#[pymethods]
impl PyTensor {
#[new]
Expand Down Expand Up @@ -246,6 +256,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
Expand Down Expand Up @@ -299,6 +311,18 @@ impl PyTensor {
M(py).map(self)
}

/// Converts candle's tensor to pytorch's tensor
/// &RETURNS&: torch.Tensor
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
Ok(torch_tensor)
}

#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
Expand Down
14 changes: 14 additions & 0 deletions candle-pyo3/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import candle
import torch

# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))

# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))
Loading