Skip to content

Commit

Permalink
Improve typehinting at root and rod directory
Browse files Browse the repository at this point in the history
  • Loading branch information
ankith26 committed May 17, 2024
1 parent 4d05804 commit 5041d0b
Show file tree
Hide file tree
Showing 22 changed files with 1,102 additions and 751 deletions.
32 changes: 22 additions & 10 deletions elastica/_calculus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__doc__ = """ Quadrature and difference kernels """
from typing import Any, Union
import numpy as np
from numpy import zeros, empty
from numpy.typing import NDArray
from numba import njit
from elastica.reset_functions_for_block_structure._reset_ghost_vector_or_scalar import (
_reset_vector_ghost,
Expand All @@ -9,15 +11,17 @@


@functools.lru_cache(maxsize=2)
def _get_zero_array(dim, ndim):
def _get_zero_array(dim: int, ndim: int) -> Union[float, NDArray[np.floating], None]:
if ndim == 1:
return 0.0
if ndim == 2:
return np.zeros((dim, 1))

return None


@njit(cache=True)
def _trapezoidal(array_collection):
def _trapezoidal(array_collection: NDArray[np.floating]) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way
Expand Down Expand Up @@ -63,7 +67,9 @@ def _trapezoidal(array_collection):


@njit(cache=True)
def _trapezoidal_for_block_structure(array_collection, ghost_idx):
def _trapezoidal_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
Simple trapezoidal quadrature rule with zero at end-points, in a dimension agnostic way. This form
specifically for the block structure implementation and there is a reset function call, to reset
Expand Down Expand Up @@ -115,7 +121,9 @@ def _trapezoidal_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _two_point_difference(array_collection):
def _two_point_difference(
array_collection: NDArray[np.floating],
) -> NDArray[np.floating]:
"""
This function does differentiation.
Expand Down Expand Up @@ -156,7 +164,9 @@ def _two_point_difference(array_collection):


@njit(cache=True)
def _two_point_difference_for_block_structure(array_collection, ghost_idx):
def _two_point_difference_for_block_structure(
array_collection: NDArray[np.floating], ghost_idx: NDArray[np.integer]
) -> NDArray[np.floating]:
"""
This function does the differentiation, for Cosserat rod model equations. This form
specifically for the block structure implementation and there is a reset function call, to
Expand Down Expand Up @@ -207,7 +217,7 @@ def _two_point_difference_for_block_structure(array_collection, ghost_idx):


@njit(cache=True)
def _difference(vector):
def _difference(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes difference between elements of a batch vector.
Expand Down Expand Up @@ -238,7 +248,7 @@ def _difference(vector):


@njit(cache=True)
def _average(vector):
def _average(vector: NDArray[np.floating]) -> NDArray[np.floating]:
"""
This function computes the average between elements of a vector.
Expand Down Expand Up @@ -268,7 +278,9 @@ def _average(vector):


@njit(cache=True)
def _clip_array(input_array, vmin, vmax):
def _clip_array(
input_array: NDArray[np.floating], vmin: np.floating, vmax: np.floating
) -> NDArray[np.floating]:
"""
This function clips an array values
between user defined minimum and maximum
Expand Down Expand Up @@ -304,7 +316,7 @@ def _clip_array(input_array, vmin, vmax):


@njit(cache=True)
def _isnan_check(array):
def _isnan_check(array: NDArray[Any]) -> bool:
"""
This function checks if there is any nan inside the array.
If there is nan, it returns True boolean.
Expand All @@ -324,7 +336,7 @@ def _isnan_check(array):
Python version: 2.24 µs ± 96.1 ns per loop
This version: 479 ns ± 6.49 ns per loop
"""
return np.isnan(array).any()
return bool(np.isnan(array).any())


position_difference_kernel = _difference
Expand Down
214 changes: 107 additions & 107 deletions elastica/_contact_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@
)
import numba
import numpy as np
from numpy.typing import NDArray


@numba.njit(cache=True)
def _calculate_contact_forces_rod_cylinder(
x_collection_rod,
edge_collection_rod,
x_cylinder_center,
x_cylinder_tip,
edge_cylinder,
radii_sum,
length_sum,
internal_forces_rod,
external_forces_rod,
external_forces_cylinder,
external_torques_cylinder,
cylinder_director_collection,
velocity_rod,
velocity_cylinder,
contact_k,
contact_nu,
velocity_damping_coefficient,
friction_coefficient,
x_collection_rod: NDArray[np.floating],
edge_collection_rod: NDArray[np.floating],
x_cylinder_center: NDArray[np.floating],
x_cylinder_tip: NDArray[np.floating],
edge_cylinder: NDArray[np.floating],
radii_sum: NDArray[np.floating],
length_sum: NDArray[np.floating],
internal_forces_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
external_forces_cylinder: NDArray[np.floating],
external_torques_cylinder: NDArray[np.floating],
cylinder_director_collection: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
velocity_cylinder: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
velocity_damping_coefficient: np.floating,
friction_coefficient: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points = x_collection_rod.shape[1]
Expand Down Expand Up @@ -155,22 +156,22 @@ def _calculate_contact_forces_rod_cylinder(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_rod(
x_collection_rod_one,
radius_rod_one,
length_rod_one,
tangent_rod_one,
velocity_rod_one,
internal_forces_rod_one,
external_forces_rod_one,
x_collection_rod_two,
radius_rod_two,
length_rod_two,
tangent_rod_two,
velocity_rod_two,
internal_forces_rod_two,
external_forces_rod_two,
contact_k,
contact_nu,
x_collection_rod_one: NDArray[np.floating],
radius_rod_one: NDArray[np.floating],
length_rod_one: NDArray[np.floating],
tangent_rod_one: NDArray[np.floating],
velocity_rod_one: NDArray[np.floating],
internal_forces_rod_one: NDArray[np.floating],
external_forces_rod_one: NDArray[np.floating],
x_collection_rod_two: NDArray[np.floating],
radius_rod_two: NDArray[np.floating],
length_rod_two: NDArray[np.floating],
tangent_rod_two: NDArray[np.floating],
velocity_rod_two: NDArray[np.floating],
internal_forces_rod_two: NDArray[np.floating],
external_forces_rod_two: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points_rod_one = x_collection_rod_one.shape[1]
Expand Down Expand Up @@ -272,14 +273,14 @@ def _calculate_contact_forces_rod_rod(

@numba.njit(cache=True)
def _calculate_contact_forces_self_rod(
x_collection_rod,
radius_rod,
length_rod,
tangent_rod,
velocity_rod,
external_forces_rod,
contact_k,
contact_nu,
x_collection_rod: NDArray[np.floating],
radius_rod: NDArray[np.floating],
length_rod: NDArray[np.floating],
tangent_rod: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points_rod = x_collection_rod.shape[1]
Expand Down Expand Up @@ -360,24 +361,24 @@ def _calculate_contact_forces_self_rod(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_sphere(
x_collection_rod,
edge_collection_rod,
x_sphere_center,
x_sphere_tip,
edge_sphere,
radii_sum,
length_sum,
internal_forces_rod,
external_forces_rod,
external_forces_sphere,
external_torques_sphere,
sphere_director_collection,
velocity_rod,
velocity_sphere,
contact_k,
contact_nu,
velocity_damping_coefficient,
friction_coefficient,
x_collection_rod: NDArray[np.floating],
edge_collection_rod: NDArray[np.floating],
x_sphere_center: NDArray[np.floating],
x_sphere_tip: NDArray[np.floating],
edge_sphere: NDArray[np.floating],
radii_sum: NDArray[np.floating],
length_sum: NDArray[np.floating],
internal_forces_rod: NDArray[np.floating],
external_forces_rod: NDArray[np.floating],
external_forces_sphere: NDArray[np.floating],
external_torques_sphere: NDArray[np.floating],
sphere_director_collection: NDArray[np.floating],
velocity_rod: NDArray[np.floating],
velocity_sphere: NDArray[np.floating],
contact_k: np.floating,
contact_nu: np.floating,
velocity_damping_coefficient: np.floating,
friction_coefficient: np.floating,
) -> None:
# We already pass in only the first n_elem x
n_points = x_collection_rod.shape[1]
Expand Down Expand Up @@ -486,18 +487,18 @@ def _calculate_contact_forces_rod_sphere(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
radius,
mass,
position_collection,
velocity_collection,
internal_forces,
external_forces,
):
plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
k: np.floating,
nu: np.floating,
radius: NDArray[np.floating],
mass: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
internal_forces: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:
"""
This function computes the plane force response on the element, in the
case of contact. Contact model given in Eqn 4.8 Gazzola et. al. RSoS 2018 paper
Expand Down Expand Up @@ -571,30 +572,30 @@ def _calculate_contact_forces_rod_plane(

@numba.njit(cache=True)
def _calculate_contact_forces_rod_plane_with_anisotropic_friction(
plane_origin,
plane_normal,
surface_tol,
slip_velocity_tol,
k,
nu,
kinetic_mu_forward,
kinetic_mu_backward,
kinetic_mu_sideways,
static_mu_forward,
static_mu_backward,
static_mu_sideways,
radius,
mass,
tangents,
position_collection,
director_collection,
velocity_collection,
omega_collection,
internal_forces,
external_forces,
internal_torques,
external_torques,
):
plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
slip_velocity_tol: np.floating,
k: np.floating,
nu: np.floating,
kinetic_mu_forward: np.floating,
kinetic_mu_backward: np.floating,
kinetic_mu_sideways: np.floating,
static_mu_forward: np.floating,
static_mu_backward: np.floating,
static_mu_sideways: np.floating,
radius: NDArray[np.floating],
mass: NDArray[np.floating],
tangents: NDArray[np.floating],
position_collection: NDArray[np.floating],
director_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
omega_collection: NDArray[np.floating],
internal_forces: NDArray[np.floating],
external_forces: NDArray[np.floating],
internal_torques: NDArray[np.floating],
external_torques: NDArray[np.floating],
) -> None:
(
plane_response_force_mag,
no_contact_point_idx,
Expand Down Expand Up @@ -784,17 +785,16 @@ def _calculate_contact_forces_rod_plane_with_anisotropic_friction(

@numba.njit(cache=True)
def _calculate_contact_forces_cylinder_plane(
plane_origin,
plane_normal,
surface_tol,
k,
nu,
length,
position_collection,
velocity_collection,
external_forces,
):

plane_origin: NDArray[np.floating],
plane_normal: NDArray[np.floating],
surface_tol: np.floating,
k: np.floating,
nu: np.floating,
length: NDArray[np.floating],
position_collection: NDArray[np.floating],
velocity_collection: NDArray[np.floating],
external_forces: NDArray[np.floating],
) -> tuple[NDArray[np.floating], NDArray[np.intp]]:
# Compute plane response force
# total_forces = system.internal_forces + system.external_forces
total_forces = external_forces
Expand Down
Loading

0 comments on commit 5041d0b

Please sign in to comment.