Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support alternative AMQP ports #1331

Merged
merged 2 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions changelog.d/20231020_110321_chris_443_ampqs.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. A new scriv changelog fragment.
..
New Functionality
^^^^^^^^^^^^^^^^^

- The ``Executor`` can now be told which port to use to listen to AMQP results, via
either the ``amqp_port`` keyword argument or the ``amqp_port`` property.

- Endpoints can be configured to talk to RMQ over a different port via the ``amqp_port``
configuration option.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class Config(RepresentationMixin):
start command for that single-user endpoint within this timeframe, the
single-user endpoint is started back up again.
Default: 30 seconds

amqp_port : int | None
Port to use for AMQP connections. Note that only 5671, 5672, and 443 are
supported by the Compute web services. If None, the port is assigned by the
services (typically 5671).
Default: None
"""

def __init__(
Expand All @@ -137,6 +143,7 @@ def __init__(
multi_user: bool | None = None,
allowed_functions: list[str] | None = None,
authentication_policy: str | None = None,
amqp_port: int | None = None,
# Tuning info
heartbeat_period=30,
heartbeat_threshold=120,
Expand Down Expand Up @@ -192,6 +199,8 @@ def __init__(
self.allowed_functions = allowed_functions
self.authentication_policy = authentication_policy

self.amqp_port = amqp_port

# Single-user tuning
self.heartbeat_period = heartbeat_period
self.heartbeat_threshold = heartbeat_threshold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ConfigModel(BaseConfigModel):
log_dir: t.Optional[str]
stdout: t.Optional[str]
stderr: t.Optional[str]
amqp_port: t.Optional[int]

_validate_engine = _validate_params("engine")

Expand Down
17 changes: 16 additions & 1 deletion compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from globus_compute_endpoint.endpoint.config.utils import serialize_config
from globus_compute_endpoint.endpoint.interchange import EndpointInterchange
from globus_compute_endpoint.endpoint.result_store import ResultStore
from globus_compute_endpoint.endpoint.utils import _redact_url_creds
from globus_compute_endpoint.endpoint.utils import _redact_url_creds, update_url_port
from globus_compute_endpoint.logging_config import setup_logging
from globus_compute_sdk.sdk.client import Client
from globus_sdk import AuthAPIError, GlobusAPIError, NetworkError
Expand Down Expand Up @@ -437,6 +437,21 @@ def start_endpoint(
log.error("Invalid credential structure")
exit(os.EX_DATAERR)

try:
tq_info, rq_info = (
reg_info["task_queue_info"],
reg_info["result_queue_info"],
)
except KeyError:
log.error("Invalid credential structure")
exit(os.EX_DATAERR)

if endpoint_config.amqp_port is not None:
for q_info in tq_info, rq_info:
q_info["connection_url"] = update_url_port(
q_info["connection_url"], endpoint_config.amqp_port
)

# sanitize passwords in logs
log_reg_info = _redact_url_creds(repr(reg_info))
log.debug(f"Registration information: {log_reg_info}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
from globus_compute_endpoint.endpoint.rabbit_mq.command_queue_subscriber import (
CommandQueueSubscriber,
)
from globus_compute_endpoint.endpoint.utils import _redact_url_creds, is_privileged
from globus_compute_endpoint.endpoint.utils import (
_redact_url_creds,
is_privileged,
update_url_port,
)
from globus_sdk import GlobusAPIError, NetworkError

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -171,6 +175,11 @@ def __init__(
)
exit(os.EX_DATAERR)

if config.amqp_port:
cq_info["connection_url"] = update_url_port(
cq_info["connection_url"], config.amqp_port
)

self._mu_user = pwd.getpwuid(os.getuid())
if config.force_mu_allow_same_user:
self._allow_same_user = True
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import os as _os
import pwd as _pwd
import re as _re
import typing as t
import urllib.parse

try:
import pyprctl as _pyprctl
Expand Down Expand Up @@ -90,3 +93,13 @@ def is_privileged(posix_user=None):
has_privileges |= posix_user.pw_name == "root"
has_privileges |= any(c in proc_caps.effective for c in _MULTI_USER_CAPS)
return has_privileges


def update_url_port(url_string: str, new_port: int | str) -> str:
c_url = urllib.parse.urlparse(url_string)
if c_url.port:
netloc = c_url.netloc.replace(f":{c_url.port}", f":{new_port}")
else:
netloc = c_url.netloc + f":{new_port}"
c_url = c_url._replace(netloc=netloc)
return urllib.parse.urlunparse(c_url)
54 changes: 47 additions & 7 deletions compute_endpoint/tests/unit/test_endpoint_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def _wrapped_umask(new_umask: int | None) -> int:
os.umask(orig_umask)


@pytest.fixture
def mock_reg_info():
yield {
"endpoint_id": str(uuid.uuid4()),
"task_queue_info": {"connection_url": "amqp://some.domain:1234"},
"result_queue_info": {"connection_url": "amqp://some.domain"},
}


@responses.activate
def test_start_endpoint(
mocker,
Expand Down Expand Up @@ -490,15 +499,17 @@ def test_endpoint_get_metadata(mocker):


@pytest.mark.parametrize("env", [None, "blar", "local", "production"])
def test_endpoint_sets_process_title(mocker, fs, randomstring, mock_ep_data, env):
def test_endpoint_sets_process_title(
mocker, fs, randomstring, mock_ep_data, env, mock_reg_info
):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_id = str(uuid.uuid4())
ep_conf.environment = env

orig_proc_title = randomstring()

mock_gcc = mocker.Mock()
mock_gcc.register_endpoint.return_value = {"endpoint_id": ep_id}
mock_gcc.register_endpoint.return_value = {**mock_reg_info, "endpoint_id": ep_id}
mocker.patch(f"{_mock_base}Endpoint.get_funcx_client", return_value=mock_gcc)

mock_spt = mocker.patch(f"{_mock_base}setproctitle")
Expand All @@ -520,19 +531,48 @@ def test_endpoint_sets_process_title(mocker, fs, randomstring, mock_ep_data, env
assert a[0].endswith(f"[{orig_proc_title}]"), "Save original cmdline for debugging"


def test_endpoint_needs_no_client_if_reg_info(mocker, fs, randomstring, mock_ep_data):
@pytest.mark.parametrize("port", [random.randint(0, 65535)])
def test_endpoint_respects_port(mocker, fs, mock_ep_data, port):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_id = str(uuid.uuid4())
ep_conf.amqp_port = port

tq_url = "amqp://some.domain:1234"
rq_url = "amqp://some.domain"

mock_reg_info = {
"endpoint_id": ep_id,
"task_queue_info": {"connection_url": tq_url},
"result_queue_info": {"connection_url": rq_url},
}

mock_update_url_port = mocker.patch(f"{_mock_base}update_url_port")
mock_update_url_port.side_effect = (None, StopIteration("Sentinel"))

with pytest.raises(StopIteration, match="Sentinel"):
ep.start_endpoint(
ep_dir, ep_id, ep_conf, log_to_console, no_color, mock_reg_info
)

assert mock_update_url_port.call_args_list[0] == ((tq_url, port),)
assert mock_update_url_port.call_args_list[1] == ((rq_url, port),)


def test_endpoint_needs_no_client_if_reg_info(
mocker, fs, randomstring, mock_ep_data, mock_reg_info
):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_id = str(uuid.uuid4())

mock_gcc = mocker.Mock()
mock_gcc.register_endpoint.return_value = {"endpoint_id": ep_id}
mock_gcc.register_endpoint.return_value = {**mock_reg_info, "endpoint_id": ep_id}
mock_get_compute_client = mocker.patch(
f"{_mock_base}Endpoint.get_funcx_client", return_value=mock_gcc
)
mock_daemon = mocker.patch(f"{_mock_base}daemon")
mock_epinterchange = mocker.patch(f"{_mock_base}EndpointInterchange")

reg_info = {"endpoint_id": ep_id}
reg_info = {**mock_reg_info, "endpoint_id": ep_id}
ep.start_endpoint(ep_dir, ep_id, ep_conf, log_to_console, no_color, reg_info)

assert mock_epinterchange.called, "Has registration, should start."
Expand Down Expand Up @@ -586,7 +626,7 @@ def test_mu_endpoint_user_ep_sensible_default(tmp_path):
render_config_user_template(ep_dir, {})


def test_always_prints_endpoint_id_to_terminal(mocker, mock_ep_data):
def test_always_prints_endpoint_id_to_terminal(mocker, mock_ep_data, mock_reg_info):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_id = str(uuid.uuid4())

Expand All @@ -598,7 +638,7 @@ def test_always_prints_endpoint_id_to_terminal(mocker, mock_ep_data):

expected_text = f"Starting endpoint; registered ID: {ep_id}"

reg_info = {"endpoint_id": ep_id}
reg_info = {**mock_reg_info, "endpoint_id": ep_id}

mock_sys.stdout.isatty.return_value = True
ep.start_endpoint(ep_dir, ep_id, ep_conf, log_to_console, no_color, reg_info)
Expand Down
12 changes: 12 additions & 0 deletions compute_endpoint/tests/unit/test_endpointmanager_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,3 +1450,15 @@ def test_load_user_config_schema(
assert mock_log.error.called
a, *_ = mock_log.error.call_args
assert "user config schema is not valid JSON" in str(a)


@pytest.mark.parametrize("port", [random.randint(0, 65535)])
def test_port_is_respected(mocker, mock_client, mock_conf, conf_dir, port):
ep_uuid, _ = mock_client
mock_conf.amqp_port = port

mock_update_url_port = mocker.patch(f"{_MOCK_BASE}update_url_port")

EndpointManager(conf_dir, ep_uuid, mock_conf)

assert mock_update_url_port.call_args[0][1] == port
56 changes: 47 additions & 9 deletions compute_endpoint/tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from collections import namedtuple

import pytest
from globus_compute_endpoint.endpoint.utils import _redact_url_creds, is_privileged
from globus_compute_endpoint.endpoint.utils import (
_redact_url_creds,
is_privileged,
update_url_port,
)

try:
import pyprctl
import pyprctl # noqa

_has_pyprctl = True
except AttributeError:
pytest.skip(allow_module_level=True)
_has_pyprctl = False


_MOCK_BASE = "globus_compute_endpoint.endpoint.utils."
Expand Down Expand Up @@ -36,6 +42,9 @@ def test_url_redaction(randomstring):

@pytest.mark.parametrize("uid", (0, 1000))
def test_is_privileged_tests_against_uid(mocker, uid):
if not _has_pyprctl:
pytest.skip()

user = namedtuple("posix_user", "pw_uid,pw_name")(uid, "asdf")
mock_prctl = mocker.patch(f"{_MOCK_BASE}_pyprctl")
mock_prctl.CapState.get_current.return_value.effective = {}
Expand All @@ -45,16 +54,45 @@ def test_is_privileged_tests_against_uid(mocker, uid):

@pytest.mark.parametrize("uname", ("root", "not_root_uname"))
def test_is_privileged_tests_for_root_username(mocker, uname):
if not _has_pyprctl:
pytest.skip()

user = namedtuple("posix_user", "pw_uid,pw_name")(987, uname)
mock_prctl = mocker.patch(f"{_MOCK_BASE}_pyprctl")
mock_prctl.CapState.get_current.return_value.effective = {}

assert is_privileged(user) is bool("root" == uname)


@pytest.mark.parametrize("cap", ({pyprctl.Cap.SYS_ADMIN}, {}))
def test_is_privileged_checks_for_privileges(mocker, cap):
user = namedtuple("posix_user", "pw_uid,pw_name")(987, "asdf")
mock_prctl = mocker.patch(f"{_MOCK_BASE}_pyprctl")
mock_prctl.CapState.get_current.return_value.effective = cap
assert is_privileged(user) is bool(cap)
if _has_pyprctl:

@pytest.mark.parametrize("cap", ({pyprctl.Cap.SYS_ADMIN}, {}))
def test_is_privileged_checks_for_privileges(mocker, cap):
if not _has_pyprctl:
pytest.skip()

user = namedtuple("posix_user", "pw_uid,pw_name")(987, "asdf")
mock_prctl = mocker.patch(f"{_MOCK_BASE}_pyprctl")
mock_prctl.CapState.get_current.return_value.effective = cap
assert is_privileged(user) is bool(cap)


@pytest.mark.parametrize(
"start_url, port, end_url",
[
("amqp://some.domain:1234", 1111, "amqp://some.domain:1111"),
("https://domain.com:4567/homepage", 2222, "https://domain.com:2222/homepage"),
(
"postgres://user:pass@some.domain:5678",
3333,
"postgres://user:pass@some.domain:3333",
),
(
"postgres://user:pass@some.domain/funcx",
4444,
"postgres://user:pass@some.domain:4444/funcx",
),
],
)
def test_update_url_port(start_url, port, end_url):
assert update_url_port(start_url, port) == end_url
Loading
Loading