Skip to content

Commit

Permalink
Extract MPI code from execute_task() (#3702)
Browse files Browse the repository at this point in the history
# Description

The `execute_task()` function is used by multiple executors, but the MPI
code is specific to HTEX.

## Type of change

- Code maintenance/cleanup
  • Loading branch information
rjmello authored Nov 19, 2024
1 parent 9fb5269 commit 91146c1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
1 change: 1 addition & 0 deletions parsl/executors/high_throughput/mpi_resource_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def put_task(self, task_package: dict):
self._map_tasks_to_nodes[task_package["task_id"]] = allocated_nodes
buffer = pack_res_spec_apply_message(_f, _args, _kwargs, resource_spec)
task_package["buffer"] = buffer
task_package["resource_spec"] = resource_spec

self.pending_task_q.put(task_package)

Expand Down
44 changes: 24 additions & 20 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,28 +590,25 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i
os.environ[key] = prefix_table[key]


def execute_task(bufs, mpi_launcher: Optional[str] = None):
"""Deserialize the buffer and execute the task.
def _init_mpi_env(mpi_launcher: str, resource_spec: Dict):
node_list = resource_spec.get("MPI_NODELIST")
if node_list is None:
return
nodes_for_task = node_list.split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task)


def execute_task(bufs: bytes):
"""Deserialize the buffer and execute the task.
Returns the result or throws exception.
"""
user_ns = locals()
user_ns.update({'__builtins__': __builtins__})

f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False)
f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, copy=False)

for varname in resource_spec:
envname = "PARSL_" + str(varname).upper()
os.environ[envname] = str(resource_spec[varname])

if resource_spec.get("MPI_NODELIST"):
worker_id = os.environ['PARSL_WORKER_RANK']
nodes_for_task = resource_spec["MPI_NODELIST"].split(',')
logger.info(f"Launching task on provisioned nodes: {nodes_for_task}")
assert mpi_launcher
update_resource_spec_env_vars(mpi_launcher,
resource_spec=resource_spec,
node_info=nodes_for_task)
# We might need to look into callability of the function from itself
# since we change it's name in the new namespace
prefix = "parsl_"
Expand All @@ -620,13 +617,18 @@ def execute_task(bufs, mpi_launcher: Optional[str] = None):
kwargname = prefix + "kwargs"
resultname = prefix + "result"

user_ns.update({fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname})

code = "{0} = {1}(*{2}, **{3})".format(resultname, fname,
argname, kwargname)

user_ns = locals()
user_ns.update({
'__builtins__': __builtins__,
fname: f,
argname: args,
kwargname: kwargs,
resultname: resultname
})

exec(code, user_ns, user_ns)
return user_ns.get(resultname)

Expand Down Expand Up @@ -786,8 +788,10 @@ def manager_is_alive():
ready_worker_count.value -= 1
worker_enqueued = False

_init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"])

try:
result = execute_task(req['buffer'], mpi_launcher=mpi_launcher)
result = execute_task(req['buffer'])
serialized_result = serialize(result, buffer_threshold=1000000)
except Exception as e:
logger.info('Caught an exception: {}'.format(e))
Expand Down

0 comments on commit 91146c1

Please sign in to comment.