Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing: timestepper fix pytest and cleanup integration routine #373

Merged
merged 12 commits into from
May 7, 2024
Merged
24 changes: 12 additions & 12 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#* Variables
PYTHON := python3
PYTHONPATH := `pwd`
AUTOFLAKE8_ARGS := -r --exclude '__init__.py' --keep-pass-after-docstring
AUTOFLAKE_ARGS := -r
#* Poetry
.PHONY: poetry-download
poetry-download:
Expand Down Expand Up @@ -47,19 +47,19 @@ flake8:
poetry run flake8 --version
poetry run flake8 elastica tests

.PHONY: autoflake8-check
autoflake8-check:
poetry run autoflake8 --version
poetry run autoflake8 $(AUTOFLAKE8_ARGS) elastica tests examples
poetry run autoflake8 --check $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake8-format
autoflake8-format:
poetry run autoflake8 --version
poetry run autoflake8 --in-place $(AUTOFLAKE8_ARGS) elastica tests examples
.PHONY: autoflake-format
autoflake-format:
poetry run autoflake --version
poetry run autoflake --in-place $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: format-codestyle
format-codestyle: black flake8
format-codestyle: black autoflake-format

.PHONY: mypy
mypy:
Expand All @@ -78,7 +78,7 @@ test_coverage_xml:
NUMBA_DISABLE_JIT=1 poetry run pytest --cov=elastica --cov-report=xml

.PHONY: check-codestyle
check-codestyle: black-check flake8 autoflake8-check
check-codestyle: black-check flake8 autoflake-check

.PHONY: formatting
formatting: format-codestyle
Expand Down
11 changes: 6 additions & 5 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
Basic coordinating for multiple, smaller systems that have an independently integrable
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Iterable, Callable, AnyStr
from typing import Iterable, Callable, AnyStr, Type
from elastica.typing import SystemType

import numpy as np

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self):
# We need to initialize our mixin classes
super(BaseSystemCollection, self).__init__()
# List of system types/bases that are allowed
self.allowed_sys_types = (RodBase, RigidBodyBase, SurfaceBase)
self.allowed_sys_types: tuple[Type, ...] = (RodBase, RigidBodyBase, SurfaceBase)
# List of systems to be integrated
self._systems = []
# Flag Finalize: Finalizing twice will cause an error,
Expand Down Expand Up @@ -98,11 +99,11 @@ def insert(self, idx, system):
def __str__(self):
return str(self._systems)

def extend_allowed_types(self, additional_types):
def extend_allowed_types(self, additional_types: list[Type, ...]):
self.allowed_sys_types += additional_types

def override_allowed_types(self, allowed_types):
self.allowed_sys_types = allowed_types
def override_allowed_types(self, allowed_types: list[Type, ...]):
self.allowed_sys_types = tuple(allowed_types)

def _get_sys_idx_if_valid(self, sys_to_be_added):
from numpy import int_ as npint
Expand Down
1 change: 0 additions & 1 deletion elastica/py.typed
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@

4 changes: 3 additions & 1 deletion elastica/systems/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from elastica._rotations import _rotate
from elastica.rod.data_structures import _RodSymplecticStepperMixin
from elastica.rod.rod_base import RodBase


class BaseStatefulSystem:
Expand Down Expand Up @@ -355,8 +356,9 @@ def make_simple_system_with_positions_directors(
)


class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin):
class SimpleSystemWithPositionsDirectors(_RodSymplecticStepperMixin, RodBase):
def __init__(self, start_position, end_position, start_director):
self.ring_rod_flag = False # TODO:
self.a = 0.5
self.b = 1
self.c = 2
Expand Down
4 changes: 2 additions & 2 deletions elastica/systems/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def external_forces(self) -> NDArray: ...
@property
def external_torques(self) -> NDArray: ...

def update_internal_forces_and_torques(self, time: np.floating) -> None: ...


class SymplecticSystemProtocol(SystemProtocol, Protocol):
"""
Expand All @@ -64,8 +66,6 @@ def dynamic_rates(
self, time: np.floating, prefac: np.floating
) -> tuple[NDArray]: ...

def update_internal_forces_and_torques(self, time: np.floating) -> None: ...


class ExplicitSystemProtocol(SystemProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
Expand Down
111 changes: 52 additions & 59 deletions elastica/timestepper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__doc__ = """Timestepping utilities to be used with Rod and RigidBody classes"""

from typing import Tuple, List, Callable, Type
from elastica.typing import SystemType
from typing import Tuple, List, Callable, Type, Any, overload
from elastica.typing import SystemType, SystemCollectionType, SteppersOperatorsType

import numpy as np
from tqdm import tqdm
Expand All @@ -10,57 +10,52 @@

from .symplectic_steppers import PositionVerlet, PEFRL
from .explicit_steppers import RungeKutta4, EulerForward
from .protocol import StepperProtocol, SymplecticStepperProtocol

from .tag import SymplecticStepperTag, ExplicitStepperTag
from .protocol import StepperProtocol, StatefulStepperProtocol
from .protocol import MethodCollectorProtocol


# TODO: Both extend_stepper_interface and integrate should be in separate file.
# __init__ is probably not an ideal place to have these scripts.
# Deprecated: Remove in the future version
# Many script still uses this method to control timestep. Keep it for backward compatibility
def extend_stepper_interface(
Stepper: StepperProtocol, System: SystemType
) -> Tuple[Callable, Tuple[Callable]]:

# StepperMethodCollector: Type[MethodCollectorProtocol]
# SystemStepper: Type[StepperProtocol]
if isinstance(Stepper.Tag, SymplecticStepperTag):
from elastica.timestepper.symplectic_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
SymplecticStepperMethods,
)

StepperMethodCollector = SymplecticStepperMethods
elif isinstance(Stepper.Tag, ExplicitStepperTag): # type: ignore[no-redef]
from elastica.timestepper.explicit_steppers import (
_SystemInstanceStepper,
_SystemCollectionStepper,
ExplicitStepperMethods,
)

StepperMethodCollector = ExplicitStepperMethods
else:
raise NotImplementedError(
"Only explicit and symplectic steppers are supported, given stepper is {}".format(
Stepper.__class__.__name__
)
)

# Check if system is a "collection" of smaller systems
if is_system_a_collection(System):
SystemStepper = _SystemCollectionStepper
else:
SystemStepper = _SystemInstanceStepper

stepper_methods: Tuple[Callable] = StepperMethodCollector(Stepper).step_methods()
do_step_method: Callable = SystemStepper.do_step
stepper: StepperProtocol, system_collection: SystemCollectionType
) -> Tuple[
Callable[
[StepperProtocol, SystemCollectionType, np.floating, np.floating], np.floating
],
SteppersOperatorsType,
]:
try:
stepper_methods: SteppersOperatorsType = stepper.steps_and_prefactors
do_step_method: Callable = stepper.do_step # type: ignore[attr-defined]
except AttributeError as e:
raise NotImplementedError(f"{stepper} stepper is not supported.") from e
return do_step_method, stepper_methods


@overload
def integrate(
stepper: StepperProtocol,
systems: SystemType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...


@overload
def integrate(
StatefulStepper: StatefulStepperProtocol,
System: SystemType,
stepper: StepperProtocol,
systems: SystemCollectionType,
final_time: float,
n_steps: int,
restart_time: float,
progress_bar: bool,
) -> float: ...

skim0119 marked this conversation as resolved.
Show resolved Hide resolved

def integrate(
stepper: StepperProtocol,
systems: SystemType | SystemCollectionType,
final_time: float,
n_steps: int = 1000,
restart_time: float = 0.0,
Expand All @@ -70,9 +65,9 @@ def integrate(

Parameters
----------
StatefulStepper : StatefulStepperProtocol
stepper : StepperProtocol
Stepper algorithm to use.
System : SystemType
systems : SystemType | SystemCollectionType
The elastica-system to simulate.
final_time : float
Total simulation time. The timestep is determined by final_time / n_steps.
Expand All @@ -86,17 +81,15 @@ def integrate(
assert final_time > 0.0, "Final time is negative!"
assert n_steps > 0, "Number of integration steps is negative!"

# Extend the stepper's interface after introspecting the properties
# of the system. If system is a collection of small systems (whose
# states cannot be aggregated), then stepper now loops over the system
# state
do_step, stages_and_updates = extend_stepper_interface(StatefulStepper, System)

dt = np.float64(float(final_time) / n_steps)
time = restart_time
dt = np.float_(float(final_time) / n_steps)
time = np.float_(restart_time)

for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = do_step(StatefulStepper, stages_and_updates, System, time, dt)
if is_system_a_collection(systems):
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step(systems, time, dt) # type: ignore[arg-type]
else:
for i in tqdm(range(n_steps), disable=(not progress_bar)):
time = stepper.step_single_instance(systems, time, dt) # type: ignore[arg-type]
skim0119 marked this conversation as resolved.
Show resolved Hide resolved

print("Final time of simulation is : ", time)
return time
return float(time)
Loading
Loading