Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyO3: Better shape handling #1143

Merged
merged 5 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
Expand Down
4 changes: 2 additions & 2 deletions candle-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ crate-type = ["cdylib"]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
pyo3 = { version = "0.20.0", features = ["extension-module"] }

[build-dependencies]
pyo3-build-config = "0.19"
pyo3-build-config = "0.20"

[features]
default = []
Expand Down
16 changes: 8 additions & 8 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape

class bf16(DType):
pass
Expand All @@ -26,21 +26,21 @@ class i64(DType):
pass

@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with ones.
"""
pass

@staticmethod
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values.
"""
pass

@staticmethod
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values from a normal distribution.
"""
Expand All @@ -67,7 +67,7 @@ class u8(DType):
pass

@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def zeros(shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
LLukas22 marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a new tensor filled with zeros.
"""
Expand Down Expand Up @@ -174,7 +174,7 @@ class Tensor:
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
def broadcast_as(self, shape: Sequence[int]) -> Tensor:
def broadcast_as(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape.
"""
Expand All @@ -184,7 +184,7 @@ class Tensor:
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
def broadcast_left(self, shape: Sequence[int]) -> Tensor:
def broadcast_left(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape, adding new dimensions on the left.
"""
Expand Down Expand Up @@ -329,7 +329,7 @@ class Tensor:
Get the `recip` of the tensor.
"""
pass
def reshape(self, shape: Sequence[int]) -> Tensor:
def reshape(self, *shape: Shape) -> Tensor:
"""
Reshapes the tensor to the given shape.
"""
Expand Down
2 changes: 1 addition & 1 deletion candle-pyo3/py_src/candle/functional/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions candle-pyo3/py_src/candle/typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
Scalar = Union[int, float]

Index = Union[int, slice, None, "Ellipsis"]

Shape = Union[int, Sequence[int]]
2 changes: 1 addition & 1 deletion candle-pyo3/py_src/candle/utils/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor

@staticmethod
Expand Down
71 changes: 35 additions & 36 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,13 @@ use half::{bf16, f16};

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

mod shape;
use shape::{PyRelativeShape, PyShape};

pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}

#[derive(Clone, Debug)]
struct PyShape(Vec<usize>);

impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
Ok(PyShape(dims))
}
}

impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}

#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
/// A `candle` tensor.
Expand Down Expand Up @@ -654,25 +641,37 @@ impl PyTensor {
Ok(Self(tensor))
}

#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
fn reshape(&self, shape: PyRelativeShape) -> PyResult<Self> {
Ok(PyTensor(
self.0
.reshape(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}

#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape.
/// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
fn broadcast_as(&self, shape: PyRelativeShape) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_as(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}

#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
/// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
fn broadcast_left(&self, shape: PyRelativeShape) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_left(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}

#[pyo3(text_signature = "(self, dim:int)")]
Expand Down Expand Up @@ -885,21 +884,21 @@ impl PyTensor {
}

if let Some(kwargs) = kwargs {
if let Some(any) = kwargs.get_item("dtype") {
if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates(
&mut dtype,
any.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
}
if let Some(any) = kwargs.get_item("device") {
if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates(
&mut device,
any.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
}
if let Some(any) = kwargs.get_item("other") {
if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates(
&mut other,
any.extract::<PyTensor>(),
Expand Down Expand Up @@ -1019,27 +1018,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
}

#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values from a normal distribution.
/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with ones.
/// &RETURNS&: Tensor
fn ones(
Expand All @@ -1053,12 +1052,12 @@ fn ones(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with zeros.
/// &RETURNS&: Tensor
fn zeros(
Expand All @@ -1072,7 +1071,7 @@ fn zeros(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}

Expand Down
89 changes: 89 additions & 0 deletions candle-pyo3/src/shape.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use ::candle::Tensor;
use pyo3::prelude::*;

#[derive(Clone, Debug)]
/// Represents an absolute shape e.g. (1, 2, 3)
pub struct PyShape(Vec<usize>);

impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}

let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
Ok(PyShape(dims))
}
}
}

impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}

#[derive(Clone, Debug)]
/// Represents a relative shape e.g. (1, -1, 3)
pub struct PyRelativeShape(Vec<isize>);
LaurentMazare marked this conversation as resolved.
Show resolved Hide resolved

impl<'source> pyo3::FromPyObject<'source> for PyRelativeShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}

let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<isize> = pyo3::FromPyObject::extract(first_element)?;
Ok(PyRelativeShape(dims))
} else {
let dims: Vec<isize> = pyo3::FromPyObject::extract(tuple)?;
Ok(PyRelativeShape(dims))
}
}
}

impl PyRelativeShape {
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
pub fn is_absolute(&self) -> bool {
self.0.iter().all(|x| *x > 0)
}

/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
if self.is_absolute() {
return Ok(PyShape(
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
));
}

let mut elements = t.elem_count();
let mut new_dims: Vec<usize> = vec![];
for dim in self.0.iter() {
if *dim > 0 {
new_dims.push(*dim as usize);
elements /= *dim as usize;
} else if *dim == -1 {
new_dims.push(elements);
} else {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {}",
dim
)));
}
}
Ok(PyShape(new_dims))
}
}
2 changes: 1 addition & 1 deletion candle-pyo3/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {}
Expand Down
Loading
Loading