diff --git a/compute_endpoint/globus_compute_endpoint/engines/helper.py b/compute_endpoint/globus_compute_endpoint/engines/helper.py index 67238d03f..accfcbd8b 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/helper.py +++ b/compute_endpoint/globus_compute_endpoint/engines/helper.py @@ -20,7 +20,7 @@ log = logging.getLogger(__name__) -serializer = ComputeSerializer() +_serde = ComputeSerializer() def execute_task( @@ -76,7 +76,12 @@ def execute_task( try: _task, task_buffer = _unpack_messagebody(task_body) log.debug("executing task task_id='%s'", task_id) - result = _call_user_function(task_buffer, result_size_limit=result_size_limit) + result = _call_user_function(task_buffer) + + res_len = len(result) + if res_len > result_size_limit: + raise MaxResultSizeExceeded(res_len, result_size_limit) + log.debug("Execution completed without exception") result_message = dict(task_id=task_id, data=result) @@ -139,29 +144,20 @@ def _unpack_messagebody(message: bytes) -> t.Tuple[Task, str]: return task, task_buffer -def _call_user_function( - task_buffer: str, result_size_limit: int, serializer=serializer -) -> str: +def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str: """Deserialize the buffer and execute the task. Parameters ---------- task_buffer: serialized buffer of (fn, args, kwargs) - result_size_limit: size limit in bytes for results - serializer: serializer for the buffers + serde: serializer for the buffers Returns ------- Returns serialized result or throws exception. """ GC_TASK_TIMEOUT = max(0.0, float(os.environ.get("GC_TASK_TIMEOUT", 0.0))) - f, args, kwargs = serializer.unpack_and_deserialize(task_buffer) + f, args, kwargs = serde.unpack_and_deserialize(task_buffer) if GC_TASK_TIMEOUT > 0.0: log.debug(f"Setting task timeout to GC_TASK_TIMEOUT={GC_TASK_TIMEOUT}s") f = timeout(f, GC_TASK_TIMEOUT) - result_data = f(*args, **kwargs) - serialized_data = serializer.serialize(result_data) - - if len(serialized_data) > result_size_limit: - raise MaxResultSizeExceeded(len(serialized_data), result_size_limit) - - return serialized_data + return serde.serialize(f(*args, **kwargs))