diff --git a/compute_endpoint/globus_compute_endpoint/cli.py b/compute_endpoint/globus_compute_endpoint/cli.py index 9f03ec499..fa52f28cb 100644 --- a/compute_endpoint/globus_compute_endpoint/cli.py +++ b/compute_endpoint/globus_compute_endpoint/cli.py @@ -300,6 +300,13 @@ def version_command(): default=False, help="Configure endpoint as multi-user capable", ) +@click.option( + "--high-assurance", + is_flag=True, + default=False, + hidden=True, # Until HA features are complete + help="Configure endpoint as high assurance capable", +) @click.option( "--display-name", help="A human readable display name for the endpoint, if desired", @@ -352,7 +359,7 @@ def version_command(): "--auth-timeout", help=( "How old (in seconds) a login session can be and still be compliant. " - "Makes the auth policy high assurance" + "If set, the auth policy will be high assurance" ), type=click.IntRange(min=0), default=None, @@ -366,6 +373,7 @@ def configure_endpoint( name: str, endpoint_config: str | None, multi_user: bool, + high_assurance: bool, display_name: str | None, auth_policy: str | None, auth_policy_project_id: str | None, @@ -389,20 +397,28 @@ def configure_endpoint( if multi_user and not _has_multi_user: raise ClickException("multi-user endpoints are not supported on this system") - if ( + create_policy = ( auth_policy_project_id is not None or auth_policy_display_name != _AUTH_POLICY_DEFAULT_NAME or auth_policy_description != _AUTH_POLICY_DEFAULT_DESC or allowed_domains is not None or excluded_domains is not None or auth_timeout is not None - ): - if auth_policy: - raise ClickException( - "Cannot specify an existing auth policy and " - "create a new one at the same time" - ) + ) + if create_policy and auth_policy: + raise ClickException( + "Cannot specify an existing auth policy and " + "create a new one at the same time" + ) + elif ( + (not create_policy and not bool(auth_policy)) or not subscription_id + ) and high_assurance: + raise ClickException( + "high-assurance(HA) endpoints require both a HA policy and " + "a HA subscription id" + ) + elif create_policy: app = get_globus_app_with_scopes() ac = ComputeAuthClient(app=app) @@ -433,6 +449,7 @@ def configure_endpoint( ep_dir, endpoint_config, multi_user, + high_assurance, display_name, auth_policy, subscription_id, diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py index a666288ae..ebd6a3df8 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/config/config.py @@ -62,6 +62,7 @@ def __init__( self, *, multi_user: bool = False, + high_assurance: bool = False, display_name: str | None = None, allowed_functions: t.Iterable[UUID_LIKE_T] | None = None, authentication_policy: UUID_LIKE_T | None = None, @@ -76,6 +77,7 @@ def __init__( self.display_name = display_name self.debug = debug is True self.multi_user = multi_user is True + self.high_assurance = high_assurance is True # Connection info and tuning self.amqp_port = amqp_port diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/config/default_config.py b/compute_endpoint/globus_compute_endpoint/endpoint/config/default_config.py index 8cd39642b..4464715a0 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/config/default_config.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/config/default_config.py @@ -4,6 +4,7 @@ config = UserEndpointConfig( display_name=None, # If None, defaults to the endpoint name + high_assurance=False, executors=[ GlobusComputeEngine( provider=LocalProvider( diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py index 99d647911..a1443d8ce 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py @@ -76,6 +76,7 @@ def update_config_file( original_path: pathlib.Path, target_path: pathlib.Path, multi_user: bool, + high_assurance: bool, display_name: str | None, auth_policy: str | None, subscription_id: str | None, @@ -89,6 +90,9 @@ def update_config_file( if auth_policy: config_dict["authentication_policy"] = auth_policy + if high_assurance: + config_dict["high_assurance"] = high_assurance + if multi_user: config_dict["multi_user"] = multi_user config_dict.pop("engine", None) @@ -109,6 +113,7 @@ def init_endpoint_dir( endpoint_dir: pathlib.Path, endpoint_config: pathlib.Path | None = None, multi_user=False, + high_assurance=False, display_name: str | None = None, auth_policy: str | None = None, subscription_id: str | None = None, @@ -143,6 +148,7 @@ def init_endpoint_dir( endpoint_config, config_target_path, multi_user, + high_assurance, display_name, auth_policy, subscription_id, @@ -187,6 +193,7 @@ def configure_endpoint( conf_dir: pathlib.Path, endpoint_config: str | None, multi_user: bool = False, + high_assurance: bool = False, display_name: str | None = None, auth_policy: str | None = None, subscription_id: str | None = None, @@ -202,6 +209,7 @@ def configure_endpoint( conf_dir, templ_conf_path, multi_user, + high_assurance, display_name, auth_policy, subscription_id, @@ -428,6 +436,7 @@ def start_endpoint( allowed_functions=endpoint_config.allowed_functions, auth_policy=endpoint_config.authentication_policy, subscription_id=endpoint_config.subscription_id, + high_assurance=endpoint_config.high_assurance, ) except GlobusAPIError as e: diff --git a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py index 0f6af17ac..0f64a1189 100644 --- a/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py +++ b/compute_endpoint/globus_compute_endpoint/endpoint/endpoint_manager.py @@ -181,6 +181,7 @@ def __init__( auth_policy=config.authentication_policy, subscription_id=config.subscription_id, public=config.public, + high_assurance=config.high_assurance, ) # Mostly to appease mypy, but also a useful text if it ever diff --git a/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint_manager.py b/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint_manager.py index fe53506c7..670dafde8 100644 --- a/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint_manager.py +++ b/compute_endpoint/tests/integration/endpoint/endpoint/test_endpoint_manager.py @@ -59,24 +59,33 @@ def test_double_configure(self): with pytest.raises(Exception, match="ConfigExists"): manager.configure_endpoint(config_dir, None) + @pytest.mark.parametrize("ha", [None, True, False]) @pytest.mark.parametrize("mu", [None, True, False]) - def test_configure_multi_user_existing_config(self, mu): + def test_configure_multi_user_ha_existing_config(self, ha, mu): manager = Endpoint() config_dir = pathlib.Path("/some/path/mock_endpoint") config_file = Endpoint._config_file_path(config_dir) config_copy = str(config_dir.parent / "config2.yaml") - # First, make an entry with multi_user - manager.configure_endpoint(config_dir, None, multi_user=True) + # First, make an entry with multi_user/ha + manager.configure_endpoint( + config_dir, None, multi_user=True, high_assurance=False + ) shutil.move(config_file, config_copy) shutil.rmtree(config_dir) - # Then, modify it with new setting - manager.configure_endpoint(config_dir, config_copy, multi_user=mu) + # Then, modify it with new settings + manager.configure_endpoint( + config_dir, config_copy, multi_user=mu, high_assurance=ha + ) with open(config_file) as f: config_dict = yaml.safe_load(f) assert "multi_user" in config_dict + if ha is True: + assert "high_assurance" in config_dict + else: + assert "high_assurance" not in config_dict os.remove(config_copy) diff --git a/compute_endpoint/tests/unit/conftest.py b/compute_endpoint/tests/unit/conftest.py index cb5ff38e4..f50e278ea 100644 --- a/compute_endpoint/tests/unit/conftest.py +++ b/compute_endpoint/tests/unit/conftest.py @@ -38,6 +38,7 @@ def pytest_configure(config): "local_compute_services": True, "environment": str, "multi_user": False, + "high_assurance": False, "executors": None, } @@ -57,6 +58,7 @@ def pytest_configure(config): "local_compute_services": True, "environment": str, "multi_user": True, + "high_assurance": True, } diff --git a/compute_endpoint/tests/unit/test_cli_behavior.py b/compute_endpoint/tests/unit/test_cli_behavior.py index 112d87c00..8fa9a2b63 100644 --- a/compute_endpoint/tests/unit/test_cli_behavior.py +++ b/compute_endpoint/tests/unit/test_cli_behavior.py @@ -449,6 +449,37 @@ def test_start_ep_display_name_in_config( assert conf_dict["display_name"] == display_name +@pytest.mark.parametrize( + "ha_test_info", + [ + ["ep0", None], + ["ep1", False], + ["ep2", True], + ], +) +def test_start_ep_high_assurance_in_config( + run_line, mock_command_ensure, make_endpoint_dir, ha_test_info +): + ep_name, is_ha = ha_test_info + + conf = mock_command_ensure.endpoint_config_dir / ep_name / "config.yaml" + configure_arg = "" + auth_policy = str(uuid.uuid4()) + sub_id = str(uuid.uuid4()) + if is_ha is not None: + configure_arg = ( + f" --subscription-id {sub_id} --auth-policy " + f"{auth_policy} --high-assurance" + ) + run_line(f"configure {ep_name}{configure_arg}") + + with open(conf) as f: + conf_dict = yaml.safe_load(f) + + if is_ha is True: + assert conf_dict["high_assurance"] == is_ha + + def test_configure_ep_auth_policy_in_config( run_line, mock_command_ensure, make_endpoint_dir ): @@ -1068,6 +1099,73 @@ def test_configure_ep_auth_policy_timeout_sets_ha( ) +@pytest.mark.parametrize( + ("is_ha", "policy_id", "auth_desc", "allowed_domains", "sub_id", "exc_text"), + ( + ( + [True, "pid", None, None, "sub_id", None], + [True, "pid", None, None, "sub_id", None], + [True, "pid", "desc", "globus.org", "sub_id", "Cannot specify an existing"], + [True, "pid", "desc", None, "sub_id", "Cannot specify an existing"], + [True, None, None, None, "sub_id", "require both a HA policy and a HA sub"], + [True, "pid", None, None, None, "require both a HA policy and a HA sub"], + [ + True, + None, + "auth_desc", + "globus.org", + None, + "require both a HA policy and a HA sub", + ], + [False, None, "auth_desc", "globus.org", None, None], + [False, "pid", None, None, None, None], + ) + ), +) +def test_configure_ha_ep_requirements( + mocker, + run_line, + mock_cli_state, + make_endpoint_dir, + ep_name, + mock_app, + mock_auth_client, + is_ha, + policy_id, + auth_desc, + allowed_domains, + sub_id, + exc_text, +): + mock_auth_client.create_policy.return_value = {"policy": {"id": "foo"}} + mock_auth_client.get_projects.return_value = [] + mocker.patch(f"{_MOCK_BASE}create_or_choose_auth_project") + + mock_ep, _ = mock_cli_state + + args = ["configure"] + if is_ha: + args.append("--high-assurance") + if policy_id: + args.append(f"--auth-policy {policy_id}") + if auth_desc: + args.append(f"--auth-policy-description {auth_desc}") + if allowed_domains: + args.append(f"--allowed-domains {allowed_domains}") + if sub_id: + args.append(f"--subscription-id {sub_id}") + + args.append("ep_name") + + line = " ".join(args) + if exc_text: + res = run_line(line, assert_exit_code=1) + assert exc_text in res.stderr + else: + run_line(line) + assert mock_ep.configure_endpoint.called + + @pytest.mark.parametrize( ("delete_cmd", "use_uuid", "exit_code", "delete_done"), [ diff --git a/compute_endpoint/tests/unit/test_endpoint_config.py b/compute_endpoint/tests/unit/test_endpoint_config.py index 384b06a09..598d1afa0 100644 --- a/compute_endpoint/tests/unit/test_endpoint_config.py +++ b/compute_endpoint/tests/unit/test_endpoint_config.py @@ -211,8 +211,8 @@ def test_userconfig_repr_nondefault_kwargs( repr_c = repr(UserEndpointConfig(**kwds)) - if kw == "multi_user": - assert f"{kw}={repr(val)}" not in repr_c, "Multi user *off* by default" + if kw in ["multi_user", "high_assurance"]: + assert f"{kw}={repr(val)}" not in repr_c, "Multi-user and HA *off* by default" else: assert f"{kw}={repr(val)}" in repr_c diff --git a/compute_endpoint/tests/unit/test_endpointmanager_unit.py b/compute_endpoint/tests/unit/test_endpointmanager_unit.py index 52abfca46..46ecc4325 100644 --- a/compute_endpoint/tests/unit/test_endpointmanager_unit.py +++ b/compute_endpoint/tests/unit/test_endpointmanager_unit.py @@ -507,6 +507,7 @@ def test_sends_data_during_registration( "endpoint_id", "metadata", "multi_user", + "high_assurance", "display_name", "allowed_functions", "auth_policy", diff --git a/compute_sdk/globus_compute_sdk/sdk/client.py b/compute_sdk/globus_compute_sdk/sdk/client.py index b707533a1..9cd71d06a 100644 --- a/compute_sdk/globus_compute_sdk/sdk/client.py +++ b/compute_sdk/globus_compute_sdk/sdk/client.py @@ -429,6 +429,7 @@ def register_endpoint( auth_policy: UUID_LIKE_T | None = None, subscription_id: UUID_LIKE_T | None = None, public: bool | None = None, + high_assurance: bool | None = None, ): """Register an endpoint with the Globus Compute service. @@ -442,6 +443,8 @@ def register_endpoint( Endpoint metadata multi_user : bool | None Whether the endpoint supports multiple users + high_assurance : bool | None + Whether the endpoint should be high assurance capable display_name : str | None The display name of the endpoint allowed_functions: list[str | UUID] | None @@ -472,6 +475,7 @@ def register_endpoint( auth_policy=auth_policy, subscription_id=subscription_id, public=public, + high_assurance=high_assurance, ) return r.data diff --git a/compute_sdk/globus_compute_sdk/sdk/web_client.py b/compute_sdk/globus_compute_sdk/sdk/web_client.py index 8edacc1a0..51c33aac7 100644 --- a/compute_sdk/globus_compute_sdk/sdk/web_client.py +++ b/compute_sdk/globus_compute_sdk/sdk/web_client.py @@ -194,6 +194,7 @@ def register_endpoint( subscription_id: t.Optional[UUID_LIKE_T] = None, public: t.Optional[bool] = None, additional_fields: t.Optional[t.Dict[str, t.Any]] = None, + high_assurance: t.Optional[bool] = None, ) -> globus_sdk.GlobusHTTPResponse: data: t.Dict[str, t.Any] = {"endpoint_name": endpoint_name} @@ -218,6 +219,8 @@ def register_endpoint( data["subscription_uuid"] = subscription_id if public is not None: data["public"] = public + if high_assurance is not None: + data["high_assurance"] = high_assurance if additional_fields is not None: data.update(additional_fields)