From d37e03fac9a274355a4ad6f7bda6e28e16fcc3ed Mon Sep 17 00:00:00 2001 From: Kevin Hunter Kesling Date: Tue, 12 Nov 2024 15:15:27 -0500 Subject: [PATCH] Update execute_task signature (#1719) `run_dir` is not optional; in a previous commit, it was enforced by an inline check. Update the function signature to match that requirement. While here, test the addition, and also bolster the module's tests in general. --- .../globus_compute_endpoint/engines/helper.py | 33 ++++--- compute_endpoint/tests/unit/conftest.py | 7 ++ .../tests/unit/test_execute_task.py | 97 +++++++++++++++++-- compute_endpoint/tests/unit/test_worker.py | 32 +----- 4 files changed, 119 insertions(+), 50 deletions(-) diff --git a/compute_endpoint/globus_compute_endpoint/engines/helper.py b/compute_endpoint/globus_compute_endpoint/engines/helper.py index accfcbd8b..3e27e209f 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/helper.py +++ b/compute_endpoint/globus_compute_endpoint/engines/helper.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging import os +import pathlib import time import typing as t import uuid @@ -21,14 +24,16 @@ log = logging.getLogger(__name__) _serde = ComputeSerializer() +_RESULT_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MiB def execute_task( task_id: uuid.UUID, task_body: bytes, - endpoint_id: t.Optional[uuid.UUID], - result_size_limit: int = 10 * 1024 * 1024, - run_dir: t.Optional[t.Union[str, os.PathLike]] = None, + endpoint_id: t.Optional[uuid.UUID] = None, + *, + run_dir: t.Union[str, os.PathLike], + result_size_limit: int = _RESULT_SIZE_LIMIT, run_in_sandbox: bool = False, ) -> bytes: """Execute task is designed to enable any executor to execute a Task payload @@ -38,7 +43,7 @@ def execute_task( ---------- task_id: uuid string task_body: packed message as bytes - endpoint_id: uuid string or None + endpoint_id: uuid.UUID or None result_size_limit: result size in bytes run_dir: directory to run function in run_in_sandbox: if enabled run task under run_dir/ @@ -56,19 +61,25 @@ def execute_task( uuid.UUID | str | tuple[str, str] | list[TaskTransition] | dict[str, str], ] + task_id_str = str(task_id) os.environ.pop("GC_TASK_SANDBOX_DIR", None) - os.environ["GC_TASK_UUID"] = str(task_id) + os.environ["GC_TASK_UUID"] = task_id_str - if not run_dir or not os.path.isabs(run_dir): - raise RuntimeError( - f"execute_task requires an absolute path for run_dir, got {run_dir=}" + if result_size_limit < 128: + raise ValueError( + f"Invalid result limit; must be at least 128 bytes ({result_size_limit=})" ) - os.makedirs(run_dir, exist_ok=True) + run_dir = pathlib.Path(run_dir) + if not run_dir.is_absolute(): + raise ValueError(f"Absolute path required, not relative: {str(run_dir)}") + + run_dir.mkdir(parents=True, exist_ok=True) os.chdir(run_dir) if run_in_sandbox: - os.makedirs(str(task_id)) # task_id is expected to be unique - os.chdir(str(task_id)) + sb_dir = run_dir / task_id_str + sb_dir.mkdir(exist_ok=True) # handle task re-transmission gracefully + os.chdir(sb_dir) # Set sandbox dir so that apps can use it os.environ["GC_TASK_SANDBOX_DIR"] = os.getcwd() diff --git a/compute_endpoint/tests/unit/conftest.py b/compute_endpoint/tests/unit/conftest.py index c31244a1f..e1623e104 100644 --- a/compute_endpoint/tests/unit/conftest.py +++ b/compute_endpoint/tests/unit/conftest.py @@ -1,3 +1,4 @@ +import functools import inspect import os import pathlib @@ -6,6 +7,7 @@ import uuid import pytest +from globus_compute_endpoint.engines.helper import execute_task from tests.conftest import randomstring_impl @@ -85,3 +87,8 @@ def get_random_of_datatype_impl(cls): @pytest.fixture def get_random_of_datatype(): return get_random_of_datatype_impl + + +@pytest.fixture +def execute_task_runner(task_uuid, tmp_path): + return functools.partial(execute_task, task_uuid, run_dir=tmp_path) diff --git a/compute_endpoint/tests/unit/test_execute_task.py b/compute_endpoint/tests/unit/test_execute_task.py index 06b156e86..f30f14776 100644 --- a/compute_endpoint/tests/unit/test_execute_task.py +++ b/compute_endpoint/tests/unit/test_execute_task.py @@ -1,10 +1,12 @@ import logging +import os import random from unittest import mock import pytest from globus_compute_common import messagepack -from globus_compute_endpoint.engines.helper import execute_task +from globus_compute_endpoint.engines.helper import _RESULT_SIZE_LIMIT, execute_task +from globus_compute_sdk.errors import MaxResultSizeExceeded from tests.utils import divide logger = logging.getLogger(__name__) @@ -12,24 +14,28 @@ _MOCK_BASE = "globus_compute_endpoint.engines.helper." -@pytest.mark.parametrize("run_dir", ("tmp", None, "$HOME")) +@pytest.mark.parametrize("run_dir", ("", ".", "./", "../", "tmp", "$HOME")) def test_bad_run_dir(endpoint_uuid, task_uuid, run_dir): - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): # not absolute execute_task(task_uuid, b"", endpoint_uuid, run_dir=run_dir) + with pytest.raises(TypeError): # not anything, allow-any-type language + execute_task(task_uuid, b"", endpoint_uuid, run_dir=None) -def test_execute_task(endpoint_uuid, serde, task_uuid, ez_pack_task, tmp_path): + +def test_happy_path(serde, task_uuid, ez_pack_task, execute_task_runner): out = random.randint(1, 100_000) divisor = random.randint(1, 100_000) task_bytes = ez_pack_task(divide, divisor * out, divisor) - packed_result = execute_task(task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path) + packed_result = execute_task_runner(task_bytes) assert isinstance(packed_result, bytes) result = messagepack.unpack(packed_result) assert isinstance(result, messagepack.message_types.Result) assert result.data + assert result.task_id == task_uuid assert "os" in result.details assert "python_version" in result.details assert "dill_version" in result.details @@ -37,13 +43,86 @@ def test_execute_task(endpoint_uuid, serde, task_uuid, ez_pack_task, tmp_path): assert serde.deserialize(result.data) == out -def test_execute_task_with_exception(endpoint_uuid, task_uuid, ez_pack_task, tmp_path): +def test_sandbox(ez_pack_task, execute_task_runner, task_uuid, tmp_path): + task_bytes = ez_pack_task(divide, 10, 2) + packed_result = execute_task_runner(task_bytes, run_in_sandbox=True) + result = messagepack.unpack(packed_result) + assert result.task_id == task_uuid + assert result.error_details is None, "Verify test setup: execution successful" + + exp_dir = tmp_path / str(task_uuid) + assert os.environ.get("GC_TASK_SANDBOX_DIR") == str(exp_dir), "Share dir w/ func" + assert os.getcwd() == str(exp_dir), "Expect sandbox dir entered" + + +def test_nested_run_dir(ez_pack_task, task_uuid, tmp_path): + task_bytes = ez_pack_task(divide, 10, 2) + nested_root = tmp_path / "a/" + nested_path = nested_root / "b/c/d" + assert not nested_root.exists(), "Verify test setup" + + packed_result = execute_task(task_uuid, task_bytes, run_dir=nested_path) + + result = messagepack.unpack(packed_result) + assert result.error_details is None, "Verify test setup: execution successful" + + assert nested_path.exists(), "Test namesake" + + +@pytest.mark.parametrize("size_limit", (128, 256, 1024, 4096, _RESULT_SIZE_LIMIT)) +def test_result_size_limit(serde, ez_pack_task, execute_task_runner, size_limit): + task_bytes = ez_pack_task(divide, 10, 2) + exp_data = f"{MaxResultSizeExceeded.__name__}({size_limit + 1}, {size_limit})" + res_data_good = "a" * size_limit + res_data_bad = "a" * (size_limit + 1) + + with mock.patch(f"{_MOCK_BASE}_call_user_function") as mock_callfn: + with mock.patch(f"{_MOCK_BASE}log.exception"): # silence tests + mock_callfn.return_value = res_data_good + res_bytes = execute_task_runner(task_bytes, result_size_limit=size_limit) + result = messagepack.unpack(res_bytes) + assert result.data == res_data_good + + mock_callfn.return_value = res_data_bad + res_bytes = execute_task_runner(task_bytes, result_size_limit=size_limit) + result = messagepack.unpack(res_bytes) + assert exp_data == result.data + assert result.error_details.code == "MaxResultSizeExceeded" + + +def test_default_result_size_limit(ez_pack_task, execute_task_runner): + task_bytes = ez_pack_task(divide, 10, 2) + default = _RESULT_SIZE_LIMIT + exp_data = f"{MaxResultSizeExceeded.__name__}({default + 1}, {default})" + res_data_good = "a" * default + res_data_bad = "a" * (default + 1) + + with mock.patch(f"{_MOCK_BASE}_call_user_function") as mock_callfn: + with mock.patch(f"{_MOCK_BASE}log.exception"): # silence tests + mock_callfn.return_value = res_data_good + res_bytes = execute_task_runner(task_bytes) + result = messagepack.unpack(res_bytes) + assert result.data == res_data_good + + mock_callfn.return_value = res_data_bad + res_bytes = execute_task_runner(task_bytes) + result = messagepack.unpack(res_bytes) + assert exp_data == result.data + assert result.error_details.code == "MaxResultSizeExceeded" + + +@pytest.mark.parametrize("size_limit", (-5, 0, 1, 65, 127)) +def test_invalid_result_size_limit(size_limit): + with pytest.raises(ValueError) as pyt_e: + execute_task("test_tid", b"", run_dir="/", result_size_limit=5) + assert "must be at least" in str(pyt_e.value) + + +def test_execute_task_with_exception(ez_pack_task, execute_task_runner): task_bytes = ez_pack_task(divide, 10, 0) with mock.patch(f"{_MOCK_BASE}log") as mock_log: - packed_result = execute_task( - task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path - ) + packed_result = execute_task_runner(task_bytes) assert mock_log.exception.called a, _k = mock_log.exception.call_args diff --git a/compute_endpoint/tests/unit/test_worker.py b/compute_endpoint/tests/unit/test_worker.py index 2bd99b73c..62ed6f984 100644 --- a/compute_endpoint/tests/unit/test_worker.py +++ b/compute_endpoint/tests/unit/test_worker.py @@ -5,7 +5,6 @@ import pytest from globus_compute_common import messagepack -from globus_compute_endpoint.engines.helper import execute_task from globus_compute_endpoint.engines.high_throughput.messages import Task from globus_compute_endpoint.engines.high_throughput.worker import Worker @@ -124,39 +123,12 @@ def test_execute_failing_function(test_worker): ) -def test_execute_function_exceeding_result_size_limit( - test_worker, endpoint_uuid, task_uuid, ez_pack_task, tmp_path -): - return_size = 10 - - task_bytes = ez_pack_task(large_result, return_size) - - with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log: - s_result = execute_task( - task_uuid, - task_bytes, - endpoint_uuid, - result_size_limit=return_size - 2, - run_dir=tmp_path, - ) - result = messagepack.unpack(s_result) - - assert isinstance(result, messagepack.message_types.Result) - assert result.error_details - assert result.task_id == task_uuid - assert result.error_details - assert result.error_details.code == "MaxResultSizeExceeded" - assert mock_log.exception.called - - -def test_app_timeout(test_worker, endpoint_uuid, task_uuid, ez_pack_task, tmp_path): +def test_app_timeout(test_worker, execute_task_runner, task_uuid, ez_pack_task): task_bytes = ez_pack_task(sleeper, 1) with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log: with mock.patch.dict(os.environ, {"GC_TASK_TIMEOUT": "0.01"}): - packed_result = execute_task( - task_uuid, task_bytes, endpoint_uuid, run_dir=tmp_path - ) + packed_result = execute_task_runner(task_bytes) result = messagepack.unpack(packed_result) assert isinstance(result, messagepack.message_types.Result)