diff --git a/compute_endpoint/tests/conftest.py b/compute_endpoint/tests/conftest.py index 17dc144f8..3acb05799 100644 --- a/compute_endpoint/tests/conftest.py +++ b/compute_endpoint/tests/conftest.py @@ -12,14 +12,14 @@ import globus_sdk import pytest import responses -from globus_compute_common import messagepack from globus_compute_endpoint import engines from globus_compute_endpoint.engines.base import GlobusComputeEngineBase from globus_compute_sdk.sdk.web_client import WebClient from globus_compute_sdk.serialize import ComputeSerializer from parsl.launchers import SimpleLauncher from parsl.providers import LocalProvider -from tests.utils import ez_pack_function + +from .utils import create_task_packer @pytest.fixture(autouse=True) @@ -215,12 +215,4 @@ def serde(): @pytest.fixture def ez_pack_task(serde, task_uuid, container_uuid): - def _pack_it(fn, *a, **k) -> bytes: - task_body = ez_pack_function(serde, fn, a, k) - return messagepack.pack( - messagepack.message_types.Task( - task_id=task_uuid, container_id=container_uuid, task_buffer=task_body - ) - ) - - return _pack_it + return create_task_packer(serde, task_uuid, container_uuid) diff --git a/compute_endpoint/tests/utils.py b/compute_endpoint/tests/utils.py index 65551e206..3d8260d1a 100644 --- a/compute_endpoint/tests/utils.py +++ b/compute_endpoint/tests/utils.py @@ -1,9 +1,15 @@ +from __future__ import annotations + import itertools import pathlib import sys import time import types import typing as t +import uuid + +from globus_compute_common import messagepack +from globus_compute_sdk.serialize import ComputeSerializer def create_traceback(start: int = 0) -> types.TracebackType: @@ -86,6 +92,48 @@ def ez_pack_function(serializer, func, args, kwargs): ) +def create_task_packer( + serde: ComputeSerializer | None = None, + task_uuid: uuid.UUID | None = None, + container_uuid: uuid.UUID | None = None, +) -> t.Callable[[t.Callable, ...], bytes]: + """ + A quick go-to for easier development while hacking on the engine submit routines. + + Reminder for the dev: + + >>> import uuid + >>> from tests.utils import create_task_packer + >>> from globus_compute_common import messagepack + >>> from globus_compute_sdk.serialize import ComputeSerializer + >>> from globus_compute_endpoint.engines import ThreadPoolEngine + >>> + >>> def some_func(*a, **k): + ... return f"[args=<{a}>, k=<{k}>]" + ... + >>> pack_task = create_task_packer() + >>> task_bytes = pack_task(some_func) + >>> + >>> e = ThreadPoolEngine() + >>> e.start(endpoint_id=uuid.uuid4()) + >>> encoded_result = e.submit("some_task_id", task_bytes, {}).result() + >>> result = messagepack.unpack(encoded_result) + >>> payload = serde.deserialize(result.data) + """ + serde = serde or ComputeSerializer() + task_uuid = task_uuid or uuid.uuid4() + + def _pack_it(fn, *a, **k) -> bytes: + task_body = ez_pack_function(serde, fn, a, k) + return messagepack.pack( + messagepack.message_types.Task( + task_id=task_uuid, container_id=container_uuid, task_buffer=task_body + ) + ) + + return _pack_it + + def double(x: int) -> int: return x * 2