diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py index ae56ab42e..537c21561 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py @@ -14,6 +14,7 @@ import sys import threading import time +import types import typing as t import uuid from concurrent.futures import Future @@ -75,11 +76,11 @@ def _import_pyprctl(): return pyprctl -def _import_pamhandle(): +def _import_pam() -> types.ModuleType: # Enable conditional import, and create a hook-point for testing to mock - from globus_compute_endpoint.pam import PamHandle + from globus_compute_endpoint import pam - return PamHandle + return pam class UserEndpointRecord(BaseModel): @@ -117,6 +118,14 @@ def __init__( self.conf_dir = conf_dir self._config = config + + # UX - test conditional imports *now*, rather than when a request comes in; + # this gives immediate feedback to an implementing admin if something is awry + if config.pam.enable: + _import_pam() + else: + _import_pyprctl() + self._reload_requested = False self._time_to_stop = False self._kill_event = threading.Event() @@ -815,11 +824,11 @@ def _wrap(msg, *a, **k): sname = self._config.pam.service_name try: logd = _affix_logd(f"PAM ({sname}, {username}): ") - logd("Importing library") - PamHandle = _import_pamhandle() + logd("Importing module") + pam = _import_pam() logd("Creating handle") - with PamHandle(sname, username=username) as pamh: + with pam.PamHandle(sname, username=username) as pamh: logd("Invoking account stage") pamh.pam_acct_mgmt() logd("Creating credentials") @@ -838,12 +847,19 @@ def _wrap(msg, *a, **k): pamh.credentials_delete() logd("Closing handle") - except Exception as e: - log.error(str(e)) # Share (very likely) pamlib error with admin ... + + except pam.PamError as e: + log.error(str(e)) # Share pamlib error with admin ... # ... but be opaque with user. raise PermissionError("see your system administrator") from None + except Exception: + log.exception(f"Unhandled error during PAM session for {username}") + + # Regardless, be opaque with user. + raise PermissionError("see your system administrator") from None + def cmd_start_endpoint( self, user_record: pwd.struct_passwd, diff --git a/compute_endpoint/tests/unit/test_endpointmanager_unit.py b/compute_endpoint/tests/unit/test_endpointmanager_unit.py index 8f6d9d791..95c5d1596 100644 --- a/compute_endpoint/tests/unit/test_endpointmanager_unit.py +++ b/compute_endpoint/tests/unit/test_endpointmanager_unit.py @@ -35,7 +35,6 @@ ResultPublisher, ) from globus_compute_endpoint.endpoint.utils import _redact_url_creds -from globus_compute_endpoint.pam import PamHandle from globus_sdk import GlobusAPIError, NetworkError try: @@ -69,6 +68,11 @@ ) +class MockPamError(Exception): + def __init__(self, *a, **k): + pass + + def mock_ensure_compute_dir(): return pathlib.Path(_mock_localuser_rec.pw_dir) / ".globus_compute" @@ -333,6 +337,55 @@ def create_response( return create_response +def _create_pam_handle_mock(): + try: + # attempt to play nice with systems that do not have PAM installed, and + # rely on those that do to test with spec=PamHandle + from globus_compute_endpoint.pam import PamHandle + + _has_pam = True + except ImportError: + _has_pam = False + + def _create_mock(): + # work with other fixtures (namely fs), that don't like multiple attempts to + # import; do the work once and cache it via closure + while True: + if _has_pam: + yield mock.MagicMock(spec=PamHandle) + else: + yield mock.MagicMock() + + return _create_mock() + + +create_pam_handle_mock = _create_pam_handle_mock() + + +@pytest.fixture +def mock_pamh(): + m = next(create_pam_handle_mock) + m.return_value = m + m.__enter__.return_value = m + yield m + + +@pytest.fixture +def mock_pam(mock_pamh): + with mock.patch(f"{_MOCK_BASE}_import_pam") as m: + m.return_value = m + m.PamHandle = mock_pamh + m.PamError = MockPamError + yield m + + +@pytest.fixture +def mock_ctl(): + with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as m: + m.return_value = m + yield m + + @pytest.mark.parametrize("env", [None, "blar", "local", "production"]) def test_sets_process_title( randomstring, conf_dir, mock_conf, mock_client, mock_setproctitle, env @@ -2058,24 +2111,35 @@ def test_port_is_respected(mocker, mock_client, mock_conf, conf_dir, port): assert mock_update_url_port.call_args[0][1] == port -def test_pam_disabled(conf_dir, mock_conf, mock_ep_uuid, mock_reg_info): - em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info) +@pytest.mark.parametrize( + "fn_name,pam_enable", + ( + ("_import_pam", True), + ("_import_pyprctl", False), + ), +) +def test_conditional_imports_verified_at_init_for_ux( + conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, fn_name, pam_enable +): + mock_conf.pam.enable = pam_enable + with mock.patch(f"{_MOCK_BASE}{fn_name}") as m: + m.side_effect = MemoryError("test induced") + with pytest.raises(MemoryError): + EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info) + + +def test_pam_disabled(conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, mock_pam): + em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info) mock_conf.pam.enable = False - with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as mock_ctl: - mock_ctl.return_value = mock_ctl - with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam: - mock_pam.return_value = mock_pam - with em.do_host_auth("some user name"): - pass + with em.do_host_auth("some user name"): + pass assert not mock_pam.called, "PAM was disable; should *not* attempt PAM" assert mock_ctl.CapState.called, "No PAM? No privileges." assert mock_ctl.set_no_new_privs.called, "No PAM? No privileges." -def test_pam_enabled(conf_dir, mock_conf, mock_ep_uuid, mock_reg_info): - em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info) - +def test_pam_enabled(conf_dir, mock_conf, ep_uuid, mock_reg_info, mock_ctl, mock_pam): def install_next_pamf(): # ensure PAM functions called in appropriate order fns = [ # reversed because we pop() to get each fn @@ -2094,21 +2158,18 @@ def _install_next_test_func(): return _install_next_test_func mock_conf.pam.enable = True - pamh = mock.Mock(spec=PamHandle) - with mock.patch(f"{_MOCK_BASE}_import_pyprctl") as mock_ctl: - mock_ctl.return_value = mock_ctl - with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam: - mock_pam.return_value = mock_pam - mock_pam.return_value.__enter__.return_value = pamh - pamh.pam_acct_mgmt.side_effect = install_next_pamf() - pamh.credentials_establish.side_effect = AssertionError("Out of order") - pamh.pam_open_session.side_effect = AssertionError("Out of order") - pamh.pam_close_session.side_effect = AssertionError("Out of order") - pamh.credentials_delete.side_effect = AssertionError("Out of order") - with em.do_host_auth("some user name"): - assert pamh.pam_open_session.called, "Complete authentication" - assert not pamh.credentials_delete.called, "PAM session *not* over yet" - assert pamh.credentials_delete.called, "PAM session completes" + pamh = mock_pam.PamHandle + pamh.pam_acct_mgmt.side_effect = install_next_pamf() + pamh.credentials_establish.side_effect = AssertionError("Out of order") + pamh.pam_open_session.side_effect = AssertionError("Out of order") + pamh.pam_close_session.side_effect = AssertionError("Out of order") + pamh.credentials_delete.side_effect = AssertionError("Out of order") + + em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info) + with em.do_host_auth("some user name"): + assert pamh.pam_open_session.called, "Complete authentication" + assert not pamh.credentials_delete.called, "PAM session *not* over yet" + assert pamh.credentials_delete.called, "PAM session completes" assert not mock_ctl.CapState.called, "Using PAM; admin manages privs" assert not mock_ctl.set_no_new_privs.called, "Using PAM; admin manages privs" @@ -2124,35 +2185,33 @@ def _install_next_test_func(): "credentials_delete", ), ) +@pytest.mark.parametrize("exc", (MockPamError("test err"), MemoryError("test err"))) def test_pam_error( - mock_log, conf_dir, mock_conf, mock_ep_uuid, mock_reg_info, fn_name, randomstring + mock_log, conf_dir, mock_conf, ep_uuid, mock_reg_info, fn_name, mock_pam, exc ): - em = EndpointManager(conf_dir, mock_ep_uuid, mock_conf, mock_reg_info) - - exc_text = randomstring() - exc = MemoryError(exc_text) + em = EndpointManager(conf_dir, ep_uuid, mock_conf, mock_reg_info) mock_conf.pam.enable = True - pamh = mock.Mock(spec=PamHandle) - with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam: - mock_pam.return_value = mock_pam - mock_pam.return_value.__enter__.return_value = pamh - getattr(pamh, fn_name).side_effect = exc - with pytest.raises(PermissionError) as pyt_e: - with em.do_host_auth("some user name"): - pass + pamh = mock_pam.PamHandle + username = "some username" + getattr(pamh, fn_name).side_effect = exc + with pytest.raises(PermissionError) as pyt_e: + with em.do_host_auth(username): + pass e_str = str(pyt_e.value) assert "PAM" not in e_str, "User-visible exception should be opaque" - assert "see your system administrator" in e_str + assert "see your system administrator" in e_str, "User-visible should have action" - a, _k = mock_log.error.call_args + if not isinstance(exc, MockPamError): + assert mock_log.exception.called, "Admin log should contain entire exception" + a, _k = mock_log.exception.call_args - assert exc_text in a[0], "Admin logs should specific error msg" + assert username in a[0], "Admin log should contain related username" def test_do_auth_change_uid_then_close( - mock_conf_root, successful_exec_from_mocked_root + mock_conf_root, successful_exec_from_mocked_root, mock_pam ): mock_os, *_, em = successful_exec_from_mocked_root @@ -2181,20 +2240,17 @@ def _called(fn_name): return _called mock_conf_root.pam.enable = True - pamh = mock.Mock(spec=PamHandle) fn_opener = set_called() - with mock.patch(f"{_MOCK_BASE}_import_pamhandle") as mock_pam: - mock_pam.return_value = mock_pam - mock_pam.return_value.__enter__.return_value = pamh - pamh.pam_open_session.side_effect = this_func(fn_opener, "pam_open_session") - pamh.pam_close_session.side_effect = AssertionError("Out of order") - mock_os.setresuid.side_effect = AssertionError("Out of order") - mock_os.setresgid.side_effect = AssertionError("Out of order") - mock_os.initgroups.side_effect = AssertionError("Out of order") + pamh = mock_pam.PamHandle + pamh.pam_open_session.side_effect = this_func(fn_opener, "pam_open_session") + pamh.pam_close_session.side_effect = AssertionError("Out of order") + mock_os.setresuid.side_effect = AssertionError("Out of order") + mock_os.setresgid.side_effect = AssertionError("Out of order") + mock_os.initgroups.side_effect = AssertionError("Out of order") - with pytest.raises(SystemExit) as pyexc: - em._event_loop() + with pytest.raises(SystemExit) as pyexc: + em._event_loop() assert pyexc.value.code == _GOOD_EC, "Q&D: verify we exec'ed, based on '+= 1'" assert pamh.pam_close_session.called