Skip to content

Commit

Permalink
Teach EP to get allowed_functions via stdin
Browse files Browse the repository at this point in the history
The EP now can receive a list of allowed functions via the `allowed_functions`
key in the JSON dictionary received on `stdin`.  Crucially, if a list is passed
in via `stdin`, it takes precedence over what may or may not be in the
configuration file.  This makes it easier for a parent MEP to ensure that child
UEPs have exactly the same allow list, enabling that both the web-service *and*
the UEP verify that a function is allowed.
  • Loading branch information
khk-globus committed Dec 4, 2024
1 parent 73763ee commit a6656f9
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 5 deletions.
8 changes: 7 additions & 1 deletion compute_endpoint/globus_compute_endpoint/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,10 @@ def _do_start_endpoint(
no_color=state.no_color,
)

_no_fn_list_canary = -15 # an arbitrary random integer; invalid as an allow_list
reg_info = {}
config_str = None
config_str: str | None = None
fn_allow_list: list[str] | None | int = _no_fn_list_canary
if sys.stdin and not (sys.stdin.closed or sys.stdin.isatty()):
try:
stdin_data = json.loads(sys.stdin.read())
Expand All @@ -641,6 +643,7 @@ def _do_start_endpoint(

reg_info = stdin_data.get("amqp_creds", {})
config_str = stdin_data.get("config", None)
fn_allow_list = stdin_data.get("allowed_functions", _no_fn_list_canary)

del stdin_data # clarity for intended scope

Expand All @@ -656,6 +659,9 @@ def _do_start_endpoint(
ep_config = get_config(ep_dir)
del config_str

if fn_allow_list != _no_fn_list_canary:
ep_config.allowed_functions = fn_allow_list

if not state.debug and ep_config.debug:
setup_logging(
logfile=ep_dir / "endpoint.log",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def heartbeat_period(self, val: float | int):

@property
def allowed_functions(self):
if self._allowed_functions:
if self._allowed_functions is not None:
return tuple(map(str, self._allowed_functions))
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1073,10 +1073,11 @@ def cmd_start_endpoint(
self._config, template_str, user_config_schema, user_opts, user_runtime
)
stdin_data_dict = {
"allowed_functions": self._config.allowed_functions,
"amqp_creds": kwargs.get("amqp_creds"),
"config": user_config,
}
stdin_data = json.dumps(stdin_data_dict)
stdin_data = json.dumps(stdin_data_dict, separators=(",", ":"))
exit_code += 1

# Reminder: this is *os*.open, not *open*. Descriptors will not be closed
Expand Down
28 changes: 28 additions & 0 deletions compute_endpoint/tests/unit/test_cli_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,34 @@ def test_start_ep_reads_stdin(
assert reg_info_found == {}


@pytest.mark.parametrize("fn_count", range(-1, 5))
def test_start_ep_stdin_allowed_fns_overrides_conf(
mocker, run_line, mock_cli_state, make_endpoint_dir, ep_name, fn_count
):
if fn_count == -1:
allowed_fns = None
else:
allowed_fns = tuple(str(uuid.uuid4()) for _ in range(fn_count))

conf = UserEndpointConfig(executors=[ThreadPoolEngine])
conf.allowed_functions = [uuid.uuid4() for _ in range(5)] # to be overridden
mock_get_config = mocker.patch(f"{_MOCK_BASE}get_config")
mock_get_config.return_value = conf

mock_sys = mocker.patch(f"{_MOCK_BASE}sys")
mock_sys.stdin.closed = False
mock_sys.stdin.isatty.return_value = False
mock_sys.stdin.read.return_value = json.dumps({"allowed_functions": allowed_fns})

make_endpoint_dir()

run_line(f"start {ep_name}")
mock_ep, _ = mock_cli_state
assert mock_ep.start_endpoint.called
(_, _, found_conf, *_), _k = mock_ep.start_endpoint.call_args
assert found_conf.allowed_functions == allowed_fns, "allowed field not overridden!"


@pytest.mark.parametrize("use_uuid", (True, False))
@mock.patch(f"{_MOCK_BASE}get_config")
def test_stop_endpoint(
Expand Down
27 changes: 25 additions & 2 deletions compute_endpoint/tests/unit/test_endpointmanager_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,8 +2006,7 @@ def test_pipe_size_limit(mocker, mock_log, successful_exec_from_mocked_root, con

conf_str = "v: " + "$" * (conf_size - 3)

# Add 34 bytes for dict keys, etc.
stdin_data_size = conf_size + 34
stdin_data_size = conf_size + 56 # overhead for JSON dict keys, etc.
pipe_buffer_size = 512
# Subtract 256 for hard-coded buffer in-code
is_valid = pipe_buffer_size - 256 - stdin_data_size >= 0
Expand Down Expand Up @@ -2039,6 +2038,30 @@ def _remove_user_config_template(*args, **kwargs):
assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'"


@pytest.mark.parametrize("fn_count", (0, 1, 2, 3, random.randint(4, 100)))
def test_set_uep_allowed_functions(
successful_exec_from_mocked_root, mock_conf_root, fn_count
):
mock_os, *_, em = successful_exec_from_mocked_root

m = mock.Mock()
mock_os.fdopen.return_value.__enter__.return_value = m

fns = [str(uuid.uuid4()) for _ in range(fn_count)]
mock_conf_root.allowed_functions = fns
with mock.patch.object(fcntl, "fcntl", return_value=2**20):
# 2**20 == plenty for test
with pytest.raises(SystemExit) as pyexc:
em._event_loop()

assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'"

(received_stdin,), _k = m.write.call_args
parsed_stdin = json.loads(received_stdin)
assert "allowed_functions" in parsed_stdin, "Even empty list should be stated"
assert parsed_stdin["allowed_functions"] == fns


def test_redirect_stdstreams_to_user_log(
successful_exec_from_mocked_root, conf_dir, command_payload
):
Expand Down

0 comments on commit a6656f9

Please sign in to comment.