Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restart user endpoint if it dies soon after start command #1312

Merged
merged 4 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ class Config(RepresentationMixin):

endpoint_teardown : str | None
Command or commands to be run during the endpoint shutdown process.

mu_child_ep_grace_period_s : float
If a single-user endpoint dies, and then the multi-user endpoint receives a
start command for that single-user endpoint within this timeframe, the
single-user endpoint is started back up again.
Default: 30 seconds
"""

def __init__(
Expand All @@ -140,6 +146,7 @@ def __init__(
endpoint_setup: str | None = None,
endpoint_teardown: str | None = None,
force_mu_allow_same_user: bool = False,
mu_child_ep_grace_period_s: float = 30,
# Misc info
display_name: str | None = None,
# Logging info
Expand Down Expand Up @@ -176,12 +183,16 @@ def __init__(
self.environment = environment
self.funcx_service_address = funcx_service_address

# Multi-user tuning
self.multi_user = multi_user is True
self.force_mu_allow_same_user = force_mu_allow_same_user is True
self.mu_child_ep_grace_period_s = mu_child_ep_grace_period_s

# Auth
self.allowed_functions = allowed_functions
self.authentication_policy = authentication_policy

# Tuning info
# Single-user tuning
self.heartbeat_period = heartbeat_period
self.heartbeat_threshold = heartbeat_threshold
self.idle_heartbeats_soft = int(max(0, idle_heartbeats_soft))
Expand Down
113 changes: 102 additions & 11 deletions compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from datetime import datetime

import globus_compute_sdk as GC
from cachetools import TTLCache
from pydantic import BaseModel

try:
import pyprctl
Expand Down Expand Up @@ -54,6 +56,29 @@ class InvalidUserError(Exception):
pass


class UserEndpointRecord(BaseModel):
ep_name: str
local_user_info: pwd.struct_passwd
arguments: str

@property
def uid(self) -> int:
return self.local_user_info.pw_uid

@property
def gid(self) -> int:
return self.local_user_info.pw_gid

@property
def uname(self) -> str:
return self.local_user_info.pw_name


T_CMD_START_ARGS = t.Tuple[
pwd.struct_passwd, t.Optional[t.List[str]], t.Optional[t.Dict]
]


class EndpointManager:
def __init__(
self,
Expand All @@ -69,14 +94,19 @@ def __init__(
self._time_to_stop = False
self._kill_event = threading.Event()

self._child_args: dict[int, tuple[int, int, str, str]] = {}
self._children: dict[int, UserEndpointRecord] = {}

self._wait_for_child = False

self._command_queue: queue.SimpleQueue[
tuple[int, BasicProperties, bytes]
] = queue.SimpleQueue()
self._command_stop_event = threading.Event()

self._cached_cmd_start_args: TTLCache[int, T_CMD_START_ARGS] = TTLCache(
maxsize=32768, ttl=config.mu_child_ep_grace_period_s
)

endpoint_uuid = Endpoint.get_or_create_endpoint_uuid(conf_dir, endpoint_uuid)

if not reg_info:
Expand Down Expand Up @@ -200,10 +230,18 @@ def wait_for_children(self):
except ValueError:
rc = -127 # invalid signal number

*_, proc_args = self._child_args.pop(pid, (None, None, None, None))
proc_args = f" [{proc_args}]" if proc_args else ""
try:
uep_record = self._children.pop(pid)
except KeyError:
log.exception(f"unknown child PID {pid}")
uep_record = None

proc_args = f" [{uep_record.arguments}]" if uep_record else ""
if not rc:
log.info(f"Command stopped normally ({pid}){proc_args}")
cmd_start_args = self._cached_cmd_start_args.pop(pid, None)
if not self._time_to_stop and cmd_start_args is not None:
self._revive_child(uep_record, cmd_start_args)
elif rc > 0:
log.warning(f"Command return code: {rc} ({pid}){proc_args}")
elif rc == -127:
Expand All @@ -212,13 +250,42 @@ def wait_for_children(self):
log.warning(
f"Command terminated by signal: {-rc} ({pid}){proc_args}"
)

pid, exit_status_ind = os.waitpid(-1, wait_flags)

except ChildProcessError:
pass
except Exception as e:
log.exception(f"Failed to wait for a child process: {e}")

def _revive_child(
self, uep_record: UserEndpointRecord | None, cmd_start_args: T_CMD_START_ARGS
):
ep_name = uep_record.ep_name if uep_record else "<unknown>"
log.info(
"User EP stopped within grace period; using cached arguments "
f"to start a new instance (name: {ep_name})"
)

try:
cached_rec, args, kwargs = cmd_start_args
updated_rec = pwd.getpwuid(cached_rec.pw_uid)
except Exception as e:
log.warning(
"Unable to update local user information; user EP will not be revived."
f" ({e.__class__.__name__}) {e}"
)
return

try:
self.cmd_start_endpoint(updated_rec, args, kwargs)
except Exception:
log.exception(
f"Unable to execute command: cmd_start_endpoint\n"
f" args: {args}\n"
f" kwargs: {kwargs}"
)

def _install_signal_handlers(self):
signal.signal(signal.SIGTERM, self.request_shutdown)
signal.signal(signal.SIGINT, self.request_shutdown)
Expand Down Expand Up @@ -259,7 +326,8 @@ def start(self):
("Signaling shutdown", signal.SIGTERM),
("Forcibly killing", signal.SIGKILL),
):
for pid, (uid, gid, uname, proc_args) in list(self._child_args.items()):
for pid, rec in self._children.items():
uid, gid, uname, proc_args = rec.uid, rec.gid, rec.uname, rec.arguments
proc_ident = f"PID: {pid}, UID: {uid}, GID: {gid}, User: {uname}"
log.info(f"{msg_prefix} of user endpoint ({proc_ident}) [{proc_args}]")
try:
Expand All @@ -275,7 +343,7 @@ def start(self):
os.setresgid(proc_gid, proc_gid, -1)

deadline = time.time() + 10
while self._child_args and time.time() < deadline:
while self._children and time.time() < deadline:
time.sleep(0.5)
self.wait_for_children()

Expand Down Expand Up @@ -372,6 +440,15 @@ def _event_loop(self):
self._command.ack(d_tag)
continue

try:
local_user_rec = pwd.getpwnam(local_user)
except Exception as e:
log.warning(
f"Invalid or unknown local username. ({e.__class__.__name__}) {e}"
)
self._command.ack(d_tag)
continue

try:
if not (command and valid_method_name_re.match(command)):
raise InvalidCommandError(f"Unknown or invalid command: {command}")
Expand All @@ -380,7 +457,7 @@ def _event_loop(self):
if not command_func:
raise InvalidCommandError(f"Unknown or invalid command: {command}")

command_func(local_user, command_args, command_kwargs)
command_func(local_user_rec, command_args, command_kwargs)
log.info(
f"Command process successfully forked for '{globus_username}'"
f" ('{globus_uuid}')."
Expand All @@ -399,7 +476,7 @@ def _event_loop(self):

def cmd_start_endpoint(
self,
local_username: str,
local_user_rec: pwd.struct_passwd,
args: list[str] | None,
kwargs: dict | None,
):
Expand All @@ -412,9 +489,21 @@ def cmd_start_endpoint(
if not ep_name:
raise InvalidCommandError("Missing endpoint name")

pw_rec = pwd.getpwnam(local_username)
udir, uid, gid = pw_rec.pw_dir, pw_rec.pw_uid, pw_rec.pw_gid
uname = pw_rec.pw_name
for p, r in self._children.items():
if r.ep_name == ep_name:
log.info(
f"User endpoint {ep_name} is already running (pid: {p}); "
"caching arguments in case it's about to shut down"
)
self._cached_cmd_start_args[p] = (local_user_rec, args, kwargs)
return

udir, uid, gid, uname = (
local_user_rec.pw_dir,
local_user_rec.pw_uid,
local_user_rec.pw_gid,
local_user_rec.pw_name,
)

if not self._allow_same_user:
p_uname = self._mu_user.pw_name
Expand Down Expand Up @@ -445,7 +534,9 @@ def cmd_start_endpoint(

if pid > 0:
proc_args_s = f"({uname}, {ep_name}) {' '.join(proc_args)}"
self._child_args[pid] = (uid, gid, local_username, proc_args_s)
self._children[pid] = UserEndpointRecord(
ep_name=ep_name, local_user_info=local_user_rec, arguments=proc_args_s
)
log.info(f"Creating new user endpoint (pid: {pid}) [{proc_args_s}]")
return

Expand Down
2 changes: 2 additions & 0 deletions compute_endpoint/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"pyyaml>=6.0,<7.0",
"jinja2>=3.1.2,<3.2",
"jsonschema>=4.19.0,<4.20",
"cachetools>=5.3.1",
"types-cachetools>=5.3.0.6",
]

TEST_REQUIRES = [
Expand Down
Loading
Loading