diff --git a/parsl/executors/high_throughput/mpi_resource_management.py b/parsl/executors/high_throughput/mpi_resource_management.py index 3f3fc33ea4..09745421d1 100644 --- a/parsl/executors/high_throughput/mpi_resource_management.py +++ b/parsl/executors/high_throughput/mpi_resource_management.py @@ -177,6 +177,7 @@ def put_task(self, task_package: dict): self._map_tasks_to_nodes[task_package["task_id"]] = allocated_nodes buffer = pack_res_spec_apply_message(_f, _args, _kwargs, resource_spec) task_package["buffer"] = buffer + task_package["resource_spec"] = resource_spec self.pending_task_q.put(task_package) diff --git a/parsl/executors/high_throughput/process_worker_pool.py b/parsl/executors/high_throughput/process_worker_pool.py index e75af86743..957d670188 100755 --- a/parsl/executors/high_throughput/process_worker_pool.py +++ b/parsl/executors/high_throughput/process_worker_pool.py @@ -590,28 +590,25 @@ def update_resource_spec_env_vars(mpi_launcher: str, resource_spec: Dict, node_i os.environ[key] = prefix_table[key] -def execute_task(bufs, mpi_launcher: Optional[str] = None): - """Deserialize the buffer and execute the task. +def _init_mpi_env(mpi_launcher: str, resource_spec: Dict): + node_list = resource_spec.get("MPI_NODELIST") + if node_list is None: + return + nodes_for_task = node_list.split(',') + logger.info(f"Launching task on provisioned nodes: {nodes_for_task}") + update_resource_spec_env_vars(mpi_launcher=mpi_launcher, resource_spec=resource_spec, node_info=nodes_for_task) + +def execute_task(bufs: bytes): + """Deserialize the buffer and execute the task. Returns the result or throws exception. """ - user_ns = locals() - user_ns.update({'__builtins__': __builtins__}) - - f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, user_ns, copy=False) + f, args, kwargs, resource_spec = unpack_res_spec_apply_message(bufs, copy=False) for varname in resource_spec: envname = "PARSL_" + str(varname).upper() os.environ[envname] = str(resource_spec[varname]) - if resource_spec.get("MPI_NODELIST"): - worker_id = os.environ['PARSL_WORKER_RANK'] - nodes_for_task = resource_spec["MPI_NODELIST"].split(',') - logger.info(f"Launching task on provisioned nodes: {nodes_for_task}") - assert mpi_launcher - update_resource_spec_env_vars(mpi_launcher, - resource_spec=resource_spec, - node_info=nodes_for_task) # We might need to look into callability of the function from itself # since we change it's name in the new namespace prefix = "parsl_" @@ -620,13 +617,18 @@ def execute_task(bufs, mpi_launcher: Optional[str] = None): kwargname = prefix + "kwargs" resultname = prefix + "result" - user_ns.update({fname: f, - argname: args, - kwargname: kwargs, - resultname: resultname}) - code = "{0} = {1}(*{2}, **{3})".format(resultname, fname, argname, kwargname) + + user_ns = locals() + user_ns.update({ + '__builtins__': __builtins__, + fname: f, + argname: args, + kwargname: kwargs, + resultname: resultname + }) + exec(code, user_ns, user_ns) return user_ns.get(resultname) @@ -786,8 +788,10 @@ def manager_is_alive(): ready_worker_count.value -= 1 worker_enqueued = False + _init_mpi_env(mpi_launcher=mpi_launcher, resource_spec=req["resource_spec"]) + try: - result = execute_task(req['buffer'], mpi_launcher=mpi_launcher) + result = execute_task(req['buffer']) serialized_result = serialize(result, buffer_threshold=1000000) except Exception as e: logger.info('Caught an exception: {}'.format(e))