Skip to content

Commit

Permalink
PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
Browse files Browse the repository at this point in the history
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
  • Loading branch information
LLukas22 authored Oct 30, 2023
1 parent 9699608 commit c05c0a8
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 3 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ impl Shape {

/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
Expand Down
36 changes: 36 additions & 0 deletions candle-pyo3/_additional_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Te
Return a slice of a tensor.
"""
pass

def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
41 changes: 41 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
Expand All @@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
Expand Down Expand Up @@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
Expand Down
70 changes: 70 additions & 0 deletions candle-pyo3/py_src/candle/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import candle
from candle import Tensor


_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])


def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"

if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"

if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"

if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"


def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"


def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

# Secure against overflow of u32 and u8 tensors
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
actual = actual.to(candle.i64)
expected = expected.to(candle.i64)

diff = (actual - expected).abs()

threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"
73 changes: 71 additions & 2 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;

Expand Down Expand Up @@ -132,9 +135,10 @@ macro_rules! pydtype {
}
};
}

pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
Expand Down Expand Up @@ -317,6 +321,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}

#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelement(&self) -> usize {
self.0.elem_count()
}

#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
Expand Down Expand Up @@ -353,6 +364,12 @@ impl PyTensor {
self.__repr__()
}

/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
}

/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Expand Down Expand Up @@ -670,6 +687,58 @@ impl PyTensor {
};
Ok(Self(tensor))
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
CompareOp::Ne => lhs.ne(rhs),
CompareOp::Lt => lhs.lt(rhs),
CompareOp::Le => lhs.le(rhs),
CompareOp::Gt => lhs.gt(rhs),
CompareOp::Ge => lhs.ge(rhs),
};
Ok(PyTensor(t.map_err(wrap_err)?))
};
if let Ok(rhs) = rhs.extract::<PyTensor>() {
if self.0.shape() == rhs.0.shape() {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = self
.0
.shape()
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
.map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;

compare(&broadcasted_lhs, &broadcasted_rhs)
}
} else if let Ok(rhs) = rhs.extract::<f64>() {
let scalar_tensor = Tensor::new(rhs, self.0.device())
.map_err(wrap_err)?
.to_dtype(self.0.dtype())
.map_err(wrap_err)?
.broadcast_as(self.0.shape())
.map_err(wrap_err)?;

compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
}
}

fn __hash__(&self) -> u64 {
// we have overridden __richcmp__ => py03 wants us to also override __hash__
// we simply hash the address of the tensor
let mut hasher = DefaultHasher::new();
let pointer = &self.0 as *const Tensor;
let address = pointer as usize;
address.hash(&mut hasher);
hasher.finish()
}

#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
Expand Down Expand Up @@ -1503,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("i16", PyDType(DType::I64))?;
m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;
Expand Down
33 changes: 33 additions & 0 deletions candle-pyo3/tests/bindings/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)

with pytest.raises(AssertionError):
assert_equal(a, b + 1)


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)

assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)
Loading

0 comments on commit c05c0a8

Please sign in to comment.