diff --git a/changelog.d/20231020_110321_chris_443_ampqs.rst b/changelog.d/20231020_110321_chris_443_ampqs.rst new file mode 100644 index 000000000..a41934270 --- /dev/null +++ b/changelog.d/20231020_110321_chris_443_ampqs.rst @@ -0,0 +1,7 @@ +.. A new scriv changelog fragment. +.. +New Functionality +^^^^^^^^^^^^^^^^^ + +- The ``Executor`` can now be told which port to use to listen to AMQP results, via + either the ``amqp_port`` keyword argument or the ``amqp_port`` property. diff --git a/compute_sdk/globus_compute_sdk/sdk/executor.py b/compute_sdk/globus_compute_sdk/sdk/executor.py index 01599ec9a..a96358c62 100644 --- a/compute_sdk/globus_compute_sdk/sdk/executor.py +++ b/compute_sdk/globus_compute_sdk/sdk/executor.py @@ -120,6 +120,7 @@ def __init__( user_endpoint_config: dict[str, t.Any] | None = None, label: str = "", batch_size: int = 128, + amqp_port: int | None = None, **kwargs, ): """ @@ -140,6 +141,8 @@ def __init__( :param batch_interval: [DEPRECATED; unused] number of seconds to coalesce tasks before submitting upstream :param batch_enabled: [DEPRECATED; unused] whether to batch results + :param amqp_port: Port to use when connecting to results queue. Note that the + Compute web services only support 5671, 5672, and 443. """ deprecated_kwargs = {"batch_interval", "batch_enabled"} for key in kwargs: @@ -164,6 +167,9 @@ def __init__( self._task_group_id: uuid.UUID | None = None # help mypy out self.task_group_id = task_group_id + self._amqp_port: int | None = None + self.amqp_port = amqp_port + self.user_endpoint_config = user_endpoint_config self.label = label @@ -296,6 +302,19 @@ def container_id(self) -> uuid.UUID | None: def container_id(self, c_id: UUID_LIKE_T | None): self._container_id = as_optional_uuid(c_id) + @property + def amqp_port(self) -> int | None: + """ + The port to use when connecting to the result queue. Can be one of 443, 5671, + 5672, or None. If None, the port is assigned by the Compute web services + (typically 5671). + """ + return self._amqp_port + + @amqp_port.setter + def amqp_port(self, p: int | None): + self._amqp_port = p + def _fn_cache_key(self, fn: t.Callable): return fn, self.container_id @@ -616,7 +635,7 @@ def reload_tasks( fut.set_exception(funcx_err) if pending: - self._result_watcher = _ResultWatcher(self) + self._result_watcher = _ResultWatcher(self, port=self.amqp_port) self._result_watcher.watch_for_task_results(pending) self._result_watcher.start() else: @@ -748,14 +767,18 @@ class SubmitGroup(t.NamedTuple): ): # Don't initialize the result watcher unless at least # one batch has been sent - self._result_watcher = _ResultWatcher(self) + self._result_watcher = _ResultWatcher( + self, port=self.amqp_port + ) self._result_watcher.start() try: self._result_watcher.watch_for_task_results(to_watch) except self._result_watcher.__class__.ShuttingDownError: log.debug("Waiting for previous ResultWatcher to shutdown") self._result_watcher.join() - self._result_watcher = _ResultWatcher(self) + self._result_watcher = _ResultWatcher( + self, port=self.amqp_port + ) self._result_watcher.start() self._result_watcher.watch_for_task_results(to_watch) @@ -944,6 +967,7 @@ def __init__( connect_attempt_limit=5, channel_close_window_s=10, channel_close_window_limit=3, + port: int | None = None, ): super().__init__() self.funcx_executor = funcx_executor @@ -988,6 +1012,8 @@ def __init__( # window before giving up and shutting down the thread self.channel_close_window_limit = channel_close_window_limit + self.port = port + def __repr__(self): return "{}<{}; pid={}; fut={:,d}; res={:,d}; qp={}>".format( self.__class__.__name__, @@ -1265,6 +1291,8 @@ def _connect(self) -> pika.SelectConnection: connection_url = res["connection_url"] pika_params = pika.URLParameters(connection_url) + if self.port is not None: + pika_params.port = self.port return pika.SelectConnection( pika_params, on_close_callback=self._on_connection_closed, diff --git a/compute_sdk/tests/unit/test_executor.py b/compute_sdk/tests/unit/test_executor.py index 59578bb56..1e0b94bb7 100644 --- a/compute_sdk/tests/unit/test_executor.py +++ b/compute_sdk/tests/unit/test_executor.py @@ -1089,3 +1089,17 @@ def test_resultwatcher_amqp_acks_in_bulk(): assert not mrw._to_ack assert mrw._channel.basic_ack.call_count == 1 mrw.shutdown() + + +def test_result_queue_watcher_custom_port(mocker, gc_executor): + gcc, gce = gc_executor + rw = _ResultWatcher(gce, port=1234) + gcc.get_result_amqp_url.return_value = { + "queue_prefix": "", + "connection_url": "amqp://some.address:1111", + } + connect = mocker.patch(f"{_MOCK_BASE}pika.SelectConnection") + + rw._connect() + + assert connect.call_args[0][0].port == 1234