diff --git a/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst b/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst new file mode 100644 index 000000000..bcfcd292f --- /dev/null +++ b/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst @@ -0,0 +1,5 @@ +Bug Fixes +^^^^^^^^^ + +- The ``GlobusComputeEngine`` has been updated to fully support the + ``heartbeat_period`` parameter. \ No newline at end of file diff --git a/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py index 0d6c2ff31..79910482e 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py +++ b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py @@ -41,7 +41,7 @@ def __init__( self.max_workers_per_node = 1 if executor is None: executor = HighThroughputExecutor( # type: ignore - *args, address=address, **kwargs + *args, address=address, heartbeat_period=heartbeat_period, **kwargs ) self.executor = executor diff --git a/compute_endpoint/tests/conftest.py b/compute_endpoint/tests/conftest.py index a244e3954..483342cb1 100644 --- a/compute_endpoint/tests/conftest.py +++ b/compute_endpoint/tests/conftest.py @@ -106,8 +106,8 @@ def func(): @pytest.fixture -def engine_heartbeat() -> float: - return 0.1 +def engine_heartbeat() -> int: + return 1 @pytest.fixture diff --git a/compute_endpoint/tests/unit/test_engines.py b/compute_endpoint/tests/unit/test_engines.py index 42868ef5b..d860b5668 100644 --- a/compute_endpoint/tests/unit/test_engines.py +++ b/compute_endpoint/tests/unit/test_engines.py @@ -19,6 +19,7 @@ from globus_compute_endpoint.engines.base import GlobusComputeEngineBase from globus_compute_sdk.serialize import ComputeSerializer from parsl.executors.high_throughput.interchange import ManagerLost +from pytest_mock import MockFixture from tests.utils import double, ez_pack_function, slow_double logger = logging.getLogger(__name__) @@ -185,3 +186,17 @@ def test_serialized_engine_config_has_provider(engine_type: GlobusComputeEngineB executor = res["executors"][0].get("executor") or res["executors"][0] assert executor.get("provider") + + +def test_gcengine_pass_through_to_executor(mocker: MockFixture): + mock_executor = mocker.patch( + "globus_compute_endpoint.engines.globus_compute.HighThroughputExecutor" + ) + + args = ("arg1", 2) + kwargs = {"address": "127.0.0.1", "heartbeat_period": 10, "foo": "bar"} + GlobusComputeEngine(*args, **kwargs) + + a, k = mock_executor.call_args + assert a == args + assert kwargs == k diff --git a/compute_endpoint/tests/unit/test_status_reporting.py b/compute_endpoint/tests/unit/test_status_reporting.py index 493332ff1..88c7edb6f 100644 --- a/compute_endpoint/tests/unit/test_status_reporting.py +++ b/compute_endpoint/tests/unit/test_status_reporting.py @@ -12,7 +12,7 @@ "engine_type", (engines.ProcessPoolEngine, engines.ThreadPoolEngine, engines.GlobusComputeEngine), ) -def test_status_reporting(engine_type, engine_runner, engine_heartbeat: float): +def test_status_reporting(engine_type, engine_runner, engine_heartbeat: int): engine = engine_runner(engine_type) report = engine.get_status_report() @@ -28,7 +28,7 @@ def test_status_reporting(engine_type, engine_runner, engine_heartbeat: float): # Confirm heartbeats in regular intervals for _i in range(3): - q_msg = results_q.get(timeout=1) + q_msg = results_q.get(timeout=2) assert isinstance(q_msg, dict) message = q_msg["message"]