diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py index 66e3903df..288af40e9 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py @@ -119,6 +119,12 @@ class Config(RepresentationMixin): endpoint_teardown : str | None Command or commands to be run during the endpoint shutdown process. + + mu_child_ep_grace_period_s : float + If a single-user endpoint dies, and then the multi-user endpoint receives a + start command for that single-user endpoint within this timeframe, the + single-user endpoint is started back up again. + Default: 30 seconds """ def __init__( @@ -140,6 +146,7 @@ def __init__( endpoint_setup: str | None = None, endpoint_teardown: str | None = None, force_mu_allow_same_user: bool = False, + mu_child_ep_grace_period_s: float = 30, # Misc info display_name: str | None = None, # Logging info @@ -176,12 +183,16 @@ def __init__( self.environment = environment self.funcx_service_address = funcx_service_address + # Multi-user tuning self.multi_user = multi_user is True self.force_mu_allow_same_user = force_mu_allow_same_user is True + self.mu_child_ep_grace_period_s = mu_child_ep_grace_period_s + + # Auth self.allowed_functions = allowed_functions self.authentication_policy = authentication_policy - # Tuning info + # Single-user tuning self.heartbeat_period = heartbeat_period self.heartbeat_threshold = heartbeat_threshold self.idle_heartbeats_soft = int(max(0, idle_heartbeats_soft)) diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py index 415f366cb..6b3726e7a 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py @@ -18,6 +18,8 @@ from datetime import datetime import globus_compute_sdk as GC +from cachetools import TTLCache +from pydantic import BaseModel try: import pyprctl @@ -54,6 +56,29 @@ class InvalidUserError(Exception): pass +class UserEndpointRecord(BaseModel): + ep_name: str + local_user_info: pwd.struct_passwd + arguments: str + + @property + def uid(self) -> int: + return self.local_user_info.pw_uid + + @property + def gid(self) -> int: + return self.local_user_info.pw_gid + + @property + def uname(self) -> str: + return self.local_user_info.pw_name + + +T_CMD_START_ARGS = t.Tuple[ + pwd.struct_passwd, t.Optional[t.List[str]], t.Optional[t.Dict] +] + + class EndpointManager: def __init__( self, @@ -69,7 +94,8 @@ def __init__( self._time_to_stop = False self._kill_event = threading.Event() - self._child_args: dict[int, tuple[int, int, str, str]] = {} + self._children: dict[int, UserEndpointRecord] = {} + self._wait_for_child = False self._command_queue: queue.SimpleQueue[ @@ -77,6 +103,10 @@ def __init__( ] = queue.SimpleQueue() self._command_stop_event = threading.Event() + self._cached_cmd_start_args: TTLCache[int, T_CMD_START_ARGS] = TTLCache( + maxsize=32768, ttl=config.mu_child_ep_grace_period_s + ) + endpoint_uuid = Endpoint.get_or_create_endpoint_uuid(conf_dir, endpoint_uuid) if not reg_info: @@ -200,10 +230,18 @@ def wait_for_children(self): except ValueError: rc = -127 # invalid signal number - *_, proc_args = self._child_args.pop(pid, (None, None, None, None)) - proc_args = f" [{proc_args}]" if proc_args else "" + try: + uep_record = self._children.pop(pid) + except KeyError: + log.exception(f"unknown child PID {pid}") + uep_record = None + + proc_args = f" [{uep_record.arguments}]" if uep_record else "" if not rc: log.info(f"Command stopped normally ({pid}){proc_args}") + cmd_start_args = self._cached_cmd_start_args.pop(pid, None) + if not self._time_to_stop and cmd_start_args is not None: + self._revive_child(uep_record, cmd_start_args) elif rc > 0: log.warning(f"Command return code: {rc} ({pid}){proc_args}") elif rc == -127: @@ -212,6 +250,7 @@ def wait_for_children(self): log.warning( f"Command terminated by signal: {-rc} ({pid}){proc_args}" ) + pid, exit_status_ind = os.waitpid(-1, wait_flags) except ChildProcessError: @@ -219,6 +258,34 @@ def wait_for_children(self): except Exception as e: log.exception(f"Failed to wait for a child process: {e}") + def _revive_child( + self, uep_record: UserEndpointRecord | None, cmd_start_args: T_CMD_START_ARGS + ): + ep_name = uep_record.ep_name if uep_record else "" + log.info( + "User EP stopped within grace period; using cached arguments " + f"to start a new instance (name: {ep_name})" + ) + + try: + cached_rec, args, kwargs = cmd_start_args + updated_rec = pwd.getpwuid(cached_rec.pw_uid) + except Exception as e: + log.warning( + "Unable to update local user information; user EP will not be revived." + f" ({e.__class__.__name__}) {e}" + ) + return + + try: + self.cmd_start_endpoint(updated_rec, args, kwargs) + except Exception: + log.exception( + f"Unable to execute command: cmd_start_endpoint\n" + f" args: {args}\n" + f" kwargs: {kwargs}" + ) + def _install_signal_handlers(self): signal.signal(signal.SIGTERM, self.request_shutdown) signal.signal(signal.SIGINT, self.request_shutdown) @@ -259,7 +326,8 @@ def start(self): ("Signaling shutdown", signal.SIGTERM), ("Forcibly killing", signal.SIGKILL), ): - for pid, (uid, gid, uname, proc_args) in list(self._child_args.items()): + for pid, rec in self._children.items(): + uid, gid, uname, proc_args = rec.uid, rec.gid, rec.uname, rec.arguments proc_ident = f"PID: {pid}, UID: {uid}, GID: {gid}, User: {uname}" log.info(f"{msg_prefix} of user endpoint ({proc_ident}) [{proc_args}]") try: @@ -275,7 +343,7 @@ def start(self): os.setresgid(proc_gid, proc_gid, -1) deadline = time.time() + 10 - while self._child_args and time.time() < deadline: + while self._children and time.time() < deadline: time.sleep(0.5) self.wait_for_children() @@ -372,6 +440,15 @@ def _event_loop(self): self._command.ack(d_tag) continue + try: + local_user_rec = pwd.getpwnam(local_user) + except Exception as e: + log.warning( + f"Invalid or unknown local username. ({e.__class__.__name__}) {e}" + ) + self._command.ack(d_tag) + continue + try: if not (command and valid_method_name_re.match(command)): raise InvalidCommandError(f"Unknown or invalid command: {command}") @@ -380,7 +457,7 @@ def _event_loop(self): if not command_func: raise InvalidCommandError(f"Unknown or invalid command: {command}") - command_func(local_user, command_args, command_kwargs) + command_func(local_user_rec, command_args, command_kwargs) log.info( f"Command process successfully forked for '{globus_username}'" f" ('{globus_uuid}')." @@ -399,7 +476,7 @@ def _event_loop(self): def cmd_start_endpoint( self, - local_username: str, + local_user_rec: pwd.struct_passwd, args: list[str] | None, kwargs: dict | None, ): @@ -412,9 +489,21 @@ def cmd_start_endpoint( if not ep_name: raise InvalidCommandError("Missing endpoint name") - pw_rec = pwd.getpwnam(local_username) - udir, uid, gid = pw_rec.pw_dir, pw_rec.pw_uid, pw_rec.pw_gid - uname = pw_rec.pw_name + for p, r in self._children.items(): + if r.ep_name == ep_name: + log.info( + f"User endpoint {ep_name} is already running (pid: {p}); " + "caching arguments in case it's about to shut down" + ) + self._cached_cmd_start_args[p] = (local_user_rec, args, kwargs) + return + + udir, uid, gid, uname = ( + local_user_rec.pw_dir, + local_user_rec.pw_uid, + local_user_rec.pw_gid, + local_user_rec.pw_name, + ) if not self._allow_same_user: p_uname = self._mu_user.pw_name @@ -445,7 +534,9 @@ def cmd_start_endpoint( if pid > 0: proc_args_s = f"({uname}, {ep_name}) {' '.join(proc_args)}" - self._child_args[pid] = (uid, gid, local_username, proc_args_s) + self._children[pid] = UserEndpointRecord( + ep_name=ep_name, local_user_info=local_user_rec, arguments=proc_args_s + ) log.info(f"Creating new user endpoint (pid: {pid}) [{proc_args_s}]") return diff --git a/compute_endpoint/setup.py b/compute_endpoint/setup.py index 008da7a83..eda0a3d1b 100644 --- a/compute_endpoint/setup.py +++ b/compute_endpoint/setup.py @@ -40,6 +40,8 @@ "pyyaml>=6.0,<7.0", "jinja2>=3.1.2,<3.2", "jsonschema>=4.19.0,<4.20", + "cachetools>=5.3.1", + "types-cachetools>=5.3.0.6", ] TEST_REQUIRES = [ diff --git a/compute_endpoint/tests/unit/test_endpointmanager_unit.py b/compute_endpoint/tests/unit/test_endpointmanager_unit.py index 83029c838..42a1eed6e 100644 --- a/compute_endpoint/tests/unit/test_endpointmanager_unit.py +++ b/compute_endpoint/tests/unit/test_endpointmanager_unit.py @@ -324,6 +324,34 @@ def test_handles_invalid_reg_info( EndpointManager(conf_dir, ep_uuid, mock_conf) +def test_records_user_ep_as_running(successful_exec): + mock_os, *_, em = successful_exec + mock_os.fork.return_value = 1 + + em._event_loop() + + uep_rec = em._children.pop(1) + assert uep_rec.ep_name == "some_ep_name" + + +def test_caches_start_cmd_args_if_ep_already_running(successful_exec, mocker): + *_, em = successful_exec + child_pid = random.randrange(1, 32768 + 1) + mock_uep = mocker.MagicMock() + mock_uep.ep_name = "some_ep_name" + em._children[child_pid] = mock_uep + + em._event_loop() + + assert child_pid in em._children + cached_args = em._cached_cmd_start_args.pop(child_pid) + assert cached_args is not None + urec, args, kwargs = cached_args + assert urec == pwd.getpwnam(getpass.getuser()) + assert args == [] + assert kwargs == {"name": "some_ep_name", "user_opts": {"heartbeat": 10}} + + def test_writes_endpoint_uuid(epmanager): conf_dir, _mock_conf, mock_client, _em = epmanager _ep_uuid, mock_gcc = mock_client @@ -397,7 +425,9 @@ def test_children_signaled_at_shutdown( uid, gid, pid = tuple(random.randint(1, 2**30) for _ in range(3)) uname = randomstring() expected.append((uid, gid, uname, "some process command line")) - em._child_args[pid] = expected[-1] + mock_rec = mocker.MagicMock() + mock_rec.uid, mock_rec.gid = uid, gid + em._children[pid] = mock_rec gid_expected_calls = ( a @@ -417,8 +447,8 @@ def test_children_signaled_at_shutdown( ) # test that SIGTERM, *then* SIGKILL sent - killpg_expected_calls = [(pid, signal.SIGTERM) for pid in em._child_args] - killpg_expected_calls.extend((pid, signal.SIGKILL) for pid in em._child_args) + killpg_expected_calls = [(pid, signal.SIGTERM) for pid in em._children] + killpg_expected_calls.extend((pid, signal.SIGKILL) for pid in em._children) em.start() assert em._event_loop.called, "Verify test setup" @@ -435,6 +465,41 @@ def test_children_signaled_at_shutdown( assert killpg_call[0] == exp_args, "Expected SIGTERM, *then* SIGKILL" +def test_restarts_running_endpoint_with_cached_args(epmanager, mocker): + *_, em = epmanager + child_pid = random.randrange(1, 32768 + 1) + args_tup = ( + pwd.getpwnam(getpass.getuser()), + [], + {"name": "some_ep_name", "user_opts": {"heartbeat": 10}}, + ) + + mock_os = mocker.patch(f"{_MOCK_BASE}os") + mock_os.waitpid.side_effect = [(child_pid, -1), (0, -1)] + mock_os.waitstatus_to_exitcode.return_value = 0 + + em._cached_cmd_start_args[child_pid] = args_tup + em.cmd_start_endpoint = mocker.Mock() + + em.wait_for_children() + + assert em.cmd_start_endpoint.call_args.args == args_tup + + +def test_no_cached_args_means_no_restart(epmanager, mocker): + *_, em = epmanager + child_pid = random.randrange(1, 32768 + 1) + + mock_os = mocker.patch(f"{_MOCK_BASE}os") + mock_os.waitpid.side_effect = [(child_pid, -1), (0, -1)] + mock_os.waitstatus_to_exitcode.return_value = -127 + em.cmd_start_endpoint = mocker.Mock() + + em.wait_for_children() + + assert em.cmd_start_endpoint.call_count == 0 + + def test_emits_endpoint_id_if_isatty(mocker, epmanager): mock_log = mocker.patch(f"{_MOCK_BASE}log") *_, em = epmanager @@ -680,6 +745,37 @@ def test_handles_unknown_user_gracefully(mocker, epmanager): assert em._command.ack.called, "Command always ACKed" +def test_handles_unknown_local_username_gracefully(mocker, epmanager): + mock_log = mocker.patch(f"{_MOCK_BASE}log") + conf_dir, mock_conf, mock_client, em = epmanager + + with open("local_user_lookup.json", "w") as f: + json.dump({"a": "a_user"}, f) + + props = pika.BasicProperties( + content_type="application/json", + content_encoding="utf-8", + timestamp=round(time.time()), + expiration="10000", + ) + + pld = { + "globus_uuid": "a", + "globus_username": "a", + } + queue_item = (1, props, json.dumps(pld).encode()) + + mocker.patch(f"{_MOCK_BASE}pwd.getpwnam", side_effect=Exception()) + + em._command_queue = mocker.Mock() + em._command_stop_event.set() + em._command_queue.get.side_effect = [queue_item, queue.Empty()] + em._event_loop() + a = mock_log.warning.call_args[0][0] + assert "Invalid or unknown local user" in a + assert em._command.ack.called, "Command always ACKed" + + @pytest.mark.parametrize( "cmd_name", ("", "_private", "9c", "valid_but_do_not_exist", " ", "a" * 101) ) @@ -705,6 +801,8 @@ def test_handles_invalid_command_gracefully(mocker, epmanager, cmd_name): } queue_item = (1, props, json.dumps(pld).encode()) + mocker.patch(f"{_MOCK_BASE}pwd.getpwnam") + em._command_queue = mocker.Mock() em._command_stop_event.set() em._command_queue.get.side_effect = [queue_item, queue.Empty()] @@ -740,6 +838,8 @@ def test_handles_failed_command(mocker, epmanager): } queue_item = (1, props, json.dumps(pld).encode()) + mocker.patch(f"{_MOCK_BASE}pwd.getpwnam") + em._command_queue = mocker.Mock() em._command_stop_event.set() em._command_queue.get.side_effect = [queue_item, queue.Empty()] @@ -981,7 +1081,7 @@ def test_run_as_same_user_fails_if_admin(successful_exec): em._allow_same_user = False kwargs = {"name": "some_endpoint_name"} with pytest.raises(InvalidUserError) as pyexc: - em.cmd_start_endpoint(pw_rec.pw_name, None, kwargs) + em.cmd_start_endpoint(pw_rec, None, kwargs) assert "UID is same as" in str(pyexc.value) assert "using a non-root user" in str(pyexc.value), "Expected suggested fix"