Skip to content

Commit

Permalink
candle.i16 => candle.i64
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Oct 28, 2023
1 parent b58056f commit 4348569
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
14 changes: 9 additions & 5 deletions candle-pyo3/py_src/candle/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from candle import Tensor


_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])


def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
Expand Down Expand Up @@ -56,11 +59,12 @@ def assert_almost_equal(
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)

# Secure against overflow of u32 and u8 tensors
diff = (
(actual - expected).abs()
if actual.sum_all().values() > expected.sum_all().values()
else (expected - actual).abs()
)
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
actual = actual.to(candle.i64)
expected = expected.to(candle.i64)

diff = (actual - expected).abs()

threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)

assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great"
5 changes: 3 additions & 2 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ macro_rules! pydtype {
}
};
}

pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
Expand Down Expand Up @@ -1576,7 +1577,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("i16", PyDType(DType::I64))?;
m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;
Expand Down
4 changes: 2 additions & 2 deletions candle-pyo3/tests/bindings/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
Expand All @@ -14,7 +14,7 @@ def test_assert_equal_asserts_correctly(dtype: candle.DType):
assert_equal(a, b + 1)


@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8])
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
Expand Down

0 comments on commit 4348569

Please sign in to comment.