Skip to content

Commit

Permalink
Move result size check one level higher (#1714)
Browse files Browse the repository at this point in the history
No functional change but move result length check one level higher in the call
stack.  There's no sense in passing the size limit.  `_call_user_function`'s
purview is to invoke the function.  Checking the result validity is for the
higher-level helper.

While in `_call_user_function`, make IDE slightly happier by not shadowing the
module variable.
  • Loading branch information
khk-globus committed Nov 13, 2024
1 parent 1cd2f0d commit c641420
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions compute_endpoint/globus_compute_endpoint/engines/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

log = logging.getLogger(__name__)

serializer = ComputeSerializer()
_serde = ComputeSerializer()


def execute_task(
Expand Down Expand Up @@ -76,7 +76,12 @@ def execute_task(
try:
_task, task_buffer = _unpack_messagebody(task_body)
log.debug("executing task task_id='%s'", task_id)
result = _call_user_function(task_buffer, result_size_limit=result_size_limit)
result = _call_user_function(task_buffer)

res_len = len(result)
if res_len > result_size_limit:
raise MaxResultSizeExceeded(res_len, result_size_limit)

log.debug("Execution completed without exception")
result_message = dict(task_id=task_id, data=result)

Expand Down Expand Up @@ -139,29 +144,20 @@ def _unpack_messagebody(message: bytes) -> t.Tuple[Task, str]:
return task, task_buffer


def _call_user_function(
task_buffer: str, result_size_limit: int, serializer=serializer
) -> str:
def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str:
"""Deserialize the buffer and execute the task.
Parameters
----------
task_buffer: serialized buffer of (fn, args, kwargs)
result_size_limit: size limit in bytes for results
serializer: serializer for the buffers
serde: serializer for the buffers
Returns
-------
Returns serialized result or throws exception.
"""
GC_TASK_TIMEOUT = max(0.0, float(os.environ.get("GC_TASK_TIMEOUT", 0.0)))
f, args, kwargs = serializer.unpack_and_deserialize(task_buffer)
f, args, kwargs = serde.unpack_and_deserialize(task_buffer)
if GC_TASK_TIMEOUT > 0.0:
log.debug(f"Setting task timeout to GC_TASK_TIMEOUT={GC_TASK_TIMEOUT}s")
f = timeout(f, GC_TASK_TIMEOUT)

result_data = f(*args, **kwargs)
serialized_data = serializer.serialize(result_data)

if len(serialized_data) > result_size_limit:
raise MaxResultSizeExceeded(len(serialized_data), result_size_limit)

return serialized_data
return serde.serialize(f(*args, **kwargs))

0 comments on commit c641420

Please sign in to comment.