diff --git a/candle-pyo3/py_src/candle/testing/__init__.py b/candle-pyo3/py_src/candle/testing/__init__.py index a11b56e252..7b2dec9ec3 100644 --- a/candle-pyo3/py_src/candle/testing/__init__.py +++ b/candle-pyo3/py_src/candle/testing/__init__.py @@ -2,6 +2,9 @@ from candle import Tensor +_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)]) + + def _assert_tensor_metadata( actual: Tensor, expected: Tensor, @@ -56,11 +59,12 @@ def assert_almost_equal( _assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride) # Secure against overflow of u32 and u8 tensors - diff = ( - (actual - expected).abs() - if actual.sum_all().values() > expected.sum_all().values() - else (expected - actual).abs() - ) + if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES: + actual = actual.to(candle.i64) + expected = expected.to(candle.i64) + + diff = (actual - expected).abs() + threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected) assert (diff <= threshold).sum_all().values() == actual.nelements, f"Difference between tensors was to great" diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 4b75a6871e..d1045b5bca 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -148,9 +148,10 @@ macro_rules! pydtype { } }; } + +pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); -pydtype!(i64, |v| v); pydtype!(f16, f32::from); pydtype!(bf16, f32::from); pydtype!(f32, |v| v); @@ -1576,7 +1577,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add("u8", PyDType(DType::U8))?; m.add("u32", PyDType(DType::U32))?; - m.add("i16", PyDType(DType::I64))?; + m.add("i64", PyDType(DType::I64))?; m.add("bf16", PyDType(DType::BF16))?; m.add("f16", PyDType(DType::F16))?; m.add("f32", PyDType(DType::F32))?; diff --git a/candle-pyo3/tests/bindings/test_testing.py b/candle-pyo3/tests/bindings/test_testing.py index 58a2ed1f51..db2fd3f7fa 100644 --- a/candle-pyo3/tests/bindings/test_testing.py +++ b/candle-pyo3/tests/bindings/test_testing.py @@ -4,7 +4,7 @@ import pytest -@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) def test_assert_equal_asserts_correctly(dtype: candle.DType): a = Tensor([1, 2, 3]).to(dtype) b = Tensor([1, 2, 3]).to(dtype) @@ -14,7 +14,7 @@ def test_assert_equal_asserts_correctly(dtype: candle.DType): assert_equal(a, b + 1) -@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8]) +@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64]) def test_assert_almost_equal_asserts_correctly(dtype: candle.DType): a = Tensor([1, 2, 3]).to(dtype) b = Tensor([1, 2, 3]).to(dtype)