Skip to content

Commit

Permalink
Add abs + candle.testing
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Oct 27, 2023
1 parent d094028 commit dd06ff8
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 61 deletions.
36 changes: 36 additions & 0 deletions candle-pyo3/_additional_typing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,39 @@ def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Te
Return a slice of a tensor.
"""
pass

def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
46 changes: 41 additions & 5 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
Expand All @@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
Expand Down Expand Up @@ -231,11 +266,6 @@ class Tensor:
Gets the tensor's dtype.
"""
pass
def equal(self, rhs: Tensor) -> bool:
"""
True if two tensors have the same size and elements, False otherwise.
"""
pass
def exp(self) -> Tensor:
"""
Performs the `exp` operation on the tensor.
Expand Down Expand Up @@ -313,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelements(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
Expand Down
66 changes: 66 additions & 0 deletions candle-pyo3/py_src/candle/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import candle
from candle import Tensor


def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"

if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"

if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"

if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"


def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"


def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

# Secure against overflow of u32 and u8 tensors
diff = (
(actual - expected).abs()
if actual.sum_all().values() > expected.sum_all().values()
else (expected - actual).abs()
)
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great"
38 changes: 17 additions & 21 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}

#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelements(&self) -> usize {
self.0.elem_count()
}

#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
Expand Down Expand Up @@ -373,6 +380,16 @@ impl PyTensor {
self.__repr__()
}

/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
match self.0.dtype() {
DType::U8 => Ok(PyTensor(self.0.clone())),
DType::U32 => Ok(PyTensor(self.0.clone())),
_ => Ok(PyTensor(self.0.abs().map_err(wrap_err)?)),
}
}

/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
Expand Down Expand Up @@ -739,27 +756,6 @@ impl PyTensor {
hasher.finish()
}

#[pyo3(text_signature = "(self, rhs:Tensor)")]
/// True if two tensors have the same size and elements, False otherwise.
/// &RETURNS&: bool
fn equal(&self, rhs: &Self) -> PyResult<bool> {
if self.0.shape() != rhs.0.shape() {
return Ok(false);
}
let result = self
.0
.eq(&rhs.0)
.map_err(wrap_err)?
.to_dtype(DType::I64)
.map_err(wrap_err)?;
Ok(result
.sum_all()
.map_err(wrap_err)?
.to_scalar::<i64>()
.map_err(wrap_err)?
== self.0.elem_count() as i64)
}

#[pyo3(text_signature = "(self, shape:Sequence[int])")]
/// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor
Expand Down
33 changes: 33 additions & 0 deletions candle-pyo3/tests/bindings/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)

with pytest.raises(AssertionError):
assert_equal(a, b + 1)


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)

assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)

with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)
66 changes: 31 additions & 35 deletions candle-pyo3/tests/native/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import candle
from candle import Tensor
from candle.utils import cuda_is_available
from candle.testing import assert_equal
import pytest


Expand Down Expand Up @@ -77,61 +78,56 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]


def test_tensors_can_be_compared_with_equal():
t = Tensor(42.0)
other = Tensor(42.0)
assert t.equal(other)
t = Tensor([42.0, 42.1])
other = Tensor([42.0, 42.0])
assert not t.equal(other)
t = Tensor(42.0)
other = Tensor([42.0, 42.0])
assert not t.equal(other)
def assert_bool(t: Tensor, expected: bool):
assert t.shape == ()
assert str(t.dtype) == str(candle.u8)
assert bool(t.values()) == expected


def test_tensor_supports_equality_opperations_with_scalars():
t = Tensor(42.0)
assert (t == 42.0).equal(Tensor(1).to_dtype(candle.u8))
assert (t == 43.0).equal(Tensor(0).to_dtype(candle.u8))

assert (t != 42.0).equal(Tensor(0).to_dtype(candle.u8))
assert (t != 43.0).equal(Tensor(1).to_dtype(candle.u8))
assert_bool(t == 42.0, True)
assert_bool(t == 43.0, False)

assert_bool(t != 42.0, False)
assert_bool(t != 43.0, True)

assert (t > 41.0).equal(Tensor(1).to_dtype(candle.u8))
assert (t > 42.0).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t > 41.0, True)
assert_bool(t > 42.0, False)

assert (t >= 42.0).equal(Tensor(1).to_dtype(candle.u8))
assert (t >= 43.0).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t >= 41.0, True)
assert_bool(t >= 42.0, True)

assert (t < 43.0).equal(Tensor(1).to_dtype(candle.u8))
assert (t < 42.0).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t < 43.0, True)
assert_bool(t < 42.0, False)

assert (t <= 42.0).equal(Tensor(1).to_dtype(candle.u8))
assert (t <= 41.0).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t <= 43.0, True)
assert_bool(t <= 42.0, True)


def test_tensor_supports_equality_opperations_with_tensors():
t = Tensor(42.0)
same = Tensor(42.0)
other = Tensor(43.0)

assert (t == same).equal(Tensor(1).to_dtype(candle.u8))
assert (t == other).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t == same, True)
assert_bool(t == other, False)

assert (t != same).equal(Tensor(0).to_dtype(candle.u8))
assert (t != other).equal(Tensor(1).to_dtype(candle.u8))
assert_bool(t != same, False)
assert_bool(t != other, True)

assert (t > same).equal(Tensor(0).to_dtype(candle.u8))
assert (t > other).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t > same, False)
assert_bool(t > other, False)

assert (t >= same).equal(Tensor(1).to_dtype(candle.u8))
assert (t >= other).equal(Tensor(0).to_dtype(candle.u8))
assert_bool(t >= same, True)
assert_bool(t >= other, False)

assert (t < same).equal(Tensor(0).to_dtype(candle.u8))
assert (t < other).equal(Tensor(1).to_dtype(candle.u8))
assert_bool(t < same, False)
assert_bool(t < other, True)

assert (t <= same).equal(Tensor(1).to_dtype(candle.u8))
assert (t <= other).equal(Tensor(1).to_dtype(candle.u8))
assert_bool(t <= same, True)
assert_bool(t <= other, True)


def test_tensor_equality_opperations_can_broadcast():
Expand All @@ -143,7 +139,7 @@ def test_tensor_equality_opperations_can_broadcast():
mask_cond = candle.Tensor([0, 1, 2])
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
assert mask.shape == (3, 3)
assert mask.equal(Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))


def test_tensor_can_be_hashed():
Expand Down

0 comments on commit dd06ff8

Please sign in to comment.