From 43d4ff9df8249b6b641dadc6a169c060b8de5017 Mon Sep 17 00:00:00 2001 From: Kevin Hunter Kesling Date: Tue, 12 Nov 2024 23:13:15 -0500 Subject: [PATCH] Move result size check one level higher (#1714) No functional change but move result length check one level higher in the call stack. There's no sense in passing the size limit. `_call_user_function`'s purview is to invoke the function. Checking the result validity is for the higher-level helper. While in `_call_user_function`, make IDE slightly happier by not shadowing the module variable. --- .../globus_compute_endpoint/engines/helper.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/compute_endpoint/globus_compute_endpoint/engines/helper.py b/compute_endpoint/globus_compute_endpoint/engines/helper.py index 67238d03f..accfcbd8b 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/helper.py +++ b/compute_endpoint/globus_compute_endpoint/engines/helper.py @@ -20,7 +20,7 @@ log = logging.getLogger(__name__) -serializer = ComputeSerializer() +_serde = ComputeSerializer() def execute_task( @@ -76,7 +76,12 @@ def execute_task( try: _task, task_buffer = _unpack_messagebody(task_body) log.debug("executing task task_id='%s'", task_id) - result = _call_user_function(task_buffer, result_size_limit=result_size_limit) + result = _call_user_function(task_buffer) + + res_len = len(result) + if res_len > result_size_limit: + raise MaxResultSizeExceeded(res_len, result_size_limit) + log.debug("Execution completed without exception") result_message = dict(task_id=task_id, data=result) @@ -139,29 +144,20 @@ def _unpack_messagebody(message: bytes) -> t.Tuple[Task, str]: return task, task_buffer -def _call_user_function( - task_buffer: str, result_size_limit: int, serializer=serializer -) -> str: +def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str: """Deserialize the buffer and execute the task. Parameters ---------- task_buffer: serialized buffer of (fn, args, kwargs) - result_size_limit: size limit in bytes for results - serializer: serializer for the buffers + serde: serializer for the buffers Returns ------- Returns serialized result or throws exception. """ GC_TASK_TIMEOUT = max(0.0, float(os.environ.get("GC_TASK_TIMEOUT", 0.0))) - f, args, kwargs = serializer.unpack_and_deserialize(task_buffer) + f, args, kwargs = serde.unpack_and_deserialize(task_buffer) if GC_TASK_TIMEOUT > 0.0: log.debug(f"Setting task timeout to GC_TASK_TIMEOUT={GC_TASK_TIMEOUT}s") f = timeout(f, GC_TASK_TIMEOUT) - result_data = f(*args, **kwargs) - serialized_data = serializer.serialize(result_data) - - if len(serialized_data) > result_size_limit: - raise MaxResultSizeExceeded(len(serialized_data), result_size_limit) - - return serialized_data + return serde.serialize(f(*args, **kwargs))