diff --git a/changelog.d/20250109_164900_kevin_remove_deprecated_htex.rst b/changelog.d/20250109_164900_kevin_remove_deprecated_htex.rst new file mode 100644 index 000000000..53b1fc3bd --- /dev/null +++ b/changelog.d/20250109_164900_kevin_remove_deprecated_htex.rst @@ -0,0 +1,6 @@ +Removed +^^^^^^^ + +- Remove ``HighThroughputEngine``. This class was deprecated in :ref:`v2.27.0 + `. We recommend migrating relevant configurations to use + |GlobusComputeEngine|. diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/messages_compat.py b/compute_endpoint/globus_compute_endpoint/endpoint/messages_compat.py deleted file mode 100644 index 5d2b12758..000000000 --- a/compute_endpoint/globus_compute_endpoint/endpoint/messages_compat.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -import logging -import pickle -import uuid - -from globus_compute_common.messagepack import Message as OutgoingMessage -from globus_compute_common.messagepack import pack -from globus_compute_common.messagepack.message_types import ( - EPStatusReport as OutgoingEPStatusReport, -) -from globus_compute_common.messagepack.message_types import Result as OutgoingResult -from globus_compute_common.messagepack.message_types import ( - ResultErrorDetails as OutgoingResultErrorDetails, -) -from globus_compute_common.messagepack.message_types import Task as OutgoingTask -from globus_compute_common.messagepack.message_types import TaskTransition -from globus_compute_endpoint.engines.high_throughput.messages import ( - EPStatusReport as InternalEPStatusReport, -) -from globus_compute_endpoint.engines.high_throughput.messages import ( - Task as InternalTask, -) - -logger = logging.getLogger(__name__) - - -def convert_ep_status_report( - internal: InternalEPStatusReport, -) -> OutgoingEPStatusReport: - messagepack_msg = OutgoingEPStatusReport( - endpoint_id=internal._header, - global_state=internal.global_state, - task_statuses=internal.task_statuses, - ) - return messagepack_msg - - -def try_convert_to_messagepack(message: bytes) -> bytes: - try: - unpacked = pickle.loads(message) - except pickle.UnpicklingError: - # message isn't pickled; assume that it's already in messagepack format - return message - - messagepack_msg: OutgoingMessage | None = None - - if isinstance(unpacked, InternalEPStatusReport): - messagepack_msg = OutgoingEPStatusReport( - endpoint_id=unpacked._header, - global_state=unpacked.global_state, - task_statuses=unpacked.task_statuses, - ) - elif isinstance(unpacked, dict): - kwargs: dict[ - str, str | uuid.UUID | OutgoingResultErrorDetails | list[TaskTransition] - ] = { - "task_id": uuid.UUID(unpacked["task_id"]), - } - if "details" in unpacked: - kwargs["details"] = unpacked["details"] - if "task_statuses" in unpacked: - kwargs["task_statuses"] = unpacked["task_statuses"] - if "exception" in unpacked: - kwargs["data"] = unpacked["exception"] - code, user_message = unpacked.get("error_details", ("Unknown", "Unknown")) - kwargs["error_details"] = OutgoingResultErrorDetails( - code=code, user_message=user_message - ) - else: - kwargs["data"] = unpacked["data"] - - messagepack_msg = OutgoingResult(**kwargs) - - if messagepack_msg: - message = pack(messagepack_msg) - - return message - - -def convert_to_internaltask(message: OutgoingTask, container_type: str | None) -> bytes: - container_loc = "RAW" - if message.container: - for img in message.container.images: - if img.image_type == container_type: - container_loc = img.location - break - - return InternalTask( - task_id=str(message.task_id), - container_id=container_loc, - task_buffer=message.task_buffer, - ).pack() diff --git a/compute_endpoint/globus_compute_endpoint/engines/__init__.py b/compute_endpoint/globus_compute_endpoint/engines/__init__.py index 0629306f1..f7d43706a 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/__init__.py +++ b/compute_endpoint/globus_compute_endpoint/engines/__init__.py @@ -1,6 +1,5 @@ from .globus_compute import GlobusComputeEngine from .globus_mpi import GlobusMPIEngine -from .high_throughput.engine import HighThroughputEngine from .process_pool import ProcessPoolEngine from .thread_pool import ThreadPoolEngine @@ -9,5 +8,4 @@ "GlobusMPIEngine", "ProcessPoolEngine", "ThreadPoolEngine", - "HighThroughputEngine", ) diff --git a/compute_endpoint/globus_compute_endpoint/engines/helper.py b/compute_endpoint/globus_compute_endpoint/engines/helper.py index 814cc4749..99e6522e5 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/helper.py +++ b/compute_endpoint/globus_compute_endpoint/engines/helper.py @@ -9,7 +9,6 @@ from globus_compute_common import messagepack from globus_compute_common.messagepack.message_types import Result, Task, TaskTransition from globus_compute_common.tasks import ActorName, TaskState -from globus_compute_endpoint.engines.high_throughput.messages import Message from globus_compute_endpoint.exception_handling import ( get_error_string, get_result_error_details, @@ -135,22 +134,12 @@ def _unpack_messagebody(message: bytes) -> tuple[Task, str]: ------- tuple(task, task_buffer) """ - try: - task = messagepack.unpack(message) - if not isinstance(task, messagepack.message_types.Task): - raise CouldNotExecuteUserTaskError( - f"wrong type of message in worker: {type(task)}" - ) - task_buffer = task.task_buffer - # on parse errors, failover to trying the "legacy" message reading - except ( - messagepack.InvalidMessageError, - messagepack.UnrecognizedProtocolVersion, - ): - task = Message.unpack(message) - assert isinstance(task, Task) - task_buffer = task.task_buffer.decode("utf-8") # type: ignore[attr-defined] - return task, task_buffer + task = messagepack.unpack(message) + if not isinstance(task, messagepack.message_types.Task): + raise CouldNotExecuteUserTaskError( + f"wrong type of message in worker: {type(task)}" + ) + return task, task.task_buffer def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str: diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/__init__.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/container_sched.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/container_sched.py deleted file mode 100644 index d74915000..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/container_sched.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -import math -import random - -from globus_compute_endpoint.logging_config import ComputeLogger - -log: ComputeLogger = logging.getLogger(__name__) # type: ignore - - -def naive_scheduler( - task_qs, outstanding_task_count, max_workers, old_worker_map, to_die_list -): - """ - Return two items (as one tuple) - dict kill_list :: KILL [(worker_type, num_kill), ...] - dict create_list :: CREATE [(worker_type, num_create), ...] - - In this scheduler model, there is minimum 1 instance of each nonempty task queue. - """ - - log.trace("Entering scheduler...") - log.trace("old_worker_map: %s", old_worker_map) - q_sizes = {} - q_types = [] - new_worker_map = {} - - # Sum the size of each *available* (unblocked) task queue - sum_q_size = 0 - for q_type in outstanding_task_count: - q_types.append(q_type) - q_size = outstanding_task_count[q_type] - sum_q_size += q_size - q_sizes[q_type] = q_size - - if sum_q_size > 0: - log.info(f"Total number of tasks is {sum_q_size}") - - # Set proportions of workers equal to the proportion of queue size. - for q_type in q_sizes: - ratio = q_sizes[q_type] / sum_q_size - new_worker_map[q_type] = min( - int(math.floor(ratio * max_workers)), q_sizes[q_type] - ) - - # Check the difference - tmp_sum_q_size = sum(new_worker_map.values()) - difference = 0 - if sum_q_size > tmp_sum_q_size: - difference = min(max_workers - tmp_sum_q_size, sum_q_size - tmp_sum_q_size) - log.debug(f"Offset difference: {difference}") - log.debug(f"Queue Types: {q_types}") - - if len(q_types) > 0: - while difference > 0: - win_q = random.choice(q_types) - if q_sizes[win_q] > new_worker_map[win_q]: - new_worker_map[win_q] += 1 - difference -= 1 - - return new_worker_map diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py deleted file mode 100644 index 9d389d0e4..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py +++ /dev/null @@ -1,983 +0,0 @@ -"""HighThroughputEngine builds on Parsl's HighThroughputExecutor for execution -of functions within containerized workers in a distributed setting. -""" - -from __future__ import annotations - -import concurrent.futures -import ipaddress -import logging -import multiprocessing -import os -import queue -import socket -import threading -import time -import typing as t -import uuid -import warnings -from concurrent.futures import Future -from multiprocessing import Process - -import dill -from globus_compute_common import messagepack -from globus_compute_common.messagepack.message_types import ( - EPStatusReport as CommonEPStatusReport, -) -from globus_compute_common.messagepack.message_types import Result, ResultErrorDetails -from globus_compute_endpoint.endpoint.messages_compat import convert_ep_status_report -from globus_compute_endpoint.engines.base import GlobusComputeEngineBase -from globus_compute_endpoint.engines.helper import execute_task -from globus_compute_endpoint.engines.high_throughput import interchange, zmq_pipes -from globus_compute_endpoint.engines.high_throughput.mac_safe_queue import mpQueue -from globus_compute_endpoint.engines.high_throughput.messages import ( - EPStatusReport, - Heartbeat, - HeartbeatReq, - Task, - TaskCancel, -) -from globus_compute_endpoint.exception_handling import get_result_error_details -from globus_compute_endpoint.strategies.simple import SimpleStrategy -from globus_compute_sdk.serialize import ComputeSerializer -from parsl.errors import ConfigurationError -from parsl.executors.errors import BadMessage, ScalingFailed -from parsl.providers import LocalProvider -from parsl.utils import RepresentationMixin - -serializer = ComputeSerializer() - -log = logging.getLogger(__name__) - - -BUFFER_THRESHOLD = 1024 * 1024 -ITEM_THRESHOLD = 1024 - - -class HighThroughputEngine(GlobusComputeEngineBase, RepresentationMixin): - """Engine designed for cluster-scale - - The HighThroughputEngine system has the following components: - 1. The HighThroughputEngine instance which is run as a client - 2. The Interchange which is acts as a load-balancing proxy between workers and - Parsl - 3. The multiprocessing based worker pool which coordinates task execution over - several cores on a node. - 4. ZeroMQ pipes connect the HighThroughputEngine, Interchange and the - process_worker_pool - - Here is a diagram - - .. code:: python - - - | Data | Engine | Interchange | External Process(es) - | Flow | | | - Task | Kernel | | | - +----->|-------->|------------>|->outgoing_q---|-> process_worker_pool - | | | | batching | | | - Parsl<---Fut-| | | load-balancing| result exception - ^ | | | watchdogs | | | - | | | Q_mngmnt | | V V - | | | Thread<--|-incoming_q<---|--- +---------+ - | | | | | | - | | | | | | - +----update_fut-----+ - - - Parameters - ---------- - - provider : :class:`~parsl.providers.base.ExecutionProvider` - Provider to access computation resources. Can be one of - :class:`~parsl.providers.aws.aws.EC2Provider`, - :class:`~parsl.providers.cobalt.cobalt.Cobalt`, - :class:`~parsl.providers.condor.condor.Condor`, - :class:`~parsl.providers.googlecloud.googlecloud.GoogleCloud`, - :class:`~parsl.providers.gridEngine.gridEngine.GridEngine`, - :class:`~parsl.providers.jetstream.jetstream.Jetstream`, - :class:`~parsl.providers.local.local.Local`, - :class:`~parsl.providers.sge.sge.GridEngine`, - :class:`~parsl.providers.slurm.slurm.Slurm`, or - :class:`~parsl.providers.torque.torque.Torque`. - - label : str - Label for this Engine instance. - - launch_cmd : str - Command line string to launch the process_worker_pool from the provider. The - command line string will be formatted with appropriate values for the following - values: ( - debug, - task_url, - result_url, - cores_per_worker, - nodes_per_block, - heartbeat_period, - heartbeat_threshold, - logdir, - ). - For example: - launch_cmd="process_worker_pool.py {debug} -c {cores_per_worker} \ - --task_url={task_url} --result_url={result_url}" - - address : string - An address of the host on which the engine runs, which is reachable from the - network in which workers will be running. This can be either a hostname as - returned by `hostname` or an IP address. Most login nodes on clusters have - several network interfaces available, only some of which can be reached - from the compute nodes. Some trial and error might be necessary to - indentify what addresses are reachable from compute nodes. - - worker_ports : (int, int) - Specify the ports to be used by workers to connect to Parsl. If this - option is specified, worker_port_range will not be honored. - - worker_port_range : (int, int) - Worker ports will be chosen between the two integers provided. - - interchange_port_range : (int, int) - Port range used by Parsl to communicate with the Interchange. - - working_dir : str - Working dir to be used by the engine. - - worker_debug : Bool - Enables worker debug logging. - - cores_per_worker : float - cores to be assigned to each worker. Oversubscription is possible - by setting cores_per_worker < 1.0. Default=1 - - mem_per_worker : float - Memory to be assigned to each worker. Default=None(no limits) - - available_accelerators: int, list of str - Either a list of accelerators device IDs - or an integer defining the number of accelerators available. - If an integer, sequential device IDs will be created starting at 0. - The manager will ensure each worker is pinned to an accelerator and - will set the maximum number of workers per node to be no more - than the number of accelerators. - - Workers are pinned to specific accelerators using environment variables, - such as by setting the ``CUDA_VISIBLE_DEVICES`` or - ``SYCL_DEVICE_FILTER`` to the selected accelerator. - Default: None - - max_workers_per_node : int - Caps the number of workers launched by the manager. Default: infinity - - suppress_failure : Bool - If set, the interchange will suppress failures rather than terminate early. - Default: False - - heartbeat_threshold : int - Seconds since the last message from the counterpart in the communication pair: - (interchange, manager) after which the counterpart is assumed to be unavailable. - Default:120s - - heartbeat_period : int - Number of seconds after which a heartbeat message indicating liveness is sent to - the endpoint - counterpart (interchange, manager). Default:30s - - poll_period : int - Timeout period to be used by the engine components in milliseconds. - Increasing poll_periods trades performance for cpu efficiency. Default: 10ms - - container_image : str - Path or identfier to the container image to be used by the workers - - scheduler_mode: str - Scheduling mode to be used by the node manager. Options: 'hard', 'soft' - 'hard' -> managers cannot replace worker's container types - 'soft' -> managers can replace unused worker's containers based on demand - - worker_mode : str - Select the mode of operation from no_container, singularity_reuse, - singularity_single_use - Default: singularity_reuse - - container_cmd_options: str - Container command strings to be added to associated container command. - For example, singularity exec {container_cmd_options} - - task_status_queue : queue.Queue - Queue to pass updates to task statuses back to the forwarder. - - strategy: Stategy Object - Specify the scaling strategy to use for this engine. - - launch_cmd: str - Specify the launch command as using f-string format that will be used to specify - command to launch managers. Default: None - - prefetch_capacity: int - Number of tasks that can be fetched by managers in excess of available - workers is a prefetching optimization. This option can cause poor - load-balancing for long running functions. - Default: 10 - - provider: Provider object - Provider determines how managers can be provisioned, say LocalProvider - offers forked processes, and SlurmProvider interfaces to request - resources from the Slurm batch scheduler. - Default: LocalProvider - """ - - def __init__( - self, - label="HighThroughputEngine", - # NEW - strategy=SimpleStrategy(), - max_workers_per_node=float("inf"), - mem_per_worker=None, - launch_cmd=None, - available_accelerators=None, - # Container specific - worker_mode="no_container", - scheduler_mode="hard", - container_type=None, - container_cmd_options="", - cold_routing_interval=10.0, - # Tuning info - prefetch_capacity=10, - provider=LocalProvider(), - address="127.0.0.1", - worker_ports=None, - worker_port_range=(54000, 55000), - interchange_port_range=(55000, 56000), - storage_access=None, - working_dir=None, - worker_debug=False, - cores_per_worker=1.0, - heartbeat_threshold=120, - heartbeat_period=30, - poll_period=10, - container_image=None, - suppress_failure=True, - run_dir=None, - endpoint_id=None, - passthrough=True, - task_status_queue=None, - ): - warnings.warn( - "HighThroughputEngine is deprecated." - " Please use GlobusComputeEngine instead.", - DeprecationWarning, - stacklevel=2, - ) - log.debug("Initializing HighThroughputEngine") - self.provider = provider - self.label = label - self.launch_cmd = launch_cmd - self.worker_debug = worker_debug - self.max_workers_per_node = max_workers_per_node - - # NEW - self.strategy = strategy - self.cores_per_worker = cores_per_worker - self.mem_per_worker = mem_per_worker - - # Container specific - self.scheduler_mode = scheduler_mode - self.container_type = container_type - self.container_cmd_options = container_cmd_options - self.cold_routing_interval = cold_routing_interval - - # Tuning info - self.prefetch_capacity = prefetch_capacity - - self.storage_access = storage_access if storage_access is not None else [] - if len(self.storage_access) > 1: - raise ConfigurationError( - "Multiple storage access schemes are not supported" - ) - self.working_dir = working_dir - self.blocks = [] - self.cores_per_worker = cores_per_worker - self.endpoint_id = endpoint_id - self._task_counter = 0 - - if not HighThroughputEngine.is_hostname_or_ip(address): - err_msg = ( - "Expecting an interface name, hostname, IPv4 address, or IPv6 address." - ) - log.critical(err_msg) - raise ValueError(err_msg) - - self.address = address - self.worker_ports = worker_ports - self.worker_port_range = worker_port_range - self.interchange_port_range = interchange_port_range - self.heartbeat_threshold = heartbeat_threshold - self.heartbeat_period = heartbeat_period - self.poll_period = poll_period - self.suppress_failure = suppress_failure - self.run_dir = run_dir - self.queue_proc: multiprocessing.Process | None = None - self.passthrough = passthrough - self.task_status_queue = task_status_queue - self.tasks = {} - - self.outgoing_q: zmq_pipes.TasksOutgoing | None = None - self.incoming_q: zmq_pipes.ResultsIncoming | None = None - self.command_client: zmq_pipes.CommandClient | None = None - self.results_passthrough: queue.Queue | None = None - self._queue_management_thread: threading.Thread | None = None - - self.is_alive = False - - # Set the available accelerators - if available_accelerators is None: - self.available_accelerators = () - else: - if isinstance(available_accelerators, int): - self.available_accelerators = [ - str(i) for i in range(available_accelerators) - ] - else: - self.available_accelerators = list(available_accelerators) - log.debug( - "Workers will be assigned " - f"to accelerators: {self.available_accelerators}" - ) - - # Globus Compute specific options - self.container_image = container_image - self.worker_mode = worker_mode - - if not launch_cmd: - self.launch_cmd = ( - "process_worker_pool.py {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--mode={worker_mode} " - "--container_image={container_image} " - ) - - def start( - self, - *args, - endpoint_id: uuid.UUID | None = None, - run_dir: str | None = None, - results_passthrough: queue.Queue[dict[str, bytes | str | None]] | None = None, - **kwargs, - ): - """Create the Interchange process and connect to it.""" - assert run_dir, "HighThroughputEngine requires kwarg:run_dir at start" - assert endpoint_id, "HighThroughputEngine requires kwarg:endpoint_id at start" - self.run_dir = run_dir - self.endpoint_id = endpoint_id - - self.outgoing_q = zmq_pipes.TasksOutgoing( - "127.0.0.1", self.interchange_port_range - ) - self.incoming_q = zmq_pipes.ResultsIncoming( - "127.0.0.1", self.interchange_port_range - ) - self.command_client = zmq_pipes.CommandClient( - "127.0.0.1", self.interchange_port_range - ) - - self.is_alive = True - - if self.passthrough is True: - if results_passthrough is None: - raise Exception( - "Engines configured in passthrough mode, must be started with" - "a multiprocessing queue for results_passthrough" - ) - self.results_passthrough = results_passthrough - log.debug(f"Engine:{self.label} starting in results_passthrough mode") - - self._engine_bad_state = threading.Event() - self._engine_exception: t.Optional[Exception] = None - self._start_queue_management_thread() - - log.info("Attempting local interchange start") - self._start_local_interchange_process() - log.info( - "Started local interchange with ports: %s. %s", - self.worker_task_port, - self.worker_result_port, - ) - - log.debug(f"Created management thread: {self._queue_management_thread}") - - if self.provider: - pass - else: - self._scaling_enabled = False - log.debug("Starting HighThroughputEngine with no provider") - - return self.outgoing_q.port, self.incoming_q.port, self.command_client.port - - @staticmethod - def is_hostname_or_ip(hostname_or_ip: str) -> bool: - """ - Utility method to verify that the input is a valid hostname or - IP address. - """ - if not hostname_or_ip: - return False - else: - try: - socket.gethostbyname(hostname_or_ip) - return True - except socket.gaierror: - # Not a hostname, now check IP - pass - try: - ipaddress.ip_address(address=hostname_or_ip) - except ValueError: - return False - return True - - def _start_local_interchange_process(self): - """Starts the interchange process locally - - Starts the interchange process locally and uses an internal command queue to - get the worker task and result ports that the interchange has bound to. - """ - comm_q = mpQueue(maxsize=10) - self.queue_proc = Process( - target=interchange.starter, - name="Engine-Interchange", - args=(comm_q,), - kwargs={ - "client_address": "127.0.0.1", # engine and ix are on the same node - "client_ports": ( - self.outgoing_q.port, - self.incoming_q.port, - self.command_client.port, - ), - "provider": self.provider, - "strategy": self.strategy, - "poll_period": self.poll_period, - "heartbeat_period": self.heartbeat_period, - "heartbeat_threshold": self.heartbeat_threshold, - "working_dir": self.working_dir, - "worker_debug": self.worker_debug, - "max_workers_per_node": self.max_workers_per_node, - "mem_per_worker": self.mem_per_worker, - "cores_per_worker": self.cores_per_worker, - "available_accelerators": self.available_accelerators, - "prefetch_capacity": self.prefetch_capacity, - "scheduler_mode": self.scheduler_mode, - "worker_mode": self.worker_mode, - "container_type": self.container_type, - "container_cmd_options": self.container_cmd_options, - "cold_routing_interval": self.cold_routing_interval, - "interchange_address": self.address, - "worker_ports": self.worker_ports, - "worker_port_range": self.worker_port_range, - "logdir": os.path.join(self.run_dir, self.label), - "suppress_failure": self.suppress_failure, - "endpoint_id": self.endpoint_id, - }, - ) - self.queue_proc.start() - msg = None - try: - msg = comm_q.get(block=True, timeout=120) - except queue.Empty: - log.error("Interchange did not complete initialization.") - - if not msg: - # poor-person's attempt to not interweave the traceback log lines - # with the subprocess' likely traceback lines - time.sleep(0.5) - raise Exception("Interchange failed to start") - - comm_q.close() # not strictly necessary, but be plain about intentions - comm_q.join_thread() - del comm_q - - self.worker_task_port, self.worker_result_port = msg - - self.worker_task_url = f"tcp://{self.address}:{self.worker_task_port}" - self.worker_result_url = "tcp://{}:{}".format( - self.address, self.worker_result_port - ) - - def get_status_report(self) -> CommonEPStatusReport: - """HTEX Interchange reports EPStatusReport periodically""" - raise NotImplementedError - - def _queue_management_worker(self): - """Listen to the queue for task status messages and handle them. - - Depending on the message, tasks will be updated with results, exceptions, - or updates. It expects the following messages: - - .. code:: python - - { - "task_id" : - "result" : serialized result object, if task succeeded - ... more tags could be added later - } - - { - "task_id" : - "exception" : serialized exception object, on failure - } - - We do not support these yet, but they could be added easily. - - .. code:: python - - { - "task_id" : - "cpu_stat" : <> - "mem_stat" : <> - "io_stat" : <> - "started" : tstamp - } - - The `None` message is a die request. - """ - log.debug("queue management worker starting") - - while self.is_alive and not self._engine_bad_state.is_set(): - try: - msgs = self.incoming_q.get(timeout=1) - - except queue.Empty: - log.debug("queue empty") - # Timed out. - continue - - except OSError as e: - log.exception(f"Caught broken queue with exception code {e.errno}: {e}") - return - - except Exception as e: - log.exception(f"Caught unknown exception: {e}") - return - - else: - if msgs is None: - log.debug("Got None, exiting") - return - - elif isinstance(msgs, EPStatusReport): - log.debug(f"Received {msgs!r}") - if self.passthrough: - external_ep_status = convert_ep_status_report(msgs) - self.results_passthrough.put( - {"message": messagepack.pack(external_ep_status)} - ) - - else: - log.debug("Unpacking results") - for serialized_msg in msgs: - try: - msg = dill.loads(serialized_msg) - tid = msg["task_id"] - except dill.UnpicklingError: - raise BadMessage("Message received could not be unpickled") - - except Exception: - raise BadMessage( - "Message received does not contain 'task_id' field" - ) - - if tid == -1 and "exception" in msg: - # TODO: This could be handled better we are - # essentially shutting down the client with little - # indication to the user. - log.warning( - "Engine shutting down due to fatal " - "exception from interchange" - ) - self._engine_exception = serializer.deserialize( - msg["exception"] - ) - log.exception(f"Exception: {self._engine_exception}") - # Set bad state to prevent new tasks from being submitted - self._engine_bad_state.set() - # We set all current tasks to this exception to make sure - # that this is raised in the main context. - # YADU: Report failure on all pending tasks - for task_id in self.tasks: - try: - self.tasks[task_id].set_exception( - self._engine_exception - ) - except concurrent.futures.InvalidStateError: - # Task was already cancelled, the exception can be - # ignored - log.debug( - f"Task:{task_id} result couldn't be set. " - "Already in terminal state" - ) - break - - if self.passthrough is True: - log.debug(f"Pushing results for task:{tid}") - # we are only interested in actual task ids here, not - # identifiers for other message types - sent_task_id = tid if isinstance(tid, str) else None - packed_result = self._convert_result_to_outgoing( - sent_task_id, msg - ) - if packed_result is None: - continue - task_result = { - "task_id": sent_task_id, - "message": packed_result, - } - self.results_passthrough.put(task_result) - self.tasks[sent_task_id].set_result(packed_result) - continue - - try: - task_fut = self.tasks.pop(tid) - except KeyError: - # This is triggered when the result of a cancelled task is - # returned - # We should log, and proceed. - log.warning( - f"Task:{tid} not found in tasks table\n" - "Task likely was cancelled and removed." - ) - continue - - if "result" in msg: - result = serializer.deserialize(msg["result"]) - try: - task_fut.set_result(result) - except concurrent.futures.InvalidStateError: - log.debug( - f"Task:{tid} result couldn't be set. " - "Already in terminal state" - ) - elif "exception" in msg: - exception = serializer.deserialize(msg["exception"]) - try: - task_fut.set_result(exception) - except concurrent.futures.InvalidStateError: - log.debug( - f"Task:{tid} result couldn't be set. " - "Already in terminal state" - ) - else: - raise BadMessage( - "Message received is neither result or exception" - ) - - log.info("queue management worker finished") - - def _convert_result_to_outgoing( - self, task_id: str | None, msg: dict[str, t.Union[str, bytes]] - ) -> bytes | None: - """The messagepacked result should be in the result dict in the data field""" - try: - # The data body already is packed on the worker - assert isinstance(msg["data"], bytes) - return msg["data"] - except (KeyError, AssertionError) as e: - log.debug(f"Invalid result message: {msg}", exc_info=e) - if task_id is None: - return None - err_code, err_msg = get_result_error_details() - failed_result = Result( - task_id=task_id, - data=f"Task {task_id} failed to run on endpoint {self.endpoint_id}", - error_details=ResultErrorDetails(code=err_code, user_message=err_msg), - ) - return messagepack.pack(failed_result) - - # When the engine gets lost, the weakref callback will wake up - # the queue management thread. - def weakref_cb(self, q=None): - """We do not use this yet.""" - q.put(None) - - def _start_queue_management_thread(self): - """Method to start the management thread as a daemon. - - Checks if a thread already exists, then starts it. - Could be used later as a restart if the management thread dies. - """ - if self._queue_management_thread is None: - log.debug("Starting queue management thread") - self._queue_management_thread = threading.Thread( - target=self._queue_management_worker, name="Queue-Management" - ) - self._queue_management_thread.daemon = True - self._queue_management_thread.start() - log.debug("Started queue management thread") - - else: - log.debug("Management thread already exists, returning") - - def hold_worker(self, worker_id): - """Puts a worker on hold, preventing scheduling of additional tasks to it. - - This is called "hold" mostly because this only stops scheduling of tasks, - and does not actually kill the worker. - - Parameters - ---------- - - worker_id : str - Worker id to be put on hold - """ - c = self.command_client.run(f"HOLD_WORKER;{worker_id}") - log.debug(f"Sent hold request to worker: {worker_id}") - return c - - def send_heartbeat(self): - log.warning("Sending heartbeat to interchange") - msg = Heartbeat(endpoint_id="") - self.outgoing_q.put(msg.pack()) - - def wait_for_endpoint(self): - heartbeat = self.command_client.run(HeartbeatReq()) - log.debug("Attempting heartbeat to interchange") - return heartbeat - - @property - def outstanding(self): - outstanding_c = self.command_client.run("OUTSTANDING_C") - log.debug(f"Got outstanding count: {outstanding_c}") - return outstanding_c - - @property - def connected_workers(self): - workers = self.command_client.run("MANAGERS") - log.debug(f"Got managers: {workers}") - return workers - - def _submit(self, func: t.Callable, *args: t.Any, **kwargs: t.Any) -> Future: - raise RuntimeError("Invalid attempt to _submit()") - - def _get_container_location(self, packed_task: bytes) -> str: - """Unpack the task message to get the container location.""" - task_msg = messagepack.unpack(packed_task) - if not isinstance(task_msg, messagepack.message_types.Task): - raise messagepack.InvalidMessageError() - - container_loc = "RAW" - if task_msg.container: - for img in task_msg.container.images: - if img.image_type == self.container_type: - container_loc = img.location - break - - return container_loc - - def submit( - self, task_id: str, packed_task: bytes, resource_specification: t.Dict - ) -> HTEXFuture: - """Submits a messagepacked.Task for execution - - Parameters - ---------- - Packed Task (messages.Task) - A packed Task object which contains task_id, - container_id, and serialized fn, args, kwargs packages. - - Returns: - Submit status - """ - if self._engine_bad_state.is_set(): - # If the flag is set the exception body must exist - raise self._engine_exception # type: ignore - - self._task_counter += 1 - - future = HTEXFuture(self, task_id) - self.tasks[task_id] = future - - container_loc = self._get_container_location(packed_task) - ser = serializer.serialize( - ( - execute_task, - [task_id, packed_task, self.endpoint_id], - {"run_dir": self.run_dir}, - ) - ) - payload = Task(task_id, container_loc, ser).pack() - assert self.outgoing_q # Placate mypy - self.outgoing_q.put(payload) - - return future - - def _get_block_and_job_ids(self): - # Not using self.blocks.keys() and self.blocks.values() simultaneously - # The dictionary may be changed during invoking this function - # As scale_in and scale_out are invoked in multiple threads - block_ids = list(self.blocks.keys()) - job_ids = [] # types: List[Any] - for bid in block_ids: - job_ids.append(self.blocks[bid]) - return block_ids, job_ids - - @property - def connection_info(self): - """All connection info necessary for the endpoint to connect back - - Returns: - Dict with connection info - """ - return { - "address": self.address, - # A memorial to the ungodly amount of time and effort spent, - # troubleshooting the order of these ports. - "client_ports": "{},{},{}".format( - self.outgoing_q.port, self.incoming_q.port, self.command_client.port - ), - } - - @property - def scaling_enabled(self): - return self._scaling_enabled - - def scale_out(self, blocks=1): - """Scales out the number of blocks by "blocks" - - Raises: - NotImplementedError - """ - r = [] - for i in range(blocks): - if self.provider: - block = self.provider.submit(self.launch_cmd, 1, 1) - log.debug(f"Launched block {i}:{block}") - if not block: - raise ( - ScalingFailed( - self.provider.label, - "Attempts to provision nodes via provider has failed", - ) - ) - self.blocks.extend([block]) - else: - log.error("No execution provider available") - r = None - return r - - def scale_in(self, blocks): - """Scale in the number of active blocks by specified amount. - - The scale in method here is very rude. It doesn't give the workers - the opportunity to finish current tasks or cleanup. This is tracked - in issue #530 - - Raises: - NotImplementedError - """ - to_kill = self.blocks[:blocks] - if self.provider: - r = self.provider.cancel(to_kill) - return r - - def _get_job_ids(self): - return list(self.block.values()) - - def status(self): - """Return status of all blocks.""" - - status = [] - if self.provider: - status = self.provider.status(self.blocks) - - return status - - def shutdown(self, /, hub=True, targets="all", block=False, **kwargs) -> None: - """Shutdown the Engine, including all workers and controllers. - - This is not implemented. - - Kwargs: - - hub (Bool): Whether the hub should be shutdown, Default:True, - - targets (list of ints| 'all'): List of block id's to kill, Default:'all' - - block (Bool): To block for confirmations or not - - Raises: - NotImplementedError - """ - - log.info("Attempting HighThroughputEngine shutdown") - if self.queue_proc: - try: - self.queue_proc.terminate() - except AttributeError: - log.info("Engine interchange terminate skipped due to wrong context") - except Exception: - log.exception("Terminating the interchange failed") - if block: - self.queue_proc.join(timeout=5) - if self.queue_proc.exitcode is None: - self.queue_proc.kill() - log.info("Finished HighThroughputEngine shutdown attempt") - - def cancel_task(self, task_id: uuid.UUID | str) -> bool: - """ - Attempt to cancel the task `task_id` by requesting cancellation - from the interchange. Task cancellation is attempted only if the - future is cancellable (i.e., not already in a terminal state). This - relies on the Engine not setting the task to a running state, and the - task only tracking pending, and completed states. - - Parameters - ---------- - task_id - - Returns - ------- - Bool - """ - - log.debug("Send TaskCancel to interchange (%s)", task_id) - log.debug("Sending cancel of task_id:{future.task_id} to interchange") - # TODO: Doesn't yet return a bool ... - assert self.command_client - return self.command_client.run(TaskCancel(str(task_id))) - - -CANCELLED = "CANCELLED" -CANCELLED_AND_NOTIFIED = "CANCELLED_AND_NOTIFIED" -FINISHED = "FINISHED" - - -class HTEXFuture(concurrent.futures.Future): - __slots__ = ("engine", "task_id") - - def __init__(self, engine: HighThroughputEngine, task_id: str): - super().__init__() - self.engine = engine - self.task_id = task_id - - def cancel(self): - raise NotImplementedError( - f"{self.__class__} does not implement cancel() " - "try using best_effort_cancel()" - ) - - def best_effort_cancel(self): - """Attempt to cancel the function. - - If the function has finished running, the task cannot be cancelled - and the method will return False. - If the function is yet to start or is running, cancellation will be - attempted without guarantees, and the method will return True. - - Please note that a return value of True does not guarantee that your - function will not execute at all, but it does guarantee that the - future will be in a cancelled state. - - Returns - ------- - Bool - """ - return self.engine.cancel_task(self.task_id) diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange.py deleted file mode 100644 index 338dd04e1..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange.py +++ /dev/null @@ -1,1250 +0,0 @@ -from __future__ import annotations - -import argparse -import collections -import copy -import json -import logging -import os -import platform -import queue -import signal -import sys -import threading -import time -import typing as t -from collections import defaultdict - -import daemon -import dill -import zmq -from globus_compute_common.messagepack.message_types import TaskTransition -from globus_compute_common.tasks import ActorName, TaskState -from globus_compute_endpoint.engines.high_throughput.interchange_task_dispatch import ( # noqa: E501 - naive_interchange_task_dispatch, -) -from globus_compute_endpoint.engines.high_throughput.messages import ( - BadCommand, - EPStatusReport, - Heartbeat, - Message, - MessageType, -) -from globus_compute_endpoint.exception_handling import ( - get_error_string, - get_result_error_details, -) -from globus_compute_endpoint.logging_config import ComputeLogger -from globus_compute_sdk.sdk.utils import chunk_by -from globus_compute_sdk.serialize import ComputeSerializer -from parsl.version import VERSION as PARSL_VERSION - -if t.TYPE_CHECKING: - import multiprocessing as mp - -log: ComputeLogger = logging.getLogger(__name__) # type: ignore - -HEARTBEAT_CODE = (2**32) - 1 -PKL_HEARTBEAT_CODE = dill.dumps(HEARTBEAT_CODE) - - -class ManagerLost(Exception): - """Task lost due to worker loss. Worker is considered lost when multiple heartbeats - have been missed. - """ - - def __init__(self, worker_id): - self.worker_id = worker_id - self.tstamp = time.time() - - def __repr__(self): - return f"Task failure due to loss of manager {self.worker_id}" - - def __str__(self): - return self.__repr__() - - -class BadRegistration(Exception): - """A new Manager tried to join the Engine with a BadRegistration message""" - - def __init__(self, worker_id, critical=False): - self.worker_id = worker_id - self.tstamp = time.time() - self.handled = "critical" if critical else "suppressed" - - def __repr__(self): - return ( - f"Manager {self.worker_id} attempted to register with a bad " - f"registration message. Caused a {self.handled} failure" - ) - - def __str__(self): - return self.__repr__() - - -class Interchange: - """Interchange is a task orchestrator for distributed systems. - - 1. Asynchronously queue large volume of tasks (>100K) - 2. Allow for workers to join and leave the union - 3. Detect workers that have failed using heartbeats - 4. Service single and batch requests from workers - 5. Be aware of requests worker resource capacity, - eg. schedule only jobs that fit into walltime. - - TODO: We most likely need a PUB channel to send out global commands, like shutdown - """ - - def __init__( - self, - strategy=None, - poll_period=None, - heartbeat_period=None, - heartbeat_threshold=1, - working_dir=None, - provider=None, - max_workers_per_node=None, - mem_per_worker=None, - available_accelerators: t.Sequence[str] = (), - prefetch_capacity=None, - scheduler_mode=None, - container_type=None, - container_cmd_options="", - worker_mode=None, - cold_routing_interval=10.0, - scaling_enabled=True, - client_address="127.0.0.1", - interchange_address="127.0.0.1", - client_ports: tuple[int, int, int] = (50055, 50056, 50057), - worker_ports=None, - worker_port_range=None, - cores_per_worker=1.0, - worker_debug=False, - launch_cmd=None, - logdir=".", - endpoint_id=None, - suppress_failure=False, - ): - """ - Parameters - ---------- - config : globus_compute_sdk.UserEndpointConfig object - Globus Compute config object that describes how compute should - be provisioned - - client_address : str - The ip address at which the parsl client can be reached. - Default: "localhost" - - interchange_address : str - The ip address at which the workers will be able to reach the Interchange. - Default: "localhost" - - client_ports : tuple[int, int, int] - The ports at which the client can be reached - - launch_cmd : str - TODO : update - - worker_ports : tuple(int, int) - The specific two ports at which workers will connect to the Interchange. - Default: None - - worker_port_range : tuple(int, int) - The interchange picks ports at random from the range which will be used by - workers. This is overridden when the worker_ports option is set. - Default: (54000, 55000) - - cores_per_worker : float - cores to be assigned to each worker. Oversubscription is possible - by setting cores_per_worker < 1.0. Default=1 - - available_accelerators: sequence of str - List of device IDs for accelerators available on each node - Default: Empty list - - container_cmd_options: str - Container command strings to be added to associated container command. - For example, singularity exec {container_cmd_options} - - cold_routing_interval: float - The time interval between warm and cold function routing in SOFT - scheduler_mode. - It is ONLY used when using soft scheduler_mode. - We need this to avoid container workers being idle for too long. - But we dont't want this cold routing to occur too often, - since this may cause many warm container workers to switch to a new one. - Default: 10.0 seconds - - worker_debug : Bool - Enables worker debug logging. - - logdir : str - Parsl log directory paths. Logs and temp files go here. Default: '.' - - endpoint_id : str - Identity string that identifies the endpoint to the broker - - suppress_failure : Bool - When set to True, the interchange will attempt to suppress failures. - Default: False - """ - - self.logdir = logdir - os.makedirs(self.logdir, exist_ok=True) - log.info(f"Initializing Interchange process with Endpoint ID: {endpoint_id}") - - # - self.max_workers_per_node = max_workers_per_node - self.mem_per_worker = mem_per_worker - self.cores_per_worker = cores_per_worker - self.available_accelerators = available_accelerators - self.prefetch_capacity = prefetch_capacity - - self.scheduler_mode = scheduler_mode - self.container_type = container_type - self.container_cmd_options = container_cmd_options - self.worker_mode = worker_mode - self.cold_routing_interval = cold_routing_interval - - self.working_dir = working_dir - self.provider = provider - self.worker_debug = worker_debug - self.scaling_enabled = scaling_enabled - - self.strategy = strategy - self.client_address = client_address - self.interchange_address = interchange_address - self.suppress_failure = suppress_failure - - self.poll_period = poll_period - self.heartbeat_period = heartbeat_period - self.heartbeat_threshold = heartbeat_threshold - # initialize the last heartbeat time to start the loop - self.last_heartbeat = time.time() - - self.serializer = ComputeSerializer() - log.info( - "Attempting connection to forwarder at {} on ports: {},{},{}".format( - client_address, client_ports[0], client_ports[1], client_ports[2] - ) - ) - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) - self.task_incoming.set_hwm(0) - self.task_incoming.RCVTIMEO = 10 # in milliseconds - log.info(f"Task incoming on tcp://{client_address}:{client_ports[0]}") - self.task_incoming.connect(f"tcp://{client_address}:{client_ports[0]}") - - self.results_outgoing = self.context.socket(zmq.DEALER) - self.results_outgoing.set_hwm(0) - log.info(f"Results outgoing on tcp://{client_address}:{client_ports[1]}") - self.results_outgoing.connect(f"tcp://{client_address}:{client_ports[1]}") - - self.command_channel = self.context.socket(zmq.DEALER) - self.command_channel.RCVTIMEO = 1000 # in milliseconds - # self.command_channel.set_hwm(0) - log.info(f"Command _channel on tcp://{client_address}:{client_ports[2]}") - self.command_channel.connect(f"tcp://{client_address}:{client_ports[2]}") - log.info("Connected to forwarder") - - self.pending_task_queue: dict[str, queue.Queue] = {} - self.containers: dict[str, str] = {} - self.total_pending_task_count = 0 - - log.info(f"Interchange address is {self.interchange_address}") - self.worker_ports = worker_ports - self.worker_port_range = ( - worker_port_range if worker_port_range is not None else (54000, 55000) - ) - - self.task_outgoing = self.context.socket(zmq.ROUTER) - self.task_outgoing.set_hwm(0) - self.results_incoming = self.context.socket(zmq.ROUTER) - self.results_incoming.set_hwm(0) - - self.endpoint_id = endpoint_id - worker_bind_address = f"tcp://{self.interchange_address}" - log.info(f"Interchange binding worker ports to {worker_bind_address}") - if self.worker_ports: - self.worker_task_port = self.worker_ports[0] - self.worker_result_port = self.worker_ports[1] - - self.task_outgoing.bind(f"{worker_bind_address}:{self.worker_task_port}") - self.results_incoming.bind( - f"{worker_bind_address}:{self.worker_result_port}" - ) - - else: - self.worker_task_port = self.task_outgoing.bind_to_random_port( - worker_bind_address, - min_port=worker_port_range[0], - max_port=worker_port_range[1], - max_tries=100, - ) - self.worker_result_port = self.results_incoming.bind_to_random_port( - worker_bind_address, - min_port=worker_port_range[0], - max_port=worker_port_range[1], - max_tries=100, - ) - - log.info( - "Bound to ports {},{} for incoming worker connections".format( - self.worker_task_port, self.worker_result_port - ) - ) - - self._ready_manager_queue: dict[bytes, t.Any] = {} - - self.blocks: dict[str, str] = {} - self.block_id_map: dict[str, str] = {} - self.launch_cmd = launch_cmd - self.last_core_hr_counter = 0 - if not launch_cmd: - self.launch_cmd = ( - "globus-compute-manager {debug} {max_workers} " - "-c {cores_per_worker} " - "--poll {poll_period} " - "--task_url={task_url} " - "--result_url={result_url} " - "--logdir={logdir} " - "--block_id={{block_id}} " - "--hb_period={heartbeat_period} " - "--hb_threshold={heartbeat_threshold} " - "--worker_mode={worker_mode} " - "--container_cmd_options='{container_cmd_options}' " - "--scheduler_mode={scheduler_mode} " - "--worker_type={{worker_type}} " - "--available-accelerators {accelerator_list}" - ) - - self.current_platform = { - "parsl_v": PARSL_VERSION, - "python_v": "{}.{}.{}".format( - sys.version_info.major, sys.version_info.minor, sys.version_info.micro - ), - "os": platform.system(), - "hname": platform.node(), - "dir": os.getcwd(), - } - - log.info(f"Platform info: {self.current_platform}") - self._block_counter = 0 - try: - self.load_config() - except Exception: - log.exception("Caught exception") - raise - - self.task_cancel_running_queue: queue.Queue = queue.Queue() - self.task_cancel_pending_trap: dict[str, str] = {} - self.task_status_deltas: dict[str, list[TaskTransition]] = defaultdict(list) - self._task_status_delta_lock = threading.Lock() - self.container_switch_count: dict[bytes, int] = {} - - def load_config(self): - """Load the config""" - log.info("Loading endpoint local config") - working_dir = self.working_dir - if self.working_dir is None: - working_dir = os.path.join(self.logdir, "worker_logs") - log.info(f"Setting working_dir: {working_dir}") - - self.provider.script_dir = working_dir - if hasattr(self.provider, "channel"): - self.provider.channel.script_dir = os.path.join( - working_dir, "submit_scripts" - ) - os.makedirs(self.provider.channel.script_dir, mode=0o700, exist_ok=True) - os.makedirs(self.provider.script_dir, exist_ok=True) - - debug_opts = "--debug" if self.worker_debug else "" - max_workers = ( - "" - if self.max_workers_per_node == float("inf") - else f"--max_workers={self.max_workers_per_node}" - ) - - worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}" - worker_result_url = ( - f"tcp://{self.interchange_address}:{self.worker_result_port}" - ) - - l_cmd = self.launch_cmd.format( - debug=debug_opts, - max_workers=max_workers, - cores_per_worker=self.cores_per_worker, - accelerator_list=" ".join(self.available_accelerators), - # mem_per_worker=self.mem_per_worker, - prefetch_capacity=self.prefetch_capacity, - task_url=worker_task_url, - result_url=worker_result_url, - nodes_per_block=self.provider.nodes_per_block, - heartbeat_period=self.heartbeat_period, - heartbeat_threshold=self.heartbeat_threshold, - poll_period=self.poll_period, - worker_mode=self.worker_mode, - container_cmd_options=self.container_cmd_options, - scheduler_mode=self.scheduler_mode, - logdir=working_dir, - ) - - self.launch_cmd = l_cmd - log.info(f"Launch command: {self.launch_cmd}") - - if self.scaling_enabled: - log.info("Scaling ...") - self.scale_out(self.provider.init_blocks) - - def migrate_tasks_to_internal(self, kill_event): - """Pull tasks from the incoming tasks 0mq pipe onto the internal - pending task queue - - Parameters: - ----------- - kill_event : threading.Event - Event to let the thread know when it is time to die. - """ - log.info("Starting") - task_counter = 0 - poller = zmq.Poller() - poller.register(self.task_incoming, zmq.POLLIN) - - while not kill_event.is_set(): - # We are no longer doing heartbeats on the task side. - try: - raw_msg = self.task_incoming.recv() - self.last_heartbeat = time.time() - except zmq.Again: - log.trace( - "No new incoming task - %s tasks in internal queue", - self.total_pending_task_count, - ) - continue - - try: - msg = Message.unpack(raw_msg) - except Exception: - log.exception(f"Failed to unpack message, RAW:{raw_msg}") - continue - - if msg == "STOP": - # TODO: Yadu. This should be replaced by a proper MessageType - log.debug("Received STOP message.") - kill_event.set() - break - elif isinstance(msg, Heartbeat): - log.debug("Got heartbeat") - else: - log.info(f"Received task: {msg.task_id}") - local_container = msg.container_id - self.containers[local_container] = local_container - msg.set_local_container(local_container) - if local_container not in self.pending_task_queue: - self.pending_task_queue[local_container] = queue.Queue( - maxsize=10**6 - ) - - # We pass the raw message along - self.pending_task_queue[local_container].put( - { - "task_id": msg.task_id, - "container_id": msg.container_id, - "local_container": local_container, - "raw_buffer": raw_msg, - } - ) - self.total_pending_task_count += 1 - tt = TaskTransition( - timestamp=time.time_ns(), - state=TaskState.WAITING_FOR_NODES, - actor=ActorName.INTERCHANGE, - ) - - with self._task_status_delta_lock: - self.task_status_deltas[msg.task_id].append(tt) - - log.debug( - f"[TASK_PULL_THREAD] task {msg.task_id} is now WAITING_FOR_NODES" - ) - log.debug( - "[TASK_PULL_THREAD] pending task count: {}".format( - self.total_pending_task_count - ) - ) - task_counter += 1 - log.debug(f"[TASK_PULL_THREAD] Fetched task:{task_counter}") - - def get_total_tasks_outstanding(self): - """Get the outstanding tasks in total""" - outstanding = {} - for task_type in self.pending_task_queue: - outstanding[task_type] = ( - outstanding.get(task_type, 0) - + self.pending_task_queue[task_type].qsize() - ) - for manager in self._ready_manager_queue: - for task_type in self._ready_manager_queue[manager]["tasks"]: - outstanding[task_type] = outstanding.get(task_type, 0) + len( - self._ready_manager_queue[manager]["tasks"][task_type] - ) - return outstanding - - def get_total_live_workers(self): - """Get the total active workers""" - active = 0 - for manager in self._ready_manager_queue: - if self._ready_manager_queue[manager]["active"]: - active += self._ready_manager_queue[manager]["max_worker_count"] - return active - - def get_outstanding_breakdown(self): - """Get outstanding breakdown per manager and in the interchange queues - - Returns - ------- - List of status for online elements - [ (element, tasks_pending, status) ... ] - """ - - pending_on_interchange = self.total_pending_task_count - # Reporting pending on interchange is a deviation from Parsl - reply = [("interchange", pending_on_interchange, True)] - for manager in self._ready_manager_queue: - resp = ( - manager.decode("utf-8"), - sum( - len(tids) - for tids in self._ready_manager_queue[manager]["tasks"].values() - ), - self._ready_manager_queue[manager]["active"], - ) - reply.append(resp) - return reply - - def _hold_block(self, block_id): - """Sends hold command to all managers which are in a specific block - - Parameters - ---------- - block_id : str - Block identifier of the block to be put on hold - """ - for manager in self._ready_manager_queue: - if ( - self._ready_manager_queue[manager]["active"] - and self._ready_manager_queue[manager]["block_id"] == block_id - ): - log.debug(f"[HOLD_BLOCK]: Sending hold to manager: {manager}") - self.hold_manager(manager) - - def hold_manager(self, manager): - """Put manager on hold - Parameters - ---------- - - manager : str - Manager id to be put on hold while being killed - """ - if manager in self._ready_manager_queue: - self._ready_manager_queue[manager]["active"] = False - - def _status_report_loop(self, kill_event, status_report_queue: queue.Queue): - log.info(f"Endpoint id: {self.endpoint_id}") - - def _enqueue_status_report(ep_state: dict, task_states: dict): - try: - msg = EPStatusReport(str(self.endpoint_id), ep_state, dict(task_states)) - status_report_queue.put(msg.pack()) - except Exception: - log.exception("Unable to create or send EP status report.") - log.debug("Attempted to send chunk: %s", tsd_chunk) - # ignoring so that the thread continues; "it's just a status" - - while True: - with self._task_status_delta_lock: - task_status_deltas = copy.deepcopy(self.task_status_deltas) - self.task_status_deltas.clear() - - log.debug( - "Cleared task deltas (%s); sending status report to Engine.", - len(task_status_deltas), - ) - - # For multi-chunk reports ("lots-o-tasks"), the state won't change *that* - # much each iteration, so cache the result - global_state = self.get_global_state_for_status_report() - - # The result processor will gracefully handle any size message, but - # courtesy says to chunk work; 4,096 is empirically chosen to be plenty - # "bulk enough," but not rude. - for tsd_chunk in chunk_by(task_status_deltas.items(), 4_096): - _enqueue_status_report(global_state, dict(tsd_chunk)) - - if not task_status_deltas: - _enqueue_status_report(global_state, {}) - - del task_status_deltas # free some memory in "large case" - - if kill_event.wait(self.heartbeat_period): - break - - def _command_server(self, kill_event): - """Command server to run async command to the interchange - - We want to be able to receive the following not yet implemented/updated - commands: - - OutstandingCount - - ListManagers (get outstanding broken down by manager) - - HoldWorker - - Shutdown - """ - log.debug("Command Server Starting") - - while not kill_event.is_set(): - try: - buffer = self.command_channel.recv() - log.debug(f"Received command request {buffer}") - command = Message.unpack(buffer) - - if command.type is MessageType.TASK_CANCEL: - log.info(f"Received TASK_CANCEL for Task:{command.task_id}") - self.enqueue_task_cancel(command.task_id) - reply = command - - elif command.type is MessageType.HEARTBEAT_REQ: - log.info("Received synchonous HEARTBEAT_REQ from hub") - log.info(f"Replying with Heartbeat({self.endpoint_id})") - reply = Heartbeat(self.endpoint_id) - - else: - log.error( - f"Received unsupported message type:{command.type} on " - "command _channel" - ) - reply = BadCommand(f"Unknown command type: {command.type}") - - log.debug(f"Reply: {reply}") - self.command_channel.send(reply.pack()) - - except zmq.Again: - log.trace("Command server is alive") - continue - - def enqueue_task_cancel(self, task_id): - """Cancel a task on the interchange - Here are the task states and responses we issue here - 1. Task is pending in queues -> we add task to a trap to capture while in - dispatch and delegate cancel to the manager the task is assigned to - 2. Task is in a transitionary state between pending in queue and dispatched -> - task is added pre-emptively to trap - 3. Task is pending on a manager -> we delegate cancellation to manager - 4. Task is already complete -> we leave in trap, since we can't know - - We place the task in the trap so that even if the search misses, the task - will be caught from getting dispatched even if the search fails due to a - race-condition. Since the task can't be dispatched before scheduling is - complete, either must work. - """ - log.debug(f"Received task_cancel request for Task:{task_id}") - - self.task_cancel_pending_trap[task_id] = task_id - for manager in self._ready_manager_queue: - for task_type in self._ready_manager_queue[manager]["tasks"]: - for tid in self._ready_manager_queue[manager]["tasks"][task_type]: - if tid == task_id: - log.debug( - f"Task:{task_id} is running, " - "moving task_cancel message onto queue" - ) - self.task_cancel_running_queue.put((manager, task_id)) - self.task_cancel_pending_trap.pop(task_id, None) - break - return - - def handle_sigterm(self, sig_num, curr_stack_frame): - log.warning("Received SIGTERM, stopping") - self.stop() - - def stop(self): - """Prepare the interchange for shutdown""" - self._kill_event.set() - - self._task_puller_thread.join() - self._command_thread.join() - self._status_report_thread.join() - log.debug("HighThroughput Interchange stopped") - - def start(self, poll_period: int | None = None) -> None: - """Start the Interchange - - Parameters: - ---------- - poll_period : int - poll_period in milliseconds - """ - signal.signal(signal.SIGTERM, self.handle_sigterm) - log.info("Incoming ports bound") - - if poll_period is None: - poll_period = self.poll_period - - start = time.time() - count = 0 - - self._kill_event = threading.Event() - self._task_puller_thread = threading.Thread( - target=self.migrate_tasks_to_internal, - args=(self._kill_event,), - name="TASK_PULL_THREAD", - ) - self._task_puller_thread.start() - - self._command_thread = threading.Thread( - target=self._command_server, args=(self._kill_event,), name="COMMAND_THREAD" - ) - self._command_thread.start() - - status_report_queue: queue.Queue[bytes] = queue.Queue() - self._status_report_thread = threading.Thread( - target=self._status_report_loop, - args=(self._kill_event, status_report_queue), - name="STATUS_THREAD", - ) - self._status_report_thread.start() - - try: - log.info("Starting strategy.") - self.strategy.start(self) - except RuntimeError: - # This is raised when re-registering an endpoint as strategy already exists - log.exception("Failed to start strategy.") - - poller = zmq.Poller() - # poller.register(self.task_incoming, zmq.POLLIN) - poller.register(self.task_outgoing, zmq.POLLIN) - poller.register(self.results_incoming, zmq.POLLIN) - - # These are managers which we should examine in an iteration - # for scheduling a job (or maybe any other attention?). - # Anything altering the state of the manager should add it - # onto this list. - interesting_managers: set[bytes] = set() - - # This value records when the last cold routing in soft mode happens - # When the cold routing in soft mode happens, it may cause worker containers to - # switch - # Cold routing is to reduce the number idle workers of specific task types on - # the managers when there are not enough tasks of those types in the task queues - # on interchange - last_cold_routing_time = time.time() - prev_manager_stat = None - - task_deltas_to_merge: dict[str, list[TaskTransition]] = defaultdict(list) - - while not self._kill_event.is_set(): - self.socks = dict(poller.poll(timeout=poll_period)) - - # Listen for requests for work - if ( - self.task_outgoing in self.socks - and self.socks[self.task_outgoing] == zmq.POLLIN - ): - log.trace("starting task_outgoing section") - message = self.task_outgoing.recv_multipart() - manager = message[0] - - mdata = self._ready_manager_queue.get(manager) - if not mdata: - reg_flag = False - - try: - msg = json.loads(message[1].decode("utf-8")) - reg_flag = True - except Exception: - log.warning( - "Got a non-json registration message from manager:%s", - manager, - ) - log.debug("Message :\n%s\n", message) - - # By default we set up to ignore bad nodes/registration messages. - now = time.time() - mdata = { - "last": now, - "reg_time": now, - "free_capacity": {"total_workers": 0}, - "max_worker_count": 0, - "active": True, - "tasks": collections.defaultdict(set), - "total_tasks": 0, - } - if reg_flag is True: - interesting_managers.add(manager) - log.info( - f"Add manager to ready queue: {manager!r}" - f"\n Registration info: {msg})" - ) - mdata.update(msg) - self._ready_manager_queue[manager] = mdata - - if ( - msg["python_v"].rsplit(".", 1)[0] - != self.current_platform["python_v"].rsplit(".", 1)[0] - or msg["parsl_v"] != self.current_platform["parsl_v"] - ): - log.info( - f"Manager:{manager!r} version:{msg['python_v']} " - "does not match the interchange" - ) - else: - # Registration has failed. - if self.suppress_failure is False: - log.debug("Setting kill event for bad manager") - self._kill_event.set() - e = BadRegistration(manager, critical=True) - result_package = { - "task_id": -1, - "exception": self.serializer.serialize(e), - } - pkl_package = dill.dumps(result_package) - self.results_outgoing.send(dill.dumps([pkl_package])) - else: - log.debug( - "Suppressing bad registration from manager: %s", - manager, - ) - - else: - mdata["last"] = time.time() - if message[1] == b"HEARTBEAT": - log.debug("Manager %s sends heartbeat", manager) - self.task_outgoing.send_multipart( - [manager, b"", PKL_HEARTBEAT_CODE] - ) - else: - manager_adv = dill.loads(message[1]) - log.debug("Manager %s requested %s", manager, manager_adv) - manager_adv["total_workers"] = sum(manager_adv["free"].values()) - mdata["free_capacity"].update(manager_adv) - interesting_managers.add(manager) - del manager_adv - - # If we had received any requests, check if there are tasks that could be - # passed - - cur_manager_stat = len(self._ready_manager_queue), len(interesting_managers) - if cur_manager_stat != prev_manager_stat: - prev_manager_stat = cur_manager_stat - _msg = "[MAIN] New managers count (total/interesting): {}/{}" - log.debug(_msg.format(*cur_manager_stat)) - - if time.time() - last_cold_routing_time > self.cold_routing_interval: - task_dispatch, dispatched_task = naive_interchange_task_dispatch( - interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=True, - ) - last_cold_routing_time = time.time() - else: - task_dispatch, dispatched_task = naive_interchange_task_dispatch( - interesting_managers, - self.pending_task_queue, - self._ready_manager_queue, - scheduler_mode=self.scheduler_mode, - cold_routing=False, - ) - - self.total_pending_task_count -= dispatched_task - - # Task cancel is high priority, so we'll process all requests - # in one go - try: - while True: - manager, task_id = self.task_cancel_running_queue.get(block=False) - log.debug( - "CANCELLED running task (id: %s, manager: %s)", task_id, manager - ) - cancel_message = dill.dumps(("TASK_CANCEL", task_id)) - self.task_outgoing.send_multipart([manager, b"", cancel_message]) - except queue.Empty: - pass - - for manager in task_dispatch: - tasks = task_dispatch[manager] - if tasks: - log.info( - 'Sending task message "{}..." to manager {!r}'.format( - str(tasks)[:50], manager - ) - ) - serializd_raw_tasks_buffer = dill.dumps(tasks) - self.task_outgoing.send_multipart( - [manager, b"", serializd_raw_tasks_buffer] - ) - - for task in tasks: - task_id = task["task_id"] - log.info(f"Sent task {task_id} to manager {manager!r}") - if ( - self.task_cancel_pending_trap - and task_id in self.task_cancel_pending_trap - ): - log.info(f"Task:{task_id} CANCELLED before launch") - cancel_message = dill.dumps(("TASK_CANCEL", task_id)) - self.task_outgoing.send_multipart( - [manager, b"", cancel_message] - ) - self.task_cancel_pending_trap.pop(task_id) - else: - log.debug("Task:%s is now WAITING_FOR_LAUNCH", task_id) - tt = TaskTransition( - timestamp=time.time_ns(), - state=TaskState.WAITING_FOR_LAUNCH, - actor=ActorName.INTERCHANGE, - ) - task_deltas_to_merge[task_id].append(tt) - - if task_deltas_to_merge: - with self._task_status_delta_lock: - for task_id, deltas in task_deltas_to_merge.items(): - self.task_status_deltas[task_id].extend(deltas) - task_deltas_to_merge.clear() - - # Receive any results and forward to client - if ( - self.results_incoming in self.socks - and self.socks[self.results_incoming] == zmq.POLLIN - ): - log.debug("entering results_incoming section") - manager, *b_messages = self.results_incoming.recv_multipart() - mdata = self._ready_manager_queue.get(manager) - if not mdata: - log.warning( - "Received a result from a un-registered manager: %s", - manager, - ) - else: - # We expect the batch of messages to be (optionally) a task status - # update message followed by 0 or more task results - try: - log.debug("Trying to unpack") - manager_report = Message.unpack(b_messages[0]) - if manager_report.task_statuses: - log.info( - "Got manager status report: %s", - manager_report.task_statuses, - ) - - for tid, statuses in manager_report.task_statuses.items(): - task_deltas_to_merge[tid].extend(statuses) - - self.task_outgoing.send_multipart( - [manager, b"", PKL_HEARTBEAT_CODE] - ) - b_messages = b_messages[1:] - mdata["last"] = time.time() - self.container_switch_count[manager] = ( - manager_report.container_switch_count - ) - log.info( - "Got container switch count: %s", - self.container_switch_count, - ) - except Exception: - pass - if len(b_messages): - log.info(f"Got {len(b_messages)} result items in batch") - for idx, b_message in enumerate(b_messages): - r = dill.loads(b_message) - tid = r["task_id"] - - log.debug("Received task result %s (from %s)", tid, manager) - task_container = self.containers[r["container_id"]] - log.debug( - "Removing for manager: %s from %s", - manager, - self._ready_manager_queue, - ) - - mdata["tasks"][task_container].remove(tid) - b_messages[idx] = dill.dumps(r) - - mdata["total_tasks"] -= len(b_messages) - - self.results_outgoing.send(dill.dumps(b_messages)) - interesting_managers.add(manager) - - log.debug(f"Current tasks: {mdata['tasks']}") - log.debug("leaving results_incoming section") - - # Send status reports from this main thread to avoid thread-safety on zmq - # sockets - try: - packed_status_report = status_report_queue.get(block=False) - log.trace("forwarding status report: %s", packed_status_report) - self.results_outgoing.send(packed_status_report) - except queue.Empty: - pass - - now = time.time() - hbt_window_start = now - self.heartbeat_threshold - bad_managers = [ - manager - for manager, mdata in self._ready_manager_queue.items() - if hbt_window_start > mdata["last"] - ] - bad_manager_msgs = [] - for manager in bad_managers: - log.debug( - "Last: %s Current: %s", - self._ready_manager_queue[manager]["last"], - now, - ) - log.warning(f"Too many heartbeats missed for manager {manager!r}") - for tasks in self._ready_manager_queue[manager]["tasks"].values(): - for tid in tasks: - try: - raise ManagerLost(manager) - except Exception: - result_package = { - "task_id": tid, - "exception": get_error_string(), - "error_details": get_result_error_details(), - } - pkl_package = dill.dumps(result_package) - bad_manager_msgs.append(pkl_package) - log.warning(f"Unregistering manager {manager!r}") - self._ready_manager_queue.pop(manager, None) - if manager in interesting_managers: - interesting_managers.remove(manager) - if bad_manager_msgs: - log.warning(f"Sending task failure reports of manager {manager!r}") - self.results_outgoing.send(dill.dumps(bad_manager_msgs)) - - delta = time.time() - start - log.info(f"Processed {count} tasks in {delta} seconds") - log.warning("Exiting") - - def get_global_state_for_status_report(self): - outstanding_tasks = self.get_total_tasks_outstanding() - pending_tasks = self.total_pending_task_count - num_managers = len(self._ready_manager_queue) - live_workers = self.get_total_live_workers() - free_capacity = sum( - m["free_capacity"]["total_workers"] - for m in self._ready_manager_queue.values() - ) - - return { - "managers": num_managers, - "total_workers": live_workers, - "idle_workers": free_capacity, - "pending_tasks": pending_tasks, - "outstanding_tasks": outstanding_tasks, - "heartbeat_period": self.heartbeat_period, - } - - def scale_out(self, blocks=1, task_type=None): - """Scales out the number of blocks by "blocks" - - Raises: - NotImplementedError - """ - log.info(f"Scaling out by {blocks} more blocks for task type {task_type}") - r = [] - for _i in range(blocks): - if self.provider: - self._block_counter += 1 - external_block_id = str(self._block_counter) - if not task_type and self.scheduler_mode == "hard": - launch_cmd = self.launch_cmd.format( - block_id=external_block_id, worker_type="RAW" - ) - else: - launch_cmd = self.launch_cmd.format( - block_id=external_block_id, worker_type=task_type - ) - if not task_type: - internal_block = self.provider.submit(launch_cmd, 1) - else: - internal_block = self.provider.submit(launch_cmd, 1, task_type) - log.debug(f"Launched block {external_block_id}->{internal_block}") - if not internal_block: - raise RuntimeError( - "Attempt to provision nodes via provider " - f"{self.provider.label} has failed" - ) - self.blocks[external_block_id] = internal_block - self.block_id_map[internal_block] = external_block_id - else: - log.error("No execution provider available") - r = None - return r - - def scale_in(self, blocks=None, block_ids=None, task_type=None): - """Scale in the number of active blocks by specified amount. - - Parameters - ---------- - blocks : int - # of blocks to terminate - - block_ids : [str.. ] - List of external block ids to terminate - """ - if block_ids is None: - block_ids = [] - if task_type: - log.info( - "Scaling in blocks of specific task type %s. Let the provider decide " - "which to kill", - task_type, - ) - if self.scaling_enabled and self.provider: - to_kill, r = self.provider.cancel(blocks, task_type) - log.info(f"Get the killed blocks: {to_kill}, and status: {r}") - for job in to_kill: - log.info( - "[scale_in] Getting the block_id map {} for job {}".format( - self.block_id_map, job - ) - ) - block_id = self.block_id_map[job] - log.info(f"[scale_in] Holding block {block_id}") - self._hold_block(block_id) - self.blocks.pop(block_id) - return r - - if block_ids: - block_ids_to_kill = block_ids - else: - block_ids_to_kill = list(self.blocks.keys())[:blocks] - - # Try a polite terminate - # TODO : Missing logic to hold blocks - for block_id in block_ids_to_kill: - self._hold_block(block_id) - - # Now kill via provider - to_kill = [self.blocks.pop(bid) for bid in block_ids_to_kill] - - if self.scaling_enabled and self.provider: - r = self.provider.cancel(to_kill) - - return r - - def provider_status(self): - """Get status of all blocks from the provider. The return type is - defined by the particular provider in use. - """ - status = [] - if self.provider: - job_ids: list[str] = list(self.blocks.values()) - log.trace("Getting the status of %s blocks.", job_ids) - status = self.provider.status(job_ids) - log.trace("The status is %s", status) - - return status - - -def starter(comm_q: mp.Queue, *args, **kwargs) -> None: - """Start the interchange process - - The Engine is expected to call this function. The args, kwargs match that of the - Interchange.__init__ - """ - ic = None - try: - ic = Interchange(*args, **kwargs) - comm_q.put((ic.worker_task_port, ic.worker_result_port)) - finally: - if not ic: # There was an exception - comm_q.put(None) - - # no sense in having the queue open past it's usefulness - comm_q.close() - comm_q.join_thread() - del comm_q - - if ic: - ic.start() - - -def cli_run(): - from globus_compute_endpoint.logging_config import setup_logging - - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--client_address", required=True, help="Client address") - parser.add_argument( - "--client_ports", - required=True, - help="client ports as a triple of outgoing,incoming,command", - ) - parser.add_argument("--worker_port_range", help="Worker port range as a tuple") - parser.add_argument( - "-l", - "--logdir", - default="./parsl_worker_logs", - help="Parsl worker log directory", - ) - parser.add_argument( - "-p", "--poll_period", help="REQUIRED: poll period used for main thread" - ) - parser.add_argument( - "--worker_ports", - default=None, - help="OPTIONAL, pair of workers ports to listen on, " - "eg --worker_ports=50001,50005", - ) - parser.add_argument( - "--suppress_failure", - action="store_true", - help="Enables suppression of failures", - ) - parser.add_argument( - "--endpoint_id", - default=None, - help="Endpoint ID, used to identify the endpoint to the remote broker", - ) - parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") - parser.add_argument( - "--config", - default=None, - help="Configuration object that describes provisioning", - ) - parser.add_argument( - "-d", "--debug", action="store_true", help="Enables debug logging" - ) - - print("Starting HTEX Intechange") - - args = parser.parse_args() - - args.logdir = os.path.abspath(args.logdir) - if args.worker_ports: - args.worker_ports = [int(i) for i in args.worker_ports.split(",")] - if args.worker_port_range: - args.worker_port_range = [int(i) for i in args.worker_port_range.split(",")] - - setup_logging( - logfile=os.path.join(args.logdir, "interchange.log"), - debug=args.debug, - console_enabled=False, - ) - - with daemon.DaemonContext(): - ic = Interchange( - logdir=args.logdir, - suppress_failure=args.suppress_failure, - client_address=args.client_address, - client_ports=[int(i) for i in args.client_ports.split(",")], - endpoint_id=args.endpoint_id, - config=args.config, - worker_ports=args.worker_ports, - worker_port_range=args.worker_port_range, - ) - ic.start() diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange_task_dispatch.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange_task_dispatch.py deleted file mode 100644 index dd3c57489..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/interchange_task_dispatch.py +++ /dev/null @@ -1,256 +0,0 @@ -from __future__ import annotations - -import collections -import logging -import queue -import random - -from globus_compute_endpoint.logging_config import ComputeLogger - -log: ComputeLogger = logging.getLogger(__name__) # type: ignore -log.info("Interchange task dispatch started") - - -def naive_interchange_task_dispatch( - interesting_managers: set[bytes], - pending_task_queue: dict[str, queue.Queue[dict]], - ready_manager_queue: dict[bytes, dict], - scheduler_mode: str = "hard", - cold_routing: bool = False, -) -> tuple[dict[bytes, list], int]: - """ - This is an initial task dispatching algorithm for interchange. - It returns a dictionary, whose key is manager, and the value is the list of tasks - to be sent to manager, and the total number of dispatched tasks. - """ - task_dispatch: dict[bytes, list] = {} - dispatched_tasks = 0 - if scheduler_mode == "hard": - dispatched_tasks += dispatch( - task_dispatch, - interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode="hard", - ) - - elif scheduler_mode == "soft": - loops = ["warm"] if not cold_routing else ["warm", "cold"] - for loop in loops: - dispatched_tasks += dispatch( - task_dispatch, - interesting_managers, - pending_task_queue, - ready_manager_queue, - scheduler_mode="soft", - loop=loop, - ) - return task_dispatch, dispatched_tasks - - -def dispatch( - task_dispatch: dict[bytes, list], - interesting_managers: set[bytes], - pending_task_queue: dict[str, queue.Queue[dict]], - ready_manager_queue: dict[bytes, dict], - scheduler_mode: str = "hard", - loop: str = "warm", -) -> int: - """ - This is the core task dispatching algorithm for interchange. - The algorithm depends on the scheduler mode and which loop. - """ - dispatched_tasks = 0 - if interesting_managers: - shuffled_managers = list(interesting_managers) - random.shuffle(shuffled_managers) - for manager in shuffled_managers: - mdata = ready_manager_queue[manager] - tasks_inflight = mdata["total_tasks"] - real_capacity: int = min( - mdata["free_capacity"]["total_workers"], - mdata["max_worker_count"] - tasks_inflight, - ) - if not (real_capacity > 0 and mdata["active"]): - interesting_managers.remove(manager) - continue - - if scheduler_mode == "hard": - tasks, tids = get_tasks_hard(pending_task_queue, mdata, real_capacity) - else: - tasks, tids = get_tasks_soft( - pending_task_queue, - mdata, - real_capacity, - loop=loop, - ) - if tasks: - log.debug("Got %s tasks from queue", len(tasks)) - for task_type in tids: - # This line is a set update, not dict update - mdata["tasks"][task_type].update(tids[task_type]) - log.debug(f"The tasks on manager %s is {mdata['tasks']}", manager) - mdata["total_tasks"] += len(tasks) - if manager not in task_dispatch: - task_dispatch[manager] = [] - task_dispatch[manager] += tasks - dispatched_tasks += len(tasks) - log.debug("Assigned tasks %s to manager %s", tids, manager) - if mdata["free_capacity"]["total_workers"] > 0: - log.trace( - "Manager %s still has free_capacity %s", - manager, - mdata["free_capacity"]["total_workers"], - ) - else: - log.debug("Manager %s is now saturated", manager) - interesting_managers.remove(manager) - - log.trace( - "The task dispatch of %s loop is %s, in total %s tasks", - loop, - task_dispatch, - dispatched_tasks, - ) - return dispatched_tasks - - -def get_tasks_hard( - pending_task_queue: dict[str, queue.Queue[dict]], - manager_ads: dict, - real_capacity: int, -) -> tuple[list[dict], dict[str, set[str]]]: - tasks: list[dict] = [] - tids: dict[str, set[str]] = collections.defaultdict(set) - task_type: str = manager_ads["worker_type"] - if not task_type: - log.warning( - "Using hard scheduler mode but with manager worker type unset. " - "Use soft scheduler mode. Set this in the config." - ) - return tasks, tids - task_q = pending_task_queue.get(task_type) - if not task_q: - log.trace("No task of type %s. Exiting task fetching.", task_type) - return tasks, tids - - # dispatch tasks of available types on manager - free_cap = manager_ads["free_capacity"] - if task_type in free_cap["free"]: - try: - while real_capacity > 0 and free_cap["free"][task_type] > 0: - x = task_q.get(block=False) - log.debug(f"Get task {x}") - tasks.append(x) - tids[task_type].add(x["task_id"]) - free_cap["free"][task_type] -= 1 - free_cap["total_workers"] -= 1 - real_capacity -= 1 - except queue.Empty: - pass - - # dispatch tasks to unused slots based on the manager type - log.trace("Second round of task fetching in hard mode") - try: - while real_capacity > 0 and free_cap["free"]["unused"] > 0: - x = task_q.get(block=False) - log.debug(f"Get task {x}") - tasks.append(x) - tids[task_type].add(x["task_id"]) - free_cap["free"]["unused"] -= 1 - free_cap["total_workers"] -= 1 - real_capacity -= 1 - except queue.Empty: - pass - return tasks, tids - - -def get_tasks_soft( - pending_task_queue: dict[str, queue.Queue[dict]], - manager_ads: dict, - real_capacity: int, - loop: str = "warm", -) -> tuple[list[dict], dict[str, set[str]]]: - tasks = [] - tids = collections.defaultdict(set) - - # Warm routing to dispatch tasks - free_cap = manager_ads["free_capacity"] - if loop == "warm": - for task_type in free_cap["free"]: - # Dispatch tasks that are of the available container types on the manager - if task_type != "unused": - task_q = pending_task_queue.get(task_type) - if not task_q: - continue - type_inflight = len(manager_ads["tasks"].get(task_type, set())) - type_capacity = min( - free_cap["free"][task_type], - free_cap["total"][task_type] - type_inflight, - ) - try: - while ( - real_capacity > 0 - and type_capacity > 0 - and free_cap["free"][task_type] > 0 - ): - x = task_q.get(block=False) - log.debug(f"Get task {x}") - tasks.append(x) - tids[task_type].add(x["task_id"]) - free_cap["free"][task_type] -= 1 - free_cap["total_workers"] -= 1 - real_capacity -= 1 - type_capacity -= 1 - except queue.Empty: - pass - # Dispatch tasks to unused container slots on the manager - else: - task_q = pending_task_queue.get(task_type) # "unused" queue - if not task_q: - log.debug("Unexpectedly non-existent 'unused' queue") - continue - task_types = list(pending_task_queue.keys()) - random.shuffle(task_types) - for task_type in task_types: - try: - while ( - real_capacity > 0 - and free_cap["free"]["unused"] > 0 - and free_cap["total_workers"] > 0 - ): - x = task_q.get(block=False) - log.debug(f"Get task {x}") - tasks.append(x) - tids[task_type].add(x["task_id"]) - free_cap["free"]["unused"] -= 1 - free_cap["total_workers"] -= 1 - real_capacity -= 1 - except queue.Empty: - pass - return tasks, tids - - # Cold routing round: allocate tasks of random types - # to workers that are of different types on the manager - # This will possibly cause container switching on the manager - # This is needed to avoid workers being idle for too long - # Potential issues may be that it could kill containers of short tasks frequently - # Tune cold_routing_interval in the config to balance such a tradeoff - log.debug("Cold function routing!") - task_types = list(pending_task_queue.keys()) - random.shuffle(task_types) - for task_type in task_types: - task_q = pending_task_queue.get(task_type) - if not task_q: - continue - try: - while real_capacity > 0 and free_cap["total_workers"] > 0: - x = task_q.get(block=False) - tasks.append(x) - tids[task_type].add(x["task_id"]) - free_cap["total_workers"] -= 1 - real_capacity -= 1 - log.debug(f"Get task {x}") - except queue.Empty: - pass - return tasks, tids diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/mac_safe_queue.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/mac_safe_queue.py deleted file mode 100644 index 1444d150f..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/mac_safe_queue.py +++ /dev/null @@ -1,12 +0,0 @@ -import multiprocessing as mp -import platform -import typing as t - -mpQueue: t.Type[mp.Queue] - -if platform.system() == "Darwin": - from parsl.multiprocessing import MacSafeQueue as mpQueue -else: - from multiprocessing import Queue as mpQueue - -__all__ = ("mpQueue",) diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/manager.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/manager.py deleted file mode 100755 index bf51f243e..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/manager.py +++ /dev/null @@ -1,926 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import json -import logging -import math -import multiprocessing -import os -import platform -import queue -import subprocess -import sys -import threading -import time -import uuid -from collections import defaultdict -from typing import Any - -import dill -import psutil -import zmq -from globus_compute_common.messagepack.message_types import TaskTransition -from globus_compute_common.tasks import ActorName, TaskState -from globus_compute_endpoint.engines.high_throughput.container_sched import ( - naive_scheduler, -) -from globus_compute_endpoint.engines.high_throughput.mac_safe_queue import mpQueue -from globus_compute_endpoint.engines.high_throughput.messages import ( - ManagerStatusReport, - Message, - Task, -) -from globus_compute_endpoint.engines.high_throughput.worker_map import WorkerMap -from globus_compute_endpoint.exception_handling import ( - get_error_string, - get_result_error_details, -) -from globus_compute_endpoint.logging_config import ComputeLogger, setup_logging -from parsl.version import VERSION as PARSL_VERSION - -RESULT_TAG = 10 -TASK_REQUEST_TAG = 11 -HEARTBEAT_CODE = (2**32) - 1 - -log: ComputeLogger = logging.getLogger(__name__) # type: ignore - - -class TaskCancelled(Exception): - """Task is cancelled by user request.""" - - def __init__(self, worker_id, manager_id): - self.worker_id = worker_id - self.manager_id = manager_id - self.tstamp = time.time() - - def __str__(self): - return ( - "Task cancelled based on user request on manager: " - f"{self.manager_id}, worker: {self.worker_id}" - ) - - -class Manager: - """Manager manages task execution by the workers - - | 0mq | Manager | Worker Processes - | | | - | <-----Request N task-----+--Count task reqs | Request task<--+ - Interchange | -------------------------+->Receive task batch| | | - | | Distribute tasks--+----> Get(block) & | - | | | Execute task | - | | | | | - | <------------------------+--Return results----+---- Post result | - | | | | | - | | | +----------+ - | | IPC-Qeueues - - """ - - def __init__( - self, - task_q_url="tcp://localhost:50097", - result_q_url="tcp://localhost:50098", - max_queue_size=10, - cores_per_worker=1, - available_accelerators: list[str] | None = None, - max_workers=float("inf"), - uid=None, - heartbeat_threshold=120, - heartbeat_period=30, - logdir=None, - debug=False, - block_id=None, - internal_worker_port_range=(50000, 60000), - worker_mode="singularity_reuse", - container_cmd_options="", - scheduler_mode="hard", - worker_type=None, - worker_max_idletime=60, - # TODO : This should be 10ms - poll_period=100, - ): - """ - Parameters - ---------- - worker_url : str - Worker url on which workers will attempt to connect back - - uid : str - string unique identifier - - cores_per_worker : float - cores to be assigned to each worker. Oversubscription is possible - by setting cores_per_worker < 1.0. Default=1 - - available_accelerators: list of strings - Accelerators available for workers to use. - default: empty list - - max_workers : int - caps the maximum number of workers that can be launched. - default: infinity - - heartbeat_threshold : int - Seconds since the last message from the interchange after which the - interchange is assumed to be un-available, and the manager initiates - shutdown. Default:120s - - Number of seconds since the last message from the interchange after which - the worker assumes that the interchange is lost and the manager shuts down. - Default:120 - - heartbeat_period : int - Number of seconds after which a heartbeat message is sent to the - interchange - - internal_worker_port_range : tuple(int, int) - Port range from which the port(s) for the workers to connect to the manager - is picked. - Default: (50000,60000) - - worker_mode : str - Pick between 3 supported modes for the worker: - 1. no_container : Worker launched without containers - 2. singularity_reuse : Worker launched inside a singularity container that - will be reused - 3. singularity_single_use : Each worker and task runs inside a new - container instance. - - container_cmd_options: str - Container command strings to be added to associated container command. - For example, singularity exec {container_cmd_options} - - scheduler_mode : str - Pick between 2 supported modes for the manager: - 1. hard: the manager cannot change the launched container type - 2. soft: the manager can decide whether to launch different containers - - worker_type : str - If set, the worker type for this manager is fixed. Default: None - - poll_period : int - Timeout period used by the manager in milliseconds. Default: 10ms - """ - log.info("Manager started") - - self.context = zmq.Context() - self.task_incoming = self.context.socket(zmq.DEALER) - self.task_incoming.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) - # Linger is set to 0, so that the manager can exit even when there might be - # messages in the pipe - self.task_incoming.setsockopt(zmq.LINGER, 0) - self.task_incoming.setsockopt(zmq.IPV6, True) - self.task_incoming.connect(task_q_url) - - self.logdir = logdir - self.debug = debug - self.block_id = block_id - self.result_outgoing = self.context.socket(zmq.DEALER) - self.result_outgoing.setsockopt(zmq.IDENTITY, uid.encode("utf-8")) - self.result_outgoing.setsockopt(zmq.LINGER, 0) - self.result_outgoing.setsockopt(zmq.IPV6, True) - self.result_outgoing.connect(result_q_url) - - log.info("Manager connected") - - self.uid = uid - - self.worker_mode = worker_mode - self.container_cmd_options = container_cmd_options - self.scheduler_mode = scheduler_mode - self.worker_type = worker_type - self.worker_max_idletime = worker_max_idletime - self.cores_on_node = multiprocessing.cpu_count() - self.max_workers = max_workers - self.cores_per_workers = cores_per_worker - self.available_mem_on_node = round( - psutil.virtual_memory().available / (2**30), 1 - ) - self.max_worker_count = min( - max_workers, math.floor(self.cores_on_node / cores_per_worker) - ) - - # Control pinning to accelerators - self.available_accelerators = available_accelerators or [] - if self.available_accelerators: - self.max_worker_count = min( - self.max_worker_count, len(self.available_accelerators) - ) - - self.worker_map = WorkerMap(self.max_worker_count, self.available_accelerators) - - self.internal_worker_port_range = internal_worker_port_range - - self.funcx_task_socket = self.context.socket(zmq.ROUTER) - self.funcx_task_socket.set_hwm(0) - self.funcx_task_socket.setsockopt(zmq.IPV6, True) - self.address = "localhost" - self.worker_port = self.funcx_task_socket.bind_to_random_port( - "tcp://*", - min_port=self.internal_worker_port_range[0], - max_port=self.internal_worker_port_range[1], - ) - - log.info( - "Manager listening on {} port for incoming worker connections".format( - self.worker_port - ) - ) - - self.task_queues: dict[str, queue.Queue] = {} - if worker_type: - self.task_queues[worker_type] = queue.Queue() - self.outstanding_task_count: dict[str, int] = {} - self.task_type_mapping: dict[str, str] = {} - - self.pending_result_queue = mpQueue() - - self.max_queue_size = max_queue_size + self.max_worker_count - self.tasks_per_round = 1 - - self.heartbeat_period = heartbeat_period - self.heartbeat_threshold = heartbeat_threshold - self.poll_period = poll_period - self.next_worker_q: list[str] = [] # FIFO queue for spinning up workers. - self.worker_procs: dict[str, subprocess.Popen] = {} - - self.task_status_deltas: dict[str, list[TaskTransition]] = defaultdict(list) - - self._kill_event = threading.Event() - self._result_pusher_thread = threading.Thread( - target=self.push_results, args=(self._kill_event,), name="Result-Pusher" - ) - self._status_report_thread = threading.Thread( - target=self._status_report_loop, - args=(self._kill_event,), - name="Status-Report", - ) - self.container_switch_count = 0 - - self.poller = zmq.Poller() - self.poller.register(self.task_incoming, zmq.POLLIN) - self.poller.register(self.funcx_task_socket, zmq.POLLIN) - self.task_worker_map: dict[str, Any] = {} - - self.task_done_counter = 0 - self.task_finalization_lock = threading.Lock() - - def create_reg_message(self): - """Creates a registration message to identify the worker to the interchange""" - msg = { - "parsl_v": PARSL_VERSION, - "python_v": "{}.{}.{}".format( - sys.version_info.major, sys.version_info.minor, sys.version_info.micro - ), - "max_worker_count": self.max_worker_count, - "cores": self.cores_on_node, - "mem": self.available_mem_on_node, - "block_id": self.block_id, - "worker_type": self.worker_type, - "os": platform.system(), - "hname": platform.node(), - "dir": os.getcwd(), - } - b_msg = json.dumps(msg).encode("utf-8") - return b_msg - - def pull_tasks(self, kill_event): - """Pull tasks from the incoming tasks 0mq pipe onto the internal - pending task queue - - - While : - receive results and task requests from the workers - receive tasks/heartbeats from the Interchange - match tasks to workers - if task doesn't have appropriate worker type: - launch worker of type.. with LRU or some sort of caching strategy. - if workers >> tasks: - advertize available capacity - - Parameters: - ----------- - kill_event : threading.Event - Event to let the thread know when it is time to die. - """ - log.info("starting") - - # Send a registration message - msg = self.create_reg_message() - log.debug(f"Sending registration message: {msg}") - self.task_incoming.send(msg) - last_interchange_contact = time.time() - task_recv_counter = 0 - - poll_timer = self.poll_period - - new_worker_map = None - last_count_pending = -1 - last_count_worker = -1 - while not kill_event.is_set(): - # Disabling the check on ready_worker_queue disables batching - pending_task_count = task_recv_counter - self.task_done_counter - ready_worker_count = self.worker_map.ready_worker_count() - if (last_count_pending, last_count_worker) != ( - pending_task_count, - ready_worker_count, - ): - log.trace( - "pending_task_count: %s, Ready_worker_count: %s", - pending_task_count, - ready_worker_count, - ) - last_count_pending = pending_task_count - last_count_worker = ready_worker_count - - if pending_task_count < self.max_queue_size and ready_worker_count > 0: - ads = self.worker_map.advertisement() - log.trace("Requesting tasks: %s", ads) - msg = dill.dumps(ads) - self.task_incoming.send(msg) - - # Receive results from the workers, if any - socks = dict(self.poller.poll(timeout=poll_timer)) - - if ( - self.funcx_task_socket in socks - and socks[self.funcx_task_socket] == zmq.POLLIN - ): - self.poll_funcx_task_socket() - - # Receive task batches from Interchange and forward to workers - if self.task_incoming in socks and socks[self.task_incoming] == zmq.POLLIN: - # If we want to wrap the task_incoming polling into a separate function, - # we need to - # self.poll_task_incoming( - # poll_timer, - # last_interchange_contact, - # kill_event, - # task_revc_counter - # ) - poll_timer = 0 - _, pkl_msg = self.task_incoming.recv_multipart() - message = dill.loads(pkl_msg) - last_interchange_contact = time.time() - - if message == "STOP": - log.critical("Received stop request") - kill_event.set() - break - - elif isinstance(message, tuple) and message[0] == "TASK_CANCEL": - with self.task_finalization_lock: - task_id = message[1] - log.info(f"Received TASK_CANCEL request for task: {task_id}") - if task_id not in self.task_worker_map: - log.warning(f"Task:{task_id} is not in task_worker_map.") - log.warning("Possible duplicate cancel or race-condition") - continue - # Cancel task by killing the worker it is on - worker_id_raw = self.task_worker_map[task_id]["worker_id"] - worker_to_kill = self.task_worker_map[task_id][ - "worker_id" - ].decode("utf-8") - worker_type = self.task_worker_map[task_id]["task_type"] - log.debug( - "Cancelling task running on worker: %s", - self.task_worker_map[task_id], - ) - try: - log.info(f"Removing worker:{worker_id_raw} from map") - self.worker_map.start_remove_worker(worker_type) - - log.info( - f"Popping worker:{worker_to_kill} from worker_procs" - ) - proc = self.worker_procs.pop(worker_to_kill) - log.warning(f"Sending process:{proc.pid} terminate signal") - proc.terminate() - try: - proc.wait(1) # Wait 1 second before attempting SIGKILL - except subprocess.TimeoutExpired: - log.exception("Process did not terminate in 1 second") - log.warning(f"Sending process:{proc.pid} kill signal") - proc.kill() - else: - log.debug( - f"Worker process exited with : {proc.returncode}" - ) - - # Now that the worker is dead, remove it from worker map - self.worker_map.remove_worker(worker_id_raw) - raise TaskCancelled(worker_to_kill, self.uid) - except Exception as e: - log.exception(f"Raise exception, handling: {e}") - result_package = { - "task_id": task_id, - "container_id": worker_type, - "error_details": get_result_error_details(e), - "exception": get_error_string(tb_levels=0), - } - self.pending_result_queue.put(dill.dumps(result_package)) - - worker_proc = self.worker_map.add_worker( - worker_id=str(self.worker_map.worker_id_counter), - worker_type=self.worker_type, - container_cmd_options=self.container_cmd_options, - address=self.address, - debug=self.debug, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port, - ) - self.worker_procs.update(worker_proc) - self.task_worker_map.pop(task_id) - self.remove_task(task_id) - - elif message == HEARTBEAT_CODE: - log.debug("Got heartbeat from interchange") - - else: - tasks = [ - (rt["local_container"], Message.unpack(rt["raw_buffer"])) - for rt in message - ] - - task_recv_counter += len(tasks) - log.debug( - "Got tasks: {} of {}".format( - [t[1].task_id for t in tasks], task_recv_counter - ) - ) - - for task_type, task in tasks: - log.debug(f"Task is of type: {task_type}") - - if task_type not in self.task_queues: - self.task_queues[task_type] = queue.Queue() - if task_type not in self.outstanding_task_count: - self.outstanding_task_count[task_type] = 0 - self.task_queues[task_type].put(task) - self.outstanding_task_count[task_type] += 1 - self.task_type_mapping[task.task_id] = task_type - log.debug( - "Got task: Outstanding task counts: {}".format( - self.outstanding_task_count - ) - ) - log.debug( - f"Task {task.task_id} pushed to task queue " - f"for type: {task_type}" - ) - - else: - log.trace("No incoming tasks") - # Limit poll duration to heartbeat_period - # heartbeat_period is in s vs poll_timer in ms - if not poll_timer: - poll_timer = self.poll_period - poll_timer = min(self.heartbeat_period * 1000, poll_timer * 2) - - # Only check if no messages were received. - if time.time() > last_interchange_contact + self.heartbeat_threshold: - log.critical( - "Missing contact with interchange beyond heartbeat_threshold" - ) - kill_event.set() - log.critical("Killing all workers") - for proc in self.worker_procs.values(): - proc.kill() - log.critical("Exiting") - break - - log.trace( - "To-Die Counts: %s, alive worker counts: %s", - self.worker_map.to_die_count, - self.worker_map.total_worker_type_counts, - ) - - new_worker_map = naive_scheduler( - self.task_queues, - self.outstanding_task_count, - self.max_worker_count, - new_worker_map, - self.worker_map.to_die_count, - ) - log.trace("New worker map: %s", new_worker_map) - - # NOTE: Wipes the queue -- previous scheduling loops don't affect what's - # needed now. - self.next_worker_q, need_more = self.worker_map.get_next_worker_q( - new_worker_map - ) - - # Spin up any new workers according to the worker queue. - # Returns the total number of containers that have spun up. - self.worker_procs.update( - self.worker_map.spin_up_workers( - self.next_worker_q, - mode=self.worker_mode, - debug=self.debug, - container_cmd_options=self.container_cmd_options, - address=self.address, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port, - ) - ) - log.trace("Worker processes: %s", self.worker_procs) - - # Count the workers of each type that need to be removed - spin_downs, container_switch_count = self.worker_map.spin_down_workers( - new_worker_map, - worker_max_idletime=self.worker_max_idletime, - need_more=need_more, - scheduler_mode=self.scheduler_mode, - ) - self.container_switch_count += container_switch_count - log.trace( - "Container switch count: total %s, cur %s", - self.container_switch_count, - container_switch_count, - ) - - for w_type in spin_downs: - self.remove_worker_init(w_type) - - current_worker_map = self.worker_map.get_worker_counts() - for task_type in current_worker_map: - if task_type == "unused": - continue - - # *** Match tasks to workers *** # - else: - available_workers = current_worker_map[task_type] - log.trace( - "Available workers of type %s: %s", task_type, available_workers - ) - - for _i in range(available_workers): - if ( - task_type in self.task_queues - and not self.task_queues[task_type].qsize() == 0 - and not self.worker_map.worker_queues[task_type].qsize() - == 0 - ): - log.debug( - "Task type {} has task queue size {}".format( - task_type, self.task_queues[task_type].qsize() - ) - ) - log.debug( - "... and available workers: {}".format( - self.worker_map.worker_queues[task_type].qsize() - ) - ) - - self.send_task_to_worker(task_type) - - def poll_funcx_task_socket(self, test=False): - try: - w_id, m_type, message = self.funcx_task_socket.recv_multipart() - if m_type == b"REGISTER": - reg_info = dill.loads(message) - log.debug(f"Registration received from worker:{w_id} {reg_info}") - self.worker_map.register_worker(w_id, reg_info["worker_type"]) - - elif m_type == b"TASK_RET": - # the following steps are also shared by task_cancel - with self.task_finalization_lock: - log.debug(f"Result received from worker: {w_id}") - task_id = dill.loads(message)["task_id"] - try: - self.remove_task(task_id) - except KeyError: - log.exception(f"Task:{task_id} missing in task structure") - else: - self.pending_result_queue.put(message) - self.worker_map.put_worker(w_id) - - elif m_type == b"WRKR_DIE": - log.debug(f"[WORKER_REMOVE] Removing worker {w_id} from worker_map...") - log.debug( - "Ready worker counts: {}".format( - self.worker_map.ready_worker_type_counts - ) - ) - log.debug( - "Total worker counts: {}".format( - self.worker_map.total_worker_type_counts - ) - ) - self.worker_map.remove_worker(w_id) - proc = self.worker_procs.pop(w_id.decode()) - if not proc.poll(): - try: - proc.wait(timeout=1) - except subprocess.TimeoutExpired: - log.warning( - "[WORKER_REMOVE] Timeout waiting for worker %s process to " - "terminate", - w_id, - ) - log.debug(f"[WORKER_REMOVE] Removing worker {w_id} process object") - log.debug(f"[WORKER_REMOVE] Worker processes: {self.worker_procs}") - - if test: - return dill.loads(message) - - except Exception: - log.exception("Unhandled exception while processing worker messages") - - def remove_task(self, task_id: str): - task_type = self.task_type_mapping.pop(task_id) - self.task_status_deltas.pop(task_id, None) - self.outstanding_task_count[task_type] -= 1 - self.task_done_counter += 1 - - def send_task_to_worker(self, task_type): - task = self.task_queues[task_type].get() - worker_id = self.worker_map.get_worker(task_type) - - log.debug(f"Sending task {task.task_id} to {worker_id}") - # TODO: Some duplication of work could be avoided here - to_send = [ - worker_id, - dill.dumps(task.task_id), - dill.dumps(task.container_id), - task.pack(), - ] - self.funcx_task_socket.send_multipart(to_send) - self.worker_map.update_worker_idle(task_type) - if task.task_id != "KILL": - log.debug(f"Set task {task.task_id} to RUNNING") - tt = TaskTransition( - timestamp=time.time_ns(), - state=TaskState.RUNNING, - actor=ActorName.MANAGER, - ) - self.task_status_deltas[task.task_id].append(tt) - self.task_worker_map[task.task_id] = { - "worker_id": worker_id, - "task_type": task_type, - } - log.debug("Sending complete") - - def _status_report_loop(self, kill_event: threading.Event): - log.debug("Manager status reporting loop starting") - - while not kill_event.wait(timeout=self.heartbeat_period): - msg = ManagerStatusReport( - self.task_status_deltas, - self.container_switch_count, - ) - log.info(f"Sending status report to interchange: {msg.task_statuses}") - self.pending_result_queue.put(msg) - if self.task_status_deltas: - log.info("Clearing task deltas") - self.task_status_deltas.clear() - - def push_results(self, kill_event, max_result_batch_size=1): - """Listens on the pending_result_queue and sends out results via 0mq - - Parameters: - ----------- - kill_event : threading.Event - Event to let the thread know when it is time to die. - """ - - log.debug("Starting thread") - - push_poll_period = ( - max(10, self.poll_period) / 1000 - ) # push_poll_period must be atleast 10 ms - log.debug(f"push poll period: {push_poll_period}") - - last_beat = time.time() - items = [] - - while not kill_event.is_set(): - try: - r = self.pending_result_queue.get(block=True, timeout=push_poll_period) - # This avoids the interchange searching and attempting to unpack every - # message in case it's a status report. - # (It would be better to use Task Messages eventually to make this more - # uniform) - # TODO: use task messages, and don't have to prepend - if isinstance(r, ManagerStatusReport): - items.insert(0, r.pack()) - else: - items.append(r) - except queue.Empty: - pass - except Exception as e: - log.exception(f"Got an exception: {e}") - - # If we have reached poll_period duration or timer has expired, we send - # results - if ( - len(items) >= self.max_queue_size - or time.time() > last_beat + push_poll_period - ): - last_beat = time.time() - if items: - self.result_outgoing.send_multipart(items) - items = [] - - log.critical("Exiting") - - def remove_worker_init(self, worker_type): - """ - Kill/Remove a worker of a given worker_type. - - Add a kill message to the task_type queue. - - Assumption : All workers of the same type are uniform, and therefore don't - discriminate when killing. - """ - - log.debug( - "[WORKER_REMOVE] Appending KILL message to worker queue {}".format( - worker_type - ) - ) - self.worker_map.start_remove_worker(worker_type) - task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") - self.task_queues[worker_type].put(task) - - def start(self): - """ - * while True: - Receive tasks and start appropriate workers - Push tasks to available workers - Forward results - """ - - if self.worker_type and self.scheduler_mode == "hard": - log.debug( - "[MANAGER] Start an initial worker with worker type {}".format( - self.worker_type - ) - ) - self.worker_procs.update( - self.worker_map.add_worker( - worker_id=str(self.worker_map.worker_id_counter), - worker_type=self.worker_type, - container_cmd_options=self.container_cmd_options, - address=self.address, - debug=self.debug, - uid=self.uid, - logdir=self.logdir, - worker_port=self.worker_port, - ) - ) - - log.debug("Initial workers launched") - self._result_pusher_thread.start() - self._status_report_thread.start() - self.pull_tasks(self._kill_event) - log.info("Waiting") - - -def cli_run(): - parser = argparse.ArgumentParser() - parser.add_argument( - "-d", "--debug", action="store_true", help="Count of apps to launch" - ) - parser.add_argument( - "-l", - "--logdir", - default="process_worker_pool_logs", - help="Process worker pool log directory", - ) - parser.add_argument( - "-u", - "--uid", - default=str(uuid.uuid4()).split("-")[-1], - help="Unique identifier string for Manager", - ) - parser.add_argument( - "-b", "--block_id", default=None, help="Block identifier string for Manager" - ) - parser.add_argument( - "-c", - "--cores_per_worker", - default="1.0", - help="Number of cores assigned to each worker process. Default=1.0", - ) - parser.add_argument( - "-a", - "--available-accelerators", - default=(), - nargs="*", - help="List of available accelerators", - ) - parser.add_argument( - "-t", "--task_url", required=True, help="REQUIRED: ZMQ url for receiving tasks" - ) - parser.add_argument( - "--max_workers", - default=float("inf"), - help="Caps the maximum workers that can be launched, default:infinity", - ) - parser.add_argument( - "--hb_period", - default=30, - help="Heartbeat period in seconds. Uses manager default unless set", - ) - parser.add_argument( - "--hb_threshold", - default=120, - help="Heartbeat threshold in seconds. Uses manager default unless set", - ) - parser.add_argument("--poll", default=10, help="Poll period used in milliseconds") - parser.add_argument( - "--worker_type", default=None, help="Fixed worker type of manager" - ) - parser.add_argument( - "--worker_mode", - default="singularity_reuse", - help=( - "Choose the mode of operation from " - "(no_container, singularity_reuse, singularity_single_use" - ), - ) - parser.add_argument( - "--container_cmd_options", - default="", - help=("Container cmd options to add to container startup cmd"), - ) - parser.add_argument( - "--scheduler_mode", - default="soft", - help=("Choose the mode of scheduler (hard, soft"), - ) - parser.add_argument( - "-r", - "--result_url", - required=True, - help="REQUIRED: ZMQ url for posting results", - ) - - args = parser.parse_args() - - setup_logging( - logfile=os.path.join(args.logdir, args.uid, "manager.log"), debug=args.debug - ) - - try: - log.info(f"Python version: {sys.version}") - log.info( - "Arguments:" - f"\n Debug logging: {args.debug}" - f"\n Log dir: {args.logdir}" - f"\n Manager ID: {args.uid}" - f"\n Block ID: {args.block_id}" - f"\n cores_per_worker: {args.cores_per_worker}" - f"\n available_accelerators: {args.available_accelerators}" - f"\n task_url: {args.task_url}" - f"\n result_url: {args.result_url}" - f"\n hb_period: {args.hb_period}" - f"\n hb_threshold: {args.hb_threshold}" - f"\n max_workers: {args.max_workers}" - f"\n poll_period: {args.poll}" - f"\n worker_mode: {args.worker_mode}" - f"\n container_cmd_options: {args.container_cmd_options}" - f"\n scheduler_mode: {args.scheduler_mode}" - f"\n worker_type: {args.worker_type}" - ) - - manager = Manager( - task_q_url=args.task_url, - result_q_url=args.result_url, - uid=args.uid, - block_id=args.block_id, - cores_per_worker=float(args.cores_per_worker), - available_accelerators=args.available_accelerators, - max_workers=( - args.max_workers - if args.max_workers == float("inf") - else int(args.max_workers) - ), - heartbeat_threshold=int(args.hb_threshold), - heartbeat_period=int(args.hb_period), - logdir=args.logdir, - debug=args.debug, - worker_mode=args.worker_mode, - container_cmd_options=args.container_cmd_options, - scheduler_mode=args.scheduler_mode, - worker_type=args.worker_type, - poll_period=int(args.poll), - ) - manager.start() - - except Exception as e: - log.critical("process_worker_pool exiting from an exception") - log.exception(f"Caught error: {e}") - raise - else: - log.info("process_worker_pool main event loop exiting normally") - print("PROCESS_WORKER_POOL main event loop exiting normally") - - -if __name__ == "__main__": - cli_run() diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/messages.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/messages.py deleted file mode 100644 index 2c821ee75..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/messages.py +++ /dev/null @@ -1,326 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from abc import ABC, abstractmethod -from collections import defaultdict -from enum import Enum, auto -from struct import Struct - -from globus_compute_common.messagepack.message_types import TaskTransition - -MESSAGE_TYPE_FORMATTER = Struct("b") - - -class MessageType(Enum): - HEARTBEAT_REQ = auto() - HEARTBEAT = auto() - EP_STATUS_REPORT = auto() - MANAGER_STATUS_REPORT = auto() - TASK = auto() - RESULTS_ACK = auto() - TASK_CANCEL = auto() - BAD_COMMAND = auto() - - def pack(self): - return MESSAGE_TYPE_FORMATTER.pack(self.value) - - @classmethod - def unpack(cls, buffer): - (mtype,) = MESSAGE_TYPE_FORMATTER.unpack_from(buffer, offset=0) - return MessageType(mtype), buffer[MESSAGE_TYPE_FORMATTER.size :] - - -COMMAND_TYPES = {MessageType.HEARTBEAT_REQ, MessageType.TASK_CANCEL} - - -class Message(ABC): - def __init__(self): - self._payload = None - self._header = None - - @property - def header(self): - return self._header - - @property - def type(self): - raise NotImplementedError() - - @property - def payload(self): - return self._payload - - @classmethod - def unpack(cls, msg): - message_type, remaining = MessageType.unpack(msg) - if message_type is MessageType.HEARTBEAT_REQ: - return HeartbeatReq.unpack(remaining) - elif message_type is MessageType.HEARTBEAT: - return Heartbeat.unpack(remaining) - elif message_type is MessageType.EP_STATUS_REPORT: - return EPStatusReport.unpack(remaining) - elif message_type is MessageType.MANAGER_STATUS_REPORT: - return ManagerStatusReport.unpack(remaining) - elif message_type is MessageType.TASK: - return Task.unpack(remaining) - elif message_type is MessageType.RESULTS_ACK: - return ResultsAck.unpack(remaining) - elif message_type is MessageType.TASK_CANCEL: - return TaskCancel.unpack(remaining) - elif message_type is MessageType.BAD_COMMAND: - return BadCommand.unpack(remaining) - - raise Exception(f"Unknown Message Type Code: {message_type}") - - @abstractmethod - def pack(self): - raise NotImplementedError() - - -class Task(Message): - """ - Task message from the forwarder->interchange - """ - - type = MessageType.TASK - - def __init__( - self, task_id: str, container_id: str, task_buffer: str | bytes, raw_buffer=None - ): - super().__init__() - self.task_id = task_id - self.container_id = container_id - self.task_buffer = task_buffer - self.raw_buffer = raw_buffer - - def pack(self) -> bytes: - if self.raw_buffer is None: - # a type:ignore is needed here - # task_buffer might be a str or it might be a bytes (see `unpack`) - # and we can't tell which is correct - # - # rather than thinking hard, preserve the exact current runtime behavior - # - # all of this code is going to be eliminated soonish by - # globus_compute_common.messagepack in part because of issues like this - if isinstance(self.task_buffer, bytes): - buf = self.task_buffer.decode() - else: - buf = self.task_buffer - add_ons = f"TID={self.task_id};CID={self.container_id};{buf}" - self.raw_buffer = add_ons.encode("utf-8") - - return self.type.pack() + self.raw_buffer - - @classmethod - def unpack(cls, raw_buffer: bytes): - b_tid, b_cid, task_buf = raw_buffer.decode("utf-8").split(";", 2) - return cls( - b_tid[4:], b_cid[4:], task_buf.encode("utf-8"), raw_buffer=raw_buffer - ) - - def set_local_container(self, container_id): - self.local_container = container_id - - -class HeartbeatReq(Message): - """ - Synchronous request for a Heartbeat. - - This is sent from the Forwarder to the endpoint on start to get an initial - connection and ensure liveness. - """ - - type = MessageType.HEARTBEAT_REQ - - @property - def header(self): - return None - - @property - def payload(self): - return None - - @classmethod - def unpack(cls, msg): - return cls() - - def pack(self): - return self.type.pack() - - -class Heartbeat(Message): - """ - Generic Heartbeat message, sent in both directions between Forwarder and - Interchange. - """ - - type = MessageType.HEARTBEAT - - def __init__(self, endpoint_id): - super().__init__() - self.endpoint_id = endpoint_id - - @classmethod - def unpack(cls, msg): - return cls(msg.decode("ascii")) - - def pack(self): - return self.type.pack() + self.endpoint_id.encode("ascii") - - -class EPStatusReport(Message): - """ - Status report for an endpoint, sent from Interchange to Forwarder. - - Includes EP-wide info such as utilization, as well as per-task status information. - """ - - type = MessageType.EP_STATUS_REPORT - - def __init__(self, endpoint_id, global_state, task_statuses): - super().__init__() - self._header = uuid.UUID(endpoint_id).bytes - self.global_state = global_state - self.task_statuses = task_statuses - - def __repr__(self): - name = type(self).__name__ - ep_id = uuid.UUID(bytes=self._header) - return f"{name}(ep:{ep_id}; task count:{len(self.task_statuses)})" - - @classmethod - def unpack(cls, msg): - endpoint_id = str(uuid.UUID(bytes=msg[:16])) - msg = msg[16:] - jsonified = msg.decode("ascii") - global_state, statuses = json.loads(jsonified) - task_statuses = defaultdict(list) - for tid, tt in statuses.items(): - for trans in tt: - task_statuses[tid].append( - TaskTransition( - timestamp=trans["timestamp"], - actor=trans["actor"], - state=trans["state"], - ) - ) - return cls(endpoint_id, global_state, task_statuses) - - def pack(self): - statuses = {} - for tid, tt in self.task_statuses.items(): - for status in tt: - statuses[tid] = statuses.get(tid, []) - statuses[tid].append(status.to_dict()) - jsonified = json.dumps([self.global_state, statuses]) - return self.type.pack() + self._header + jsonified.encode("ascii") - - -class ManagerStatusReport(Message): - """ - Status report sent from the Manager to the Interchange, which mostly just amounts - to saying which tasks are now RUNNING. - """ - - type = MessageType.MANAGER_STATUS_REPORT - - def __init__(self, task_statuses, container_switch_count): - super().__init__() - self.task_statuses = task_statuses - self.container_switch_count = container_switch_count - - @classmethod - def unpack(cls, msg): - container_switch_count = int.from_bytes(msg[:10], "little") - msg = msg[10:] - jsonified = msg.decode("ascii") - statuses = json.loads(jsonified) - task_statuses = defaultdict(list) - for tid, tt in statuses.items(): - for trans in tt: - task_statuses[tid].append( - TaskTransition( - timestamp=trans["timestamp"], - actor=trans["actor"], - state=trans["state"], - ) - ) - return cls(task_statuses, container_switch_count) - - def pack(self): - # TODO: do better than JSON? - statuses = {} - for tid, tt in self.task_statuses.items(): - for status in tt: - statuses[tid] = statuses.get(tid, []) - statuses[tid].append(status.to_dict()) - - jsonified = json.dumps(statuses) - return ( - self.type.pack() - + self.container_switch_count.to_bytes(10, "little") - + jsonified.encode("ascii") - ) - - -class ResultsAck(Message): - """ - Results acknowledgement to acknowledge a task result was received by - the forwarder. Sent from forwarder->interchange - """ - - type = MessageType.RESULTS_ACK - - def __init__(self, task_id): - super().__init__() - self.task_id = task_id - - @classmethod - def unpack(cls, msg): - return cls(msg.decode("ascii")) - - def pack(self): - return self.type.pack() + self.task_id.encode("ascii") - - -class TaskCancel(Message): - """ - Synchronous request for to cancel a Task. - - This is sent from the Executor to the Interchange - """ - - type = MessageType.TASK_CANCEL - - def __init__(self, task_id): - super().__init__() - self.task_id = task_id - - @classmethod - def unpack(cls, msg): - return cls(json.loads(msg.decode("ascii"))) - - def pack(self): - return self.type.pack() + json.dumps(self.task_id).encode("ascii") - - -class BadCommand(Message): - """ - Error message send to indicate that a command is either - unknown, malformed or unsupported. - """ - - type = MessageType.BAD_COMMAND - - def __init__(self, reason: str): - super().__init__() - self.reason = reason - - @classmethod - def unpack(cls, msg): - return cls(msg.decode("ascii")) - - def pack(self): - return self.type.pack() + self.reason.encode("ascii") diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker.py deleted file mode 100644 index bd3a30abd..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker.py +++ /dev/null @@ -1,221 +0,0 @@ -from __future__ import annotations - -import argparse -import logging -import os -import signal -import sys -import typing as t - -import dill -import zmq -from globus_compute_common import messagepack -from globus_compute_common.messagepack.message_types import Result as OutgoingResult -from globus_compute_common.messagepack.message_types import ( - ResultErrorDetails as OutgoingResultErrorDetails, -) -from globus_compute_endpoint.engines.high_throughput.messages import Message -from globus_compute_endpoint.exception_handling import ( - get_error_string, - get_result_error_details, -) -from globus_compute_endpoint.logging_config import setup_logging -from globus_compute_sdk.sdk.utils import get_env_details -from globus_compute_sdk.serialize import ComputeSerializer - -log = logging.getLogger(__name__) - -DEFAULT_RESULT_SIZE_LIMIT_MB = 10 -DEFAULT_RESULT_SIZE_LIMIT_B = DEFAULT_RESULT_SIZE_LIMIT_MB * 1024 * 1024 - - -class Worker: - """The Globus Compute worker - Parameters - ---------- - - worker_id : str - Worker id string - - address : str - Address at which the manager might be reached. This is usually the ipv4 - or ipv6 loopback address 127.0.0.1 or ::1 - - port : int - Port at which the manager can be reached - - result_size_limit : int - Maximum result size allowed in Bytes - Default = 10 MB - - Globus Compute worker will use the REP sockets to: - task = recv () - result = execute(task) - send(result) - """ - - def __init__( - self, - worker_id, - address, - port, - worker_type="RAW", - result_size_limit=DEFAULT_RESULT_SIZE_LIMIT_B, - ): - self.worker_id = worker_id - self.address = address - self.port = port - self.worker_type = worker_type - self.serializer = ComputeSerializer() - self.serialize = self.serializer.serialize - self.deserialize = self.serializer.deserialize - self.result_size_limit = result_size_limit - - log.info(f"Initializing worker {worker_id}") - log.info(f"Worker is of type: {worker_type}") - - self.context = zmq.Context() - self.poller = zmq.Poller() - self.identity = worker_id.encode() - - self.task_socket = self.context.socket(zmq.DEALER) - self.task_socket.setsockopt(zmq.IDENTITY, self.identity) - self.task_socket.setsockopt(zmq.IPV6, True) - - log.info(f"Trying to connect to : tcp://{self.address}:{self.port}") - self.task_socket.connect(f"tcp://{self.address}:{self.port}") - self.poller.register(self.task_socket, zmq.POLLIN) - signal.signal(signal.SIGTERM, self.handler) - - def handler(self, signum, frame): - log.error(f"Signal handler called with signal {signum}") - sys.exit(1) - - def _send_registration_message(self): - log.debug("Sending registration") - payload = {"worker_id": self.worker_id, "worker_type": self.worker_type} - self.task_socket.send_multipart([b"REGISTER", dill.dumps(payload)]) - - def start(self): - log.info("Starting worker") - self._send_registration_message() - - while True: - log.debug("Waiting for task") - p_task_id, p_container_id, msg = self.task_socket.recv_multipart() - task_id: str = dill.loads(p_task_id) - container_id: str = dill.loads(p_container_id) - log.debug(f"Received task with task_id='{task_id}' and msg='{msg}'") - - if task_id == "KILL": - log.info("[KILL] -- Worker KILL message received! ") - # send a "worker die" message back to the manager - self.task_socket.send_multipart([b"WRKR_DIE", b""]) - log.info(f"*** WORKER {self.worker_id} ABOUT TO DIE ***") - # Kill the worker after accepting death in message to manager. - sys.exit() - else: - result = self._worker_execute_task(task_id, msg) - result["container_id"] = container_id - log.debug("Sending result") - # send bytes over the socket back to the manager - self.task_socket.send_multipart([b"TASK_RET", dill.dumps(result)]) - - log.warning("Broke out of the loop... dying") - - def compose_exception_message(self, task_id: str) -> bytes: - code, user_message = get_result_error_details() - outgoing_result = OutgoingResult( - task_id=task_id, - data=get_error_string(), - error_details=OutgoingResultErrorDetails( - code=code, - user_message=user_message, - ), - details=get_env_details(), - ) - return messagepack.pack(outgoing_result) - - def _worker_execute_task( - self, task_id: str, msg: bytes - ) -> dict[str, t.Union[str, bytes]]: - result_message: dict[str, t.Union[str, bytes]] = {"task_id": task_id} - try: - # Unwrap HTEX's Task packing - task_message = Message.unpack(msg) - serialized_fn_package = task_message.task_buffer.decode() - - # Deserialize HTEX Engines' wrapping of - # execute_task, messagepack_payload) - function, args, kwargs = self.deserialize(serialized_fn_package) - - # Execute - serialized_result: bytes = function(*args, **kwargs) - result_message["data"] = serialized_result - - except Exception: - log.exception("Failed to execute task") - serialized_error = self.compose_exception_message(task_id) - result_message["data"] = serialized_error - - return result_message - - -def cli_run(): - parser = argparse.ArgumentParser() - parser.add_argument( - "-w", "--worker_id", required=True, help="ID of worker from process_worker_pool" - ) - parser.add_argument( - "-t", "--type", required=False, help="Container type of worker", default="RAW" - ) - parser.add_argument( - "-a", "--address", required=True, help="Address for the manager, eg X,Y," - ) - parser.add_argument( - "-p", - "--port", - required=True, - help="Internal port at which the worker connects to the manager", - ) - parser.add_argument( - "--logdir", required=True, help="Directory path where worker log files written" - ) - parser.add_argument( - "-d", - "--debug", - action="store_true", - help="Directory path where worker log files written", - ) - args = parser.parse_args() - - setup_logging( - logfile=os.path.join(args.logdir, f"funcx_worker_{args.worker_id}.log"), - debug=args.debug, - ) - - # Redirect the stdout and stderr - stdout_path = os.path.join(args.logdir, f"funcx_worker_{args.worker_id}.stdout") - stderr_path = os.path.join(args.logdir, f"funcx_worker_{args.worker_id}.stderr") - with open(stdout_path, "w") as fo, open(stderr_path, "w") as fe: - # Redirect the stdout - old_stdout, old_stderr = sys.stdout, sys.stderr - sys.stdout = fo - sys.stderr = fe - - try: - worker = Worker( - args.worker_id, - args.address, - int(args.port), - worker_type=args.type, - ) - worker.start() - finally: - # Switch them back - sys.stdout = old_stdout - sys.stderr = old_stderr - - -if __name__ == "__main__": - cli_run() diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker_map.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker_map.py deleted file mode 100644 index 8a95bf0e9..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/worker_map.py +++ /dev/null @@ -1,516 +0,0 @@ -from __future__ import annotations - -import logging -import os -import random -import subprocess -import time -from collections import defaultdict -from queue import Empty, Queue -from typing import Any - -from globus_compute_endpoint.logging_config import ComputeLogger - -log: ComputeLogger = logging.getLogger(__name__) # type: ignore - - -class WorkerMap: - """WorkerMap keeps track of workers""" - - def __init__( - self, - max_worker_count: int, - available_accelerators: list[str], - ): - """ - - Parameters - ---------- - max_worker_count: - Maximum number of workers allowed - available_accelerators: - List of accelerator devices workers can be pinned to - """ - self.max_worker_count = max_worker_count - self.total_worker_type_counts: dict[str, int] = { - "unused": self.max_worker_count - } - self.ready_worker_type_counts: dict[str, int] = { - "unused": self.max_worker_count - } - self.pending_worker_type_counts: dict[str, Any] = {} - # a dict to keep track of all the worker_queues with the key of work_type - self.worker_queues: dict[str, Any] = {} - # a dict to keep track of all the worker_types with the key of worker_id - self.worker_types: dict[str, str] = {} - self.worker_id_counter = 0 # used to create worker_ids - - # Only spin up containers if active_workers + pending_workers < max_workers. - self.active_workers = 0 - self.pending_workers = 0 - - # Need to keep track of workers that are ABOUT to die - self.to_die_count: dict[str, int] = {} - - # Need to keep track of workers' last idle time by worker type - self.worker_idle_since: dict[str, float] = {} - - # Create a queue of available accelerators, if accelerators are defined - self.available_accelerators: Queue | None = None - if len(available_accelerators) != 0: - self.available_accelerators = Queue() - for device in available_accelerators: - self.available_accelerators.put(device) - self.assigned_accelerators: dict[str, str] = {} # Map worker ID -> accelerator - - self._noisy_log: dict[str, Any] = defaultdict(dict) - - def register_worker(self, worker_id, worker_type): - """Add a new worker""" - log.debug(f"In register worker worker_id: {worker_id} type:{worker_type}") - self.worker_types[worker_id] = worker_type - - if worker_type not in self.worker_queues: - self.worker_queues[worker_type] = Queue() - - self.total_worker_type_counts[worker_type] = ( - self.total_worker_type_counts.get(worker_type, 0) + 1 - ) - self.ready_worker_type_counts[worker_type] = ( - self.ready_worker_type_counts.get(worker_type, 0) + 1 - ) - self.pending_worker_type_counts[worker_type] = ( - self.pending_worker_type_counts.get(worker_type, 0) - 1 - ) - self.pending_workers -= 1 - self.active_workers += 1 - self.worker_queues[worker_type].put(worker_id) - self.worker_idle_since[worker_type] = time.time() - - if worker_type not in self.to_die_count: - self.to_die_count[worker_type] = 0 - - def start_remove_worker(self, worker_type): - """Increase the to_die_count in prep for a worker getting removed""" - self.to_die_count[worker_type] += 1 - - def remove_worker(self, worker_id): - """Remove the worker from the WorkerMap - - Should already be KILLed by this point. - """ - worker_type = self.worker_types[worker_id] - self.active_workers -= 1 - self.total_worker_type_counts[worker_type] -= 1 - self.to_die_count[worker_type] -= 1 - self.total_worker_type_counts["unused"] += 1 - self.ready_worker_type_counts["unused"] += 1 - - # Mark the accelerator as available, if provided - if worker_id in self.assigned_accelerators: - device = self.assigned_accelerators.pop(worker_id) - self.available_accelerators.put(device) - - def spin_up_workers( - self, - next_worker_q, - mode="no_container", - container_cmd_options="", - address=None, - debug=None, - uid=None, - logdir=None, - worker_port=None, - ): - """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. - - Parameters - ---------- - new_worker_q : queue.Queue() - Queue of worker types to be spun up next. - mode : str - Mode of the worker, no_container, singularity, etc. - address : str - Address at which to connect to the workers. - debug : bool - Whether debug logging is activated. - uid : str - Worker ID to be assigned to worker. - logdir: str - Directory in which to write logs - worker_port: int - Port at which to connect to the workers. - - Returns - --------- - Total number of spun-up workers. - """ - spin_ups = {} - - log.trace("Next Worker Qsize: %s", len(next_worker_q)) - log.trace("Active Workers: %s", self.active_workers) - log.trace("Pending Workers: %s", self.pending_workers) - log.trace("Max Worker Count: %s", self.max_worker_count) - - if ( - len(next_worker_q) > 0 - and self.active_workers + self.pending_workers < self.max_worker_count - ): - log.debug("Spinning up new workers") - log.debug( - "Empty slots: %s", - self.max_worker_count - self.active_workers - self.pending_workers, - ) - log.debug(f"New workers: {len(next_worker_q)}") - log.debug(f"Unused slots: {self.total_worker_type_counts['unused']}") - num_slots = min( - self.max_worker_count - self.active_workers - self.pending_workers, - len(next_worker_q), - self.total_worker_type_counts["unused"], - ) - for _ in range(num_slots): - try: - proc = self.add_worker( - worker_id=str(self.worker_id_counter), - worker_type=next_worker_q.pop(0), - container_cmd_options=container_cmd_options, - mode=mode, - address=address, - debug=debug, - uid=uid, - logdir=logdir, - worker_port=worker_port, - ) - except Exception: - log.exception("Error spinning up worker! Skipping...") - continue - else: - spin_ups.update(proc) - return spin_ups - - def spin_down_workers( - self, - new_worker_map, - worker_max_idletime=60, - need_more=False, - scheduler_mode="hard", - ): - """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. - - Parameters - ---------- - new_worker_map : dict - {worker_type: total_number_of_containers,...}. - need_more: bool - whether the manager needs to spin down some warm containers - Returns - --------- - List of removed worker types. - """ - if need_more: - return self._spin_down( - new_worker_map, - worker_max_idletime=worker_max_idletime, - scheduler_mode=scheduler_mode, - check_idle=False, - ) - else: - return self._spin_down( - new_worker_map, - worker_max_idletime=worker_max_idletime, - scheduler_mode=scheduler_mode, - check_idle=True, - ) - - def _spin_down( - self, - new_worker_map, - worker_max_idletime=60, - scheduler_mode="hard", - check_idle=True, - ): - """Helper function to call 'remove' for appropriate workers in 'new_worker_map'. - - Parameters - ---------- - new_worker_map : dict - {worker_type: total_number_of_containers,...}. - check_idle : boolean - A boolean to indicate whether to check the idle time of containers or not - - If checked, that means the workloads are not so busy, and we can leave the - container workers alive until the worker_max_idletime is reached. Otherwise, - that means the workloads are busy and we need to turn of some containers to - acommodate the workers, regardless of if it reaches the worker_max_idletime. - - Returns - --------- - List of removed worker types. - """ - spin_downs = [] - container_switch_count = 0 - now = time.time() - for worker_type in self.total_worker_type_counts: - if worker_type == "unused": - continue - if ( - check_idle - and now - self.worker_idle_since[worker_type] < worker_max_idletime - ): - log.trace( - "Current time: %s (idle since: %s). Worker type %s has not " - "exceeded maximum idle time %s; continuing", - now, - self.worker_idle_since[worker_type], - worker_type, - worker_max_idletime, - ) - continue - num_remove = max( - 0, - self.total_worker_type_counts[worker_type] - - self.to_die_count.get(worker_type, 0) - - new_worker_map.get(worker_type, 0), - ) - if scheduler_mode == "hard": - # Leave at least one worker alive in hard mode - max_remove = max(0, self.total_worker_type_counts[worker_type] - 1) - num_remove = min(num_remove, max_remove) - - if num_remove > 0: - log.debug(f"Removing {num_remove} workers of type {worker_type}") - for _i in range(num_remove): - spin_downs.append(worker_type) - # A container switching is defined as a warm container must be - # switched to another container to accommodate the workloads. - # If a container worker has been idle for worker_max_idletime, - # Then it is not counted as a container switching - if not check_idle: - container_switch_count += num_remove - return spin_downs, container_switch_count - - def add_worker( - self, - worker_id=None, - mode="no_container", - worker_type="RAW", - container_cmd_options="", - walltime=1, - address=None, - debug=None, - worker_port=None, - logdir=None, - uid=None, - ): - """Launch the appropriate worker - - Parameters - ---------- - worker_id : str - Worker identifier string - mode : str - Valid options are no_container, singularity - walltime : int - Walltime in seconds before we check status - - """ - if worker_id is None: - str(random.random()) - - debug = " --debug" if debug else "" - - worker_id = f" --worker_id {worker_id}" - - self.worker_id_counter += 1 - - cmd = ( - f"globus-compute-worker {debug}{worker_id} " - f"-a {address} " - f"-p {worker_port} " - f"-t {worker_type} " - f"--logdir={os.path.join(logdir, uid)} " - ) - - container_uri = None - if worker_type != "RAW": - container_uri = worker_type - - # If accelerator list is provided, get the next one off the queue - # and mark it as assigned - environment_variables = os.environ.copy() - if self.available_accelerators is not None: - try: - device = self.available_accelerators.get_nowait() - except Empty: - raise ValueError( - "No accelerators are available." - " New worker must be created only" - " after another is removed" - ) - self.assigned_accelerators[worker_id] = device - log.info(f"Assigned worker '{worker_id}' to accelerator '{device}'") - - # Create the - # TODO (wardlt): This code has only been tested for CUDA - environment_variables["CUDA_VISIBLE_DEVICES"] = device - environment_variables["ROCR_VISIBLE_DEVICES"] = device - environment_variables["SYCL_DEVICE_FILTER"] = f"*:*:{device}" - - log.info(f"Command string :\n {cmd}") - log.info(f"Mode: {mode}") - log.info(f"Container uri: {container_uri}") - log.info(f"Container cmd options: {container_cmd_options}") - log.info(f"Worker type: {worker_type}") - - if mode == "no_container": - modded_cmd = cmd - elif mode == "singularity_reuse": - if container_uri is None: - log.warning( - "No container is specified for singularity mode. " - "Spawning a worker in a raw process instead." - ) - modded_cmd = cmd - elif not os.path.exists(container_uri): - log.warning( - f"Container uri {container_uri} is not found. " - "Spawning a worker in a raw process instead." - ) - modded_cmd = cmd - else: - modded_cmd = ( - f"singularity exec {container_cmd_options} {container_uri} {cmd}" - ) - log.info(f"Command string with singularity:\n {modded_cmd}") - else: - raise NameError("Invalid container launch mode.") - - try: - proc = subprocess.Popen( - modded_cmd.split(), - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - shell=False, - env=environment_variables, - ) - - except Exception: - log.exception("Got an error in worker launch") - raise - - self.total_worker_type_counts["unused"] -= 1 - self.ready_worker_type_counts["unused"] -= 1 - self.pending_worker_type_counts[worker_type] = ( - self.pending_worker_type_counts.get(worker_type, 0) + 1 - ) - self.pending_workers += 1 - - return {str(self.worker_id_counter - 1): proc} - - def get_next_worker_q(self, new_worker_map) -> tuple[list[str], bool]: - """Helper function to generate a queue of next workers to spin up . - From a mapping generated by the scheduler - - Parameters - ---------- - new_worker_map : dict - {worker_type: total_number_of_containers,...} - - Returns - --------- - Queue containing the next workers the system should spin-up. - """ - - _log_data = self._noisy_log["get_next_worker_q"] - _l_total = _log_data.get("total_worker_type_counts") - _l_pending = _log_data.get("pending_worker_type_counts") - if (_l_total, _l_pending) != ( - self.total_worker_type_counts, - self.pending_worker_type_counts, - ): - log.debug( - "total_worker_type_counts: %s; pending_worker_type_counts: %s", - self.total_worker_type_counts, - self.pending_worker_type_counts, - ) - _log_data["total_worker_type_counts"] = self.total_worker_type_counts - _log_data["pending_worker_type_counts"] = self.pending_worker_type_counts - - new_worker_list = [] - for worker_type in new_worker_map: - cur_workers = self.total_worker_type_counts.get( - worker_type, 0 - ) + self.pending_worker_type_counts.get(worker_type, 0) - if new_worker_map[worker_type] > cur_workers: - for _i in range(new_worker_map[worker_type] - cur_workers): - # Add worker - new_worker_list.append(worker_type) - - # need_more is to reflect if a manager needs more workers than the current - # unused slots - # If yes, that means the manager needs to turn off some warm workers to serve - # the requests - need_more = len(new_worker_list) > self.total_worker_type_counts["unused"] - # Randomly assign order of newly needed containers... add to spin-up queue. - if len(new_worker_list) > 0: - random.shuffle(new_worker_list) - - return new_worker_list, need_more - - def update_worker_idle(self, worker_type): - """Update the workers' last idle time by worker type""" - log.debug(f"Worker idle since: {self.worker_idle_since}") - self.worker_idle_since[worker_type] = time.time() - - def put_worker(self, worker): - """Adds worker to the list of waiting workers""" - worker_type = self.worker_types[worker] - - if worker_type not in self.worker_queues: - self.worker_queues[worker_type] = Queue() - - self.ready_worker_type_counts[worker_type] += 1 - self.worker_queues[worker_type].put(worker) - - def get_worker(self, worker_type): - """Get a task and reduce the # of worker for that type by 1. - Raises queue.Empty if empty - """ - worker = self.worker_queues[worker_type].get_nowait() - self.ready_worker_type_counts[worker_type] -= 1 - return worker - - def get_worker_counts(self): - """Returns just the dict of worker_type and counts""" - return self.total_worker_type_counts - - def ready_worker_count(self): - return sum(self.ready_worker_type_counts.values()) - - def advertisement(self): - """ - Manager capacity advertisement to interchange. - - The advertisement includes two parts: - - One is the read_worker_type_counts, which reflects the capacity of different - types of containers on the manager. - - The other is the total number of workers of each type. This includes all the - pending workers and to_die workers when advertising. We need this "total" - advertisement because we use killer task mechanisms to kill a worker. When a - manager is advertising, there may be some killer tasks in queue, and we want to - ensure that the manager does not over-advertise its actual capacity. Instead, - let the interchange decide if it is sending too many tasks to the manager. - """ - ads = {"total": {}, "free": {}} - total = dict(self.total_worker_type_counts) - for worker_type in self.pending_worker_type_counts: - total[worker_type] = ( - total.get(worker_type, 0) - + self.pending_worker_type_counts[worker_type] - - self.to_die_count.get(worker_type, 0) - ) - ads["total"].update(total) - ads["free"].update(self.ready_worker_type_counts) - return ads diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py deleted file mode 100644 index aac200d55..000000000 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py +++ /dev/null @@ -1,201 +0,0 @@ -#!/usr/bin/env python3 - -from __future__ import annotations - -import ipaddress -import logging -import time - -import dill -import zmq -from globus_compute_endpoint.engines.high_throughput.messages import Message - -log = logging.getLogger(__name__) - - -def _zmq_canonicalize_address(addr: str | int) -> str: - try: - ip = ipaddress.ip_address(addr) - except ValueError: - # Not a valid IPv4 or IPv6 address - if isinstance(addr, int): - # If it was an integer, then it's just plain invalid - raise - - # Otherwise, it was likely a hostname; let another layer deal with it - return addr - - if ip.version == 4: - return str(ip) # like "12.34.56.78" - elif ip.version == 6: - return f"[{ip}]" # like "[::1]" - - -def _zmq_create_socket_port(context: zmq.Context, ip_address: str | int, port_range): - """ - Utility method with logic shared by all the pipes - """ - sock = context.socket(zmq.DEALER) - sock.set_hwm(0) - # This option should work for both IPv4 and IPv6...? - # May not work until Parsl is updated? - sock.setsockopt(zmq.IPV6, True) - - port = sock.bind_to_random_port( - f"tcp://{_zmq_canonicalize_address(ip_address)}", - min_port=port_range[0], - max_port=port_range[1], - ) - return sock, port - - -class CommandClient: - """CommandClient""" - - def __init__(self, ip_address, port_range): - """ - Parameters - ---------- - - ip_address: str - IP address of the client (where Parsl runs) - port_range: tuple(int, int) - Port range for the comms between client and interchange - - """ - - self.context = zmq.Context() - self.zmq_socket, self.port = _zmq_create_socket_port( - self.context, ip_address, port_range - ) - - def run(self, message): - """This function needs to be fast at the same time aware of the possibility of - ZMQ pipes overflowing. - - The timeout increases slowly if contention is detected on ZMQ pipes. - We could set copy=False and get slightly better latency but this results - in ZMQ sockets reaching a broken state once there are ~10k tasks in flight. - This issue can be magnified if each the serialized buffer itself is larger. - """ - self.zmq_socket.send(message.pack(), copy=True) - reply = self.zmq_socket.recv() - return Message.unpack(reply) - - def close(self): - self.zmq_socket.close() - self.context.term() - - -class TasksOutgoing: - """Outgoing task queue from the Engine to the Interchange""" - - def __init__(self, ip_address, port_range): - """ - Parameters - ---------- - - ip_address: str - IP address of the client (where Parsl runs) - port_range: tuple(int, int) - Port range for the comms between client and interchange - - """ - self.context = zmq.Context() - self.zmq_socket, self.port = _zmq_create_socket_port( - self.context, ip_address, port_range - ) - self.poller = zmq.Poller() - self.poller.register(self.zmq_socket, zmq.POLLOUT) - - def put(self, message, max_timeout=1000): - """This function needs to be fast at the same time aware of the possibility of - ZMQ pipes overflowing. - - The timeout increases slowly if contention is detected on ZMQ pipes. - We could set copy=False and get slightly better latency but this results - in ZMQ sockets reaching a broken state once there are ~10k tasks in flight. - This issue can be magnified if each the serialized buffer itself is larger. - - Parameters - ---------- - - message : py object - Python object to send - max_timeout : int - Max timeout in milliseconds that we will wait for before raising an - exception - - Raises - ------ - - zmq.EAGAIN if the send failed. - - """ - timeout_ms = 0 - current_wait = 0 - while current_wait < max_timeout: - socks = dict(self.poller.poll(timeout=timeout_ms)) - if self.zmq_socket in socks and socks[self.zmq_socket] == zmq.POLLOUT: - # The copy option adds latency but reduces the risk of ZMQ overflow - self.zmq_socket.send(message, copy=True) - return - else: - timeout_ms += 1 - log.debug( - "Not sending due to full zmq pipe, timeout: {} ms".format( - timeout_ms - ) - ) - current_wait += timeout_ms - - # Send has failed. - log.debug(f"Remote side has been unresponsive for {current_wait}") - raise zmq.error.Again - - def close(self): - self.zmq_socket.close() - self.context.term() - - -class ResultsIncoming: - """Incoming results queue from the Interchange to the Engine""" - - def __init__(self, ip_address, port_range): - """ - Parameters - ---------- - - ip_address: str - IP address of the client (where Parsl runs) - port_range: tuple(int, int) - Port range for the comms between client and interchange - - """ - self.context = zmq.Context() - self.results_receiver, self.port = _zmq_create_socket_port( - self.context, ip_address, port_range - ) - - def get(self, block=True, timeout=None): - block_messages = self.results_receiver.recv() - try: - res = dill.loads(block_messages) - except dill.UnpicklingError: - try: - res = Message.unpack(block_messages) - except Exception: - log.exception( - "Message in results queue is not pickle/Message formatted: %s", - block_messages, - ) - return res - - def request_close(self): - status = self.results_receiver.send(dill.dumps(None)) - time.sleep(0.1) - return status - - def close(self): - self.results_receiver.close() - self.context.term() diff --git a/compute_endpoint/globus_compute_endpoint/executors/high_throughput/__init__.py b/compute_endpoint/globus_compute_endpoint/executors/high_throughput/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/compute_endpoint/globus_compute_endpoint/executors/high_throughput/executor.py b/compute_endpoint/globus_compute_endpoint/executors/high_throughput/executor.py deleted file mode 100644 index 262686017..000000000 --- a/compute_endpoint/globus_compute_endpoint/executors/high_throughput/executor.py +++ /dev/null @@ -1,18 +0,0 @@ -import warnings - -from globus_compute_endpoint.engines import HighThroughputEngine - -warnings.warn( - f"{__name__} is deprecated. Please import from globus_compute_sdk.engines instead.", - category=DeprecationWarning, - stacklevel=2, -) - - -def HighThroughputExecutor(*args, **kwargs) -> HighThroughputEngine: - warnings.warn( - "HighThroughputExecutor is deprecated. Please use GlobusComputeEngine instead.", - category=DeprecationWarning, - stacklevel=2, - ) - return HighThroughputEngine(*args, **kwargs) diff --git a/compute_endpoint/tests/integration/endpoint/endpoint/test_messages_compat.py b/compute_endpoint/tests/integration/endpoint/endpoint/test_messages_compat.py deleted file mode 100644 index e60f18b45..000000000 --- a/compute_endpoint/tests/integration/endpoint/endpoint/test_messages_compat.py +++ /dev/null @@ -1,131 +0,0 @@ -import pickle -import uuid - -import pytest -from globus_compute_common.messagepack import unpack -from globus_compute_common.messagepack.message_types import Container, ContainerImage -from globus_compute_common.messagepack.message_types import ( - EPStatusReport as OutgoingEPStatusReport, -) -from globus_compute_common.messagepack.message_types import Task as OutgoingTask -from globus_compute_common.messagepack.message_types import TaskTransition -from globus_compute_common.tasks.constants import ActorName, TaskState -from globus_compute_endpoint.endpoint.messages_compat import ( - convert_to_internaltask, - try_convert_to_messagepack, -) -from globus_compute_endpoint.engines.high_throughput.messages import ( - EPStatusReport as InternalEPStatusReport, -) -from globus_compute_endpoint.engines.high_throughput.messages import ( - Message as InternalMessage, -) -from globus_compute_endpoint.engines.high_throughput.messages import ( - Task as InternalTask, -) - - -def test_ep_status_report_conversion(): - ep_id = uuid.uuid4() - global_state = {"looking": "good"} - task_statuses = { - "1": [ - TaskTransition( - timestamp=1, - state=TaskState.EXEC_END, - actor=ActorName.INTERCHANGE, - ) - ], - "2": [ - TaskTransition( - timestamp=1, - state=TaskState.EXEC_END, - actor=ActorName.INTERCHANGE, - ) - ], - } - - internal = InternalEPStatusReport(str(ep_id), global_state, task_statuses) - message = pickle.dumps(internal) - - outgoing = try_convert_to_messagepack(message) - external = unpack(outgoing) - - assert isinstance(external, OutgoingEPStatusReport) - assert external.endpoint_id == ep_id - assert external.global_state == global_state - assert external.task_statuses == task_statuses - - -def test_external_task_to_internal_task(randomstring): - task_id = uuid.uuid4() - task_buffer = b"task_buffer" - container_type = randomstring() - location = randomstring() - - external = OutgoingTask( - task_id=task_id, - container=Container( - container_id=uuid.uuid4(), - name="", - images=[ - ContainerImage( - image_type=container_type, - location=location, - created_at=0, - modified_at=0, - ) - ], - ), - task_buffer=task_buffer, - ) - - incoming = convert_to_internaltask(external, container_type) - internal = InternalMessage.unpack(incoming) - - assert isinstance(internal, InternalTask) - assert internal.task_id == str(task_id) - assert internal.container_id == location - assert internal.task_buffer == task_buffer - - -def test_external_task_without_container_id_converts_to_RAW(): - task_id = uuid.uuid4() - task_buffer = b"task_buffer" - - external = OutgoingTask(task_id=task_id, container_id=None, task_buffer=task_buffer) - - incoming = convert_to_internaltask(external, None) - internal = InternalMessage.unpack(incoming) - - assert isinstance(internal, InternalTask) - assert internal.task_id == str(task_id) - assert internal.container_id == "RAW" - assert internal.task_buffer == task_buffer - - -@pytest.mark.parametrize( - "packed_result", - [ - [ - ( - b'\x01{"message_type":"result","data":{"task_id":' - b'"1aa3202c-336d-4c43-9a6e-98711add151d","data":"abc 123",' - b'"error_details":{"code":"err_code","user_message":"msg"},' - b'"task_statuses":[]}}' - ), - "result", - "1aa3202c-336d-4c43-9a6e-98711add151d", - "abc 123", - "err_code", - ], - ], -) -def test_unpack_result_without_details(packed_result): - raw, result, task_id, data, err = packed_result - unpacked = unpack(raw) - assert unpacked.message_type == result - assert isinstance(unpacked.task_id, uuid.UUID) - assert str(unpacked.task_id) == task_id - assert unpacked.data == data - assert unpacked.error_details.code == err diff --git a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/__init__.py b/compute_endpoint/tests/integration/endpoint/executors/high_throughput/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_htex_regression.py b/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_htex_regression.py deleted file mode 100644 index 5f5c3cfc2..000000000 --- a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_htex_regression.py +++ /dev/null @@ -1,62 +0,0 @@ -import queue -import random -import uuid - -import pytest -from globus_compute_common import messagepack -from globus_compute_endpoint.engines import HighThroughputEngine -from tests.utils import double - - -@pytest.fixture -def engine(tmp_path, htex_warns): - ep_id = uuid.uuid4() - engine = HighThroughputEngine( - label="HTEXEngine", - heartbeat_period=1, - worker_debug=True, - ) - results_queue = queue.Queue() - engine.start(endpoint_id=ep_id, run_dir=tmp_path, results_passthrough=results_queue) - - yield engine - engine.shutdown() - - -def test_engine_submit(engine, serde, task_uuid, ez_pack_task): - q = engine.results_passthrough - task_arg = random.randint(1, 1000) - task_bytes = ez_pack_task(double, task_arg) - resource_spec = {} - future = engine.submit( - str(task_uuid), task_bytes, resource_specification=resource_spec - ) - packed_result = future.result() - - # Confirm that the future got the right answer - assert isinstance(packed_result, bytes) - result = messagepack.unpack(packed_result) - assert isinstance(result, messagepack.message_types.Result) - assert result.task_id == task_uuid - - # Confirm that the same result got back though the queue - for _i in range(10): - q_msg = q.get(timeout=5) - assert isinstance(q_msg, dict) - - packed_result_q = q_msg["message"] - result = messagepack.unpack(packed_result_q) - # Handle a sneaky EPStatusReport that popped in ahead of the result - if isinstance(result, messagepack.message_types.EPStatusReport): - continue - - # At this point the message should be the result - assert ( - packed_result == packed_result_q - ), "Result from passthrough_q and future should match" - - assert result.task_id == task_uuid - final_result = serde.deserialize(result.data) - expected = task_arg * 2 - assert final_result == expected, f"Expected {expected}, but got: {final_result}" - break diff --git a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_manager.py b/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_manager.py deleted file mode 100644 index 0cd7b0f38..000000000 --- a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_manager.py +++ /dev/null @@ -1,60 +0,0 @@ -import os -import pickle -import queue -import shutil -import subprocess - -import pytest -from globus_compute_endpoint.engines.high_throughput.manager import Manager -from globus_compute_endpoint.engines.high_throughput.messages import Task - -_MOCK_BASE = "globus_compute_endpoint.engines.high_throughput.manager." - - -class TestManager: - @pytest.fixture(autouse=True) - def test_setup_teardown(self): - os.makedirs(os.path.join(os.getcwd(), "mock_uid")) - yield - shutil.rmtree(os.path.join(os.getcwd(), "mock_uid")) - - def test_remove_worker_init(self, mocker): - # zmq is being mocked here because it was making tests hang - mocker.patch(f"{_MOCK_BASE}zmq.Context") # noqa: E501 - - manager = Manager(logdir="./", uid="mock_uid") - manager.worker_map.to_die_count["RAW"] = 0 - manager.task_queues["RAW"] = queue.Queue() - - manager.remove_worker_init("RAW") - task = manager.task_queues["RAW"].get() - assert isinstance(task, Task) - assert task.task_id == "KILL" - assert task.task_buffer == "KILL" - - def test_poll_funcx_task_socket(self, mocker): - # zmq is being mocked here because it was making tests hang - mocker.patch(f"{_MOCK_BASE}zmq.Context") # noqa: E501 - mock_worker_map = mocker.patch(f"{_MOCK_BASE}WorkerMap") - - manager = Manager(logdir="./", uid="mock_uid") - manager.task_queues["RAW"] = queue.Queue() - manager.worker_type = "RAW" - manager.worker_procs["0"] = mocker.Mock(spec=subprocess.Popen) - - manager.funcx_task_socket.recv_multipart.return_value = ( - b"0", - b"REGISTER", - pickle.dumps({"worker_type": "RAW"}), - ) - manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.register_worker.assert_called_with(b"0", "RAW") - - manager.funcx_task_socket.recv_multipart.return_value = ( - b"0", - b"WRKR_DIE", - pickle.dumps(None), - ) - manager.poll_funcx_task_socket(test=True) - mock_worker_map.return_value.remove_worker.assert_called_with(b"0") - assert len(manager.worker_procs) == 0 diff --git a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_worker_map.py b/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_worker_map.py deleted file mode 100644 index ebf0dd9af..000000000 --- a/compute_endpoint/tests/integration/endpoint/executors/high_throughput/test_worker_map.py +++ /dev/null @@ -1,42 +0,0 @@ -import logging -import os - -from globus_compute_endpoint.engines.high_throughput.worker_map import WorkerMap - - -class TestWorkerMap: - def test_add_worker(self, mocker): - mock_popen = mocker.patch( - "globus_compute_endpoint.engines.high_throughput.worker_map.subprocess.Popen" # noqa: E501 - ) - mock_popen.return_value = "proc" - - # Test adding with no accelerators - worker_map = WorkerMap(1, []) - worker = worker_map.add_worker( - worker_id="0", - address="127.0.0.1", - debug=logging.DEBUG, - uid="test1", - logdir=os.getcwd(), - worker_port=50001, - ) - - assert list(worker.keys()) == ["0"] - assert worker["0"] == "proc" - assert worker_map.worker_id_counter == 1 - assert worker_map.available_accelerators is None - - # Test with an accelerator - worker_map = WorkerMap(1, ["0"]) - worker_map.add_worker( - worker_id="1", - address="127.0.0.1", - debug=logging.DEBUG, - uid="test1", - logdir=os.getcwd(), - worker_port=50001, - ) - - last_call = mock_popen.mock_calls[-1] - assert last_call[-1]["env"]["CUDA_VISIBLE_DEVICES"] == "0" diff --git a/compute_endpoint/tests/unit/test_bad_endpoint_config.py b/compute_endpoint/tests/unit/test_bad_endpoint_config.py deleted file mode 100644 index 3e68ea410..000000000 --- a/compute_endpoint/tests/unit/test_bad_endpoint_config.py +++ /dev/null @@ -1,21 +0,0 @@ -from unittest import mock - -import pytest -from globus_compute_endpoint.engines import HighThroughputEngine - -_MOCK_BASE = "globus_compute_endpoint.engines.high_throughput.engine." - - -@pytest.mark.parametrize("address", ("example", "a.b.c.d.e", "*")) -def test_invalid_address(address, htex_warns): - with mock.patch(f"{_MOCK_BASE}log") as mock_log: - with pytest.raises(ValueError): - HighThroughputEngine(address=address) - assert mock_log.critical.called - - -@pytest.mark.parametrize( - "address", ("192.168.64.12", "fe80::e643:4bff:fe61:8f72", "129.114.44.12") -) -def test_valid_address(address, htex_warns): - HighThroughputEngine(address=address) diff --git a/compute_endpoint/tests/unit/test_endpoint_config.py b/compute_endpoint/tests/unit/test_endpoint_config.py index 384b06a09..7f3b14557 100644 --- a/compute_endpoint/tests/unit/test_endpoint_config.py +++ b/compute_endpoint/tests/unit/test_endpoint_config.py @@ -133,32 +133,20 @@ def test_mu_public(public: t.Any): assert c.public is (public is True) -@pytest.mark.parametrize("engine_type", ("GlobusComputeEngine", "HighThroughputEngine")) @pytest.mark.parametrize("strategy", ("simple", {"type": "SimpleStrategy"}, None)) def test_conditional_engine_strategy( - engine_type: str, strategy: t.Union[str, dict, None], config_dict: dict + strategy: t.Union[str, dict, None], config_dict: dict ): - config_dict["engine"]["type"] = engine_type + config_dict["engine"]["type"] = "GlobusComputeEngine" config_dict["engine"]["strategy"] = strategy - config_dict["engine"]["address"] = ( - "::1" if engine_type != "HighThroughputEngine" else "127.0.0.1" - ) - - if engine_type == "GlobusComputeEngine": - if isinstance(strategy, str) or strategy is None: - UserEndpointConfigModel(**config_dict) - elif isinstance(strategy, dict): - with pytest.raises(ValidationError) as pyt_e: - UserEndpointConfigModel(**config_dict) - assert "object is incompatible" in str(pyt_e.value) + config_dict["engine"]["address"] = "::1" - elif engine_type == "HighThroughputEngine": - if isinstance(strategy, dict) or strategy is None: + if isinstance(strategy, str) or strategy is None: + UserEndpointConfigModel(**config_dict) + elif isinstance(strategy, dict): + with pytest.raises(ValidationError) as pyt_e: UserEndpointConfigModel(**config_dict) - elif isinstance(strategy, str): - with pytest.raises(ValidationError) as pyt_e: - UserEndpointConfigModel(**config_dict) - assert "string is incompatible" in str(pyt_e.value) + assert "object is incompatible" in str(pyt_e.value) @pytest.mark.parametrize( diff --git a/compute_endpoint/tests/unit/test_engines.py b/compute_endpoint/tests/unit/test_engines.py index af5618b79..194b7ca18 100644 --- a/compute_endpoint/tests/unit/test_engines.py +++ b/compute_endpoint/tests/unit/test_engines.py @@ -3,7 +3,6 @@ import pathlib import random import time -import typing as t from queue import Queue from unittest import mock @@ -16,7 +15,6 @@ from globus_compute_endpoint.engines import ( GlobusComputeEngine, GlobusMPIEngine, - HighThroughputEngine, ProcessPoolEngine, ThreadPoolEngine, ) @@ -183,12 +181,9 @@ def test_gc_engine_system_failure(ez_pack_task, task_uuid, engine_runner): future.result() -@pytest.mark.parametrize("engine_type", (GlobusComputeEngine, HighThroughputEngine)) -def test_serialized_engine_config_has_provider( - engine_type: t.Type[GlobusComputeEngineBase], -): - loopback = "127.0.0.1" if engine_type != "HighThroughputEngine" else "::1" - ep_config = UserEndpointConfig(executors=[engine_type(address=loopback)]) +def test_serialized_engine_config_has_provider(): + loopback = "::1" + ep_config = UserEndpointConfig(executors=[GlobusComputeEngine(address=loopback)]) res = serialize_config(ep_config) executor = res["executors"][0].get("executor") or res["executors"][0] @@ -396,23 +391,3 @@ def test_gcmpiengine_accepts_resource_specification(task_uuid, randomstring): a, _k = engine.executor.submit.call_args assert spec in a - - -@pytest.mark.parametrize( - ("input", "is_valid"), - ( - [None, False], - ["", False], - ["localhost.1", False], - ["localhost", True], - ["1.2.3.4.5", False], - ["127.0.0.1", True], - ["example.com", True], - ["0:0:0:0:0:0:0:1", True], - ["11111:0:0:0:0:0:0:1", False], - ["::1", True], - ["abc", False], - ), -) -def test_hostname_or_ip_validation(input, is_valid): - assert HighThroughputEngine.is_hostname_or_ip(input) is is_valid diff --git a/compute_endpoint/tests/unit/test_execute_task.py b/compute_endpoint/tests/unit/test_execute_task.py index f30f14776..927e9339f 100644 --- a/compute_endpoint/tests/unit/test_execute_task.py +++ b/compute_endpoint/tests/unit/test_execute_task.py @@ -14,6 +14,16 @@ _MOCK_BASE = "globus_compute_endpoint.engines.helper." +def sleeper(t: float): + import time + + now = start = time.monotonic() + while now - start < t: + time.sleep(0.0001) + now = time.monotonic() + return True + + @pytest.mark.parametrize("run_dir", ("", ".", "./", "../", "tmp", "$HOME")) def test_bad_run_dir(endpoint_uuid, task_uuid, run_dir): with pytest.raises(ValueError): # not absolute @@ -137,3 +147,17 @@ def test_execute_task_with_exception(ez_pack_task, execute_task_runner): assert "dill_version" in result.details assert "endpoint_id" in result.details assert "ZeroDivisionError" in result.data + + +def test_execute_task_timeout(execute_task_runner, task_uuid, ez_pack_task): + task_bytes = ez_pack_task(sleeper, 1) + + with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log: + with mock.patch.dict(os.environ, {"GC_TASK_TIMEOUT": "0.01"}): + packed_result = execute_task_runner(task_bytes) + + result = messagepack.unpack(packed_result) + assert isinstance(result, messagepack.message_types.Result) + assert result.task_id == task_uuid + assert "AppTimeout" in result.data + assert mock_log.exception.called diff --git a/compute_endpoint/tests/unit/test_highthroughputinterchange.py b/compute_endpoint/tests/unit/test_highthroughputinterchange.py deleted file mode 100644 index 7063adf0e..000000000 --- a/compute_endpoint/tests/unit/test_highthroughputinterchange.py +++ /dev/null @@ -1,67 +0,0 @@ -import time -import uuid -from unittest import mock - -import pytest -from globus_compute_common.tasks import TaskState -from globus_compute_endpoint.engines.high_throughput.interchange import ( - Interchange, - starter, -) -from globus_compute_endpoint.engines.high_throughput.messages import Task - -# Work with linter's 88 char limit, and be uniform in this file how we do it -mod_dot_path = "globus_compute_endpoint.engines.high_throughput.interchange" - - -@mock.patch(f"{mod_dot_path}.zmq") -@mock.patch(f"{mod_dot_path}.Interchange.load_config") -class TestHighThroughputInterchange: - def test_migrate_internal_task_status(self, _mzmq, _mfn_conf, tmp_path, mocker): - mock_evt = mock.Mock() - mock_evt.is_set.side_effect = [False, True] # run once, please - task_id = str(uuid.uuid4()) - packed_task = Task(task_id, "RAW", b"").pack() - - ix = Interchange(logdir=tmp_path, worker_ports=(1, 1)) - ix.task_incoming.recv.return_value = packed_task - ix.migrate_tasks_to_internal(mock_evt) - - assert task_id in ix.task_status_deltas - - tt = ix.task_status_deltas[task_id][0] - assert 0 <= time.time_ns() - tt.timestamp < 2000000000, "Expecting a timestamp" - assert tt.state == TaskState.WAITING_FOR_NODES - - def test_start_task_status(self, _mzmq, _mfn_conf, tmp_path, mocker, reset_signals): - mock_evt = mock.Mock() - mock_evt.is_set.side_effect = [False, False, True] # run once, please - task_id = str(uuid.uuid4()) - - mocker.patch(f"{mod_dot_path}.log") - mock_thread = mocker.patch(f"{mod_dot_path}.threading") - mock_thread.Event.return_value = mock_evt - - mock_dispatch = mocker.patch(f"{mod_dot_path}.naive_interchange_task_dispatch") - mock_dispatch.return_value = ({"mgr": [{"task_id": task_id}]}, 0) - - ix = Interchange(logdir=tmp_path, worker_ports=(1, 1)) - ix.strategy = mock.Mock() - ix.start() - - assert task_id in ix.task_status_deltas - - tt = ix.task_status_deltas[task_id][0] - assert 0 <= time.time_ns() - tt.timestamp < 2000000000, "Expecting a timestamp" - assert tt.state == TaskState.WAITING_FOR_LAUNCH - - -def test_starter_sends_sentinel_upon_error(mocker): - q = mocker.Mock() - mock_ix = mocker.patch(f"{mod_dot_path}.Interchange") - mock_ix.side_effect = ArithmeticError - with pytest.raises(ArithmeticError): - starter(q) - q.put.assert_called() - q.close.assert_called() - q.join_thread.assert_called() diff --git a/compute_endpoint/tests/unit/test_htex.py b/compute_endpoint/tests/unit/test_htex.py deleted file mode 100644 index e57b9a72e..000000000 --- a/compute_endpoint/tests/unit/test_htex.py +++ /dev/null @@ -1,106 +0,0 @@ -import queue -import threading -import typing as t -import uuid -from unittest import mock - -import dill -import pytest -from globus_compute_common import messagepack -from globus_compute_endpoint.engines import HighThroughputEngine -from globus_compute_endpoint.engines.high_throughput.messages import ( - Task as InternalTask, -) -from globus_compute_sdk.serialize import ComputeSerializer -from pytest_mock import MockFixture -from tests.utils import double, ez_pack_function, try_assert - - -@pytest.fixture(autouse=True) -def warning_invoked(htex_warns): - yield - - -@pytest.fixture -def htex(tmp_path): - ep_id = uuid.uuid4() - executor = HighThroughputEngine( - address="127.0.0.1", - heartbeat_period=1, - heartbeat_threshold=2, - worker_debug=True, - ) - q = queue.Queue() - executor.start(endpoint_id=ep_id, run_dir=str(tmp_path), results_passthrough=q) - - yield executor - executor.shutdown() - - -def test_engine_submit_container_location( - mocker: MockFixture, htex: HighThroughputEngine, serde: ComputeSerializer -): - engine = htex - task_id = uuid.uuid4() - container_id = uuid.uuid4() - container_type = "singularity" - container_loc = "/path/to/container" - task_body = ez_pack_function(serde, double, (10,), {}) - task_message = messagepack.pack( - messagepack.message_types.Task( - task_id=task_id, - container_id=container_id, - container=messagepack.message_types.Container( - container_id=container_id, - name="RedSolo", - images=[ - messagepack.message_types.ContainerImage( - image_type=container_type, - location=container_loc, - created_at=1699389746.8433976, - modified_at=1699389746.8433977, - ) - ], - ), - task_buffer=task_body, - ) - ) - - mock_put = mocker.patch.object(engine.outgoing_q, "put") - - engine.container_type = container_type - engine.submit(str(task_id), task_message, {}) - - a, _ = mock_put.call_args - unpacked_msg = InternalTask.unpack(a[0]) - assert unpacked_msg.container_id == container_loc - - -@pytest.mark.parametrize("task_id", (str(uuid.uuid4()), None)) -def test_engine_invalid_result_data(task_id: t.Optional[str]): - htex = HighThroughputEngine(address="127.0.0.1") - htex.incoming_q = mock.MagicMock() - htex.results_passthrough = mock.MagicMock() - htex.tasks = mock.MagicMock() - htex.is_alive = True - htex._engine_bad_state = threading.Event() - - result_message = {"task_id": task_id} - htex.incoming_q.get.return_value = [dill.dumps(result_message)] - - queue_mgmt_thread = threading.Thread(target=htex._queue_management_worker) - queue_mgmt_thread.start() - - if task_id: - try_assert(lambda: htex.results_passthrough.put.called) - res = htex.results_passthrough.put.call_args[0][0] - msg = messagepack.unpack(res["message"]) - assert res["task_id"] == task_id - assert f"{task_id} failed to run" in msg.data - else: - try_assert(lambda: htex.incoming_q.get.call_count > 1) - assert not htex.results_passthrough.put.called - - htex.is_alive = False - htex._engine_bad_state.set() - queue_mgmt_thread.join() diff --git a/compute_endpoint/tests/unit/test_htex_facade.py b/compute_endpoint/tests/unit/test_htex_facade.py deleted file mode 100644 index 725f44ea1..000000000 --- a/compute_endpoint/tests/unit/test_htex_facade.py +++ /dev/null @@ -1,23 +0,0 @@ -from globus_compute_endpoint.engines import HighThroughputEngine -from globus_compute_endpoint.executors import HighThroughputExecutor - - -def test_deprecation_notice(mocker): - """Instantiating HTEX should throw a WARNING notice about - deprecation - """ - mock_warn = mocker.patch( - "globus_compute_endpoint.executors.high_throughput.executor.warnings" - ) - HighThroughputExecutor() - assert mock_warn.warn.called - # call_args is weirdly nested - assert "deprecated" in mock_warn.warn.call_args[0][0] - - -def test_htex_returns_engine(htex_warns): - """An instance of HighThroughputExecutor should now return - a HighThroughputEngine object - """ - htex = HighThroughputExecutor() - assert isinstance(htex, HighThroughputEngine) diff --git a/compute_endpoint/tests/unit/test_manager_unit.py b/compute_endpoint/tests/unit/test_manager_unit.py deleted file mode 100644 index ae4767d1e..000000000 --- a/compute_endpoint/tests/unit/test_manager_unit.py +++ /dev/null @@ -1,27 +0,0 @@ -import time -import uuid -from unittest import mock - -from globus_compute_common.tasks import TaskState -from globus_compute_endpoint.engines.high_throughput.manager import Manager -from globus_compute_endpoint.engines.high_throughput.messages import Task - - -@mock.patch("globus_compute_endpoint.engines.high_throughput.manager.zmq") -class TestManager: - def test_task_to_worker_status_change(self, randomstring): - task_type = randomstring() - task_id = str(uuid.uuid4()) - task = Task(task_id, "RAW", b"") - - mgr = Manager(uid="some_uid", worker_type=task_type) - mgr.worker_map = mock.Mock() - mgr.worker_map.get_worker.return_value = "some_work_id" - mgr.task_queues[task_type].put(task) - mgr.send_task_to_worker(task_type) - - assert task_id in mgr.task_status_deltas - - tt = mgr.task_status_deltas[task_id][0] - assert time.time_ns() - tt.timestamp < 2000000000, "Expecting a timestamp" - assert tt.state == TaskState.RUNNING diff --git a/compute_endpoint/tests/unit/test_worker.py b/compute_endpoint/tests/unit/test_worker.py deleted file mode 100644 index 4347dbd6b..000000000 --- a/compute_endpoint/tests/unit/test_worker.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import pickle -import uuid -from unittest import mock - -import pytest -from globus_compute_common import messagepack -from globus_compute_endpoint.engines.high_throughput.messages import Task -from globus_compute_endpoint.engines.high_throughput.worker import Worker - -_MOCK_BASE = "globus_compute_endpoint.engines.high_throughput.worker." - - -def hello_world(): - return "hello world" - - -def failing_function(): - x = {} - return x["foo"] # will fail, but in a "natural" way - - -def large_result(size): - return bytearray(size) - - -def sleeper(t: float): - import time - - now = start = time.monotonic() - while now - start < t: - time.sleep(0.0001) - now = time.monotonic() - return True - - -@pytest.fixture(autouse=True) -def reset_signals_auto(reset_signals): - yield - - -@pytest.fixture -def test_worker(): - with mock.patch(f"{_MOCK_BASE}zmq.Context") as mock_context: - # the worker will receive tasks and send messages on this mock socket - mock_socket = mock.Mock() - mock_context.return_value.socket.return_value = mock_socket - yield Worker("0", "::1", 50001) - - -def test_register_and_kill(test_worker): - # send a kill message on the mock socket - task = Task(task_id="KILL", container_id="RAW", task_buffer="KILL") - test_worker.task_socket.recv_multipart.return_value = ( - pickle.dumps("KILL"), - pickle.dumps("abc"), - task.pack(), - ) - - # calling worker.start begins a while loop, where first a REGISTER - # message is sent out, then the worker receives the KILL task, which - # triggers a WRKR_DIE message to be sent before the while loop exits - - # confirm that it raises a SystemExit because of the kill message - with pytest.raises(SystemExit): - test_worker.start() - - # these 2 calls to send_multipart happen in a sequence - test_worker.task_socket.send_multipart.assert_called() - arglist = test_worker.task_socket.send_multipart.call_args_list - assert len(arglist) == 2, arglist - for x in arglist: - assert len(x[0]) == 1, x - assert len(x[1]) == 0, x - messages = [x[0][0] for x in arglist] - assert all(isinstance(m, list) and len(m) == 2 for m in messages), messages - assert messages[0][0] == b"REGISTER", messages - assert messages[1][0] == b"WRKR_DIE", messages - - -def test_execute_hello_world(test_worker, tmp_path): - task_id = uuid.uuid1() - task_body = test_worker.serializer.serialize((hello_world, (), {})) - internal_task = Task(task_id, "RAW", task_body) - payload = internal_task.pack() - - result = test_worker._worker_execute_task(str(task_id), payload) - assert isinstance(result, dict) - assert "exception" not in result - assert isinstance(result.get("data"), str) - - assert result["data"] == "hello world" - - -def test_execute_failing_function(test_worker): - task_id = uuid.uuid1() - task_body = test_worker.serializer.serialize((failing_function, (), {})) - task_message = Task(task_id, "RAW", task_body).pack() - - with mock.patch(f"{_MOCK_BASE}log") as mock_log: - result = test_worker._worker_execute_task(str(task_id), task_message) - assert isinstance(result, dict) - assert "data" in result - - result = messagepack.unpack(result["data"]) - assert isinstance(result, messagepack.message_types.Result) - assert result.task_id == task_id - - a, _k = mock_log.exception.call_args - assert "Failed to execute task" in a[0] - - # error string contains the KeyError which failed - assert "KeyError" in result.data - assert result.is_error is True - - assert isinstance( - result.error_details, messagepack.message_types.ResultErrorDetails - ) - assert result.error_details.code == "RemoteExecutionError" - assert ( - result.error_details.user_message - == "An error occurred during the execution of this task" - ) - - -def test_app_timeout(test_worker, execute_task_runner, task_uuid, ez_pack_task): - task_bytes = ez_pack_task(sleeper, 1) - - with mock.patch("globus_compute_endpoint.engines.helper.log") as mock_log: - with mock.patch.dict(os.environ, {"GC_TASK_TIMEOUT": "0.01"}): - packed_result = execute_task_runner(task_bytes) - - result = messagepack.unpack(packed_result) - assert isinstance(result, messagepack.message_types.Result) - assert result.task_id == task_uuid - assert "AppTimeout" in result.data - assert mock_log.exception.called