From ea1f17efd988b0153b2f41adf71f51daf70a830a Mon Sep 17 00:00:00 2001 From: Niko Sirmpilatze Date: Wed, 4 Sep 2024 10:47:27 +0100 Subject: [PATCH] Compatibility with numpy 2.0 (#109) * replaced np.alltrue with np.all * added ruff rule for checking numpy2.0 deprecations --- .../agents/whittington_2020_extras/whittington_2020_utils.py | 2 +- neuralplayground/utils.py | 2 +- pyproject.toml | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py b/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py index f8b662a2..1d9005a2 100644 --- a/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py +++ b/neuralplayground/agents/whittington_2020_extras/whittington_2020_utils.py @@ -325,7 +325,7 @@ def check_wall(pre_state, new_state, wall, wall_closenes=1e-5, tolerance=1e-9): larger_than_zero = intersection >= 0 # If condition is true, then the points cross the wall - cross_wall = np.alltrue(np.logical_and(smaller_than_one, larger_than_zero)) + cross_wall = np.all(np.logical_and(smaller_than_one, larger_than_zero)) if cross_wall: new_state = (intersection[-1] - wall_closenes) * (new_state - pre_state) + pre_state diff --git a/neuralplayground/utils.py b/neuralplayground/utils.py index 36446657..3dc10f59 100644 --- a/neuralplayground/utils.py +++ b/neuralplayground/utils.py @@ -49,7 +49,7 @@ def check_crossing_wall( larger_than_zero = intersection >= 0 # If condition is true, then the points cross the wall - cross_wall = np.alltrue(np.logical_and(smaller_than_one, larger_than_zero)) + cross_wall = np.all(np.logical_and(smaller_than_one, larger_than_zero)) if cross_wall: new_state = (intersection[-1] - wall_closenes) * (new_state - pre_state) + pre_state diff --git a/pyproject.toml b/pyproject.toml index 1192a276..a6f573e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,9 +96,11 @@ ignore = [ [tool.ruff] line-length = 127 exclude = ["__init__.py","build",".eggs"] -select = ["I", "E", "F"] fix = true +[tool.ruff.lint] +select = ["I", "E", "F", "NPY201"] + [tool.cibuildwheel] build = "cp38-* cp39-* cp310-* cp311-*"