From 37f4181201f9586b93ba8f7dd7739ced86776fb0 Mon Sep 17 00:00:00 2001 From: A Gleeson Date: Thu, 9 Jan 2025 14:03:52 +0000 Subject: [PATCH] refactor: Fixing the Vpc provider in Database copy (#709) Co-authored-by: chiaramapellimt Co-authored-by: Chiara <95863059+chiaramapellimt@users.noreply.github.com> --- dbt_platform_helper/domain/database_copy.py | 12 +- dbt_platform_helper/providers/vpc.py | 57 ++++++++++ dbt_platform_helper/utils/application.py | 23 ++-- dbt_platform_helper/utils/aws.py | 50 -------- .../domain/test_database_copy.py | 35 ++++-- tests/platform_helper/providers/test_vpc.py | 107 ++++++++++++++++++ tests/platform_helper/utils/test_aws.py | 106 +---------------- 7 files changed, 207 insertions(+), 183 deletions(-) create mode 100644 dbt_platform_helper/providers/vpc.py create mode 100644 tests/platform_helper/providers/test_vpc.py diff --git a/dbt_platform_helper/domain/database_copy.py b/dbt_platform_helper/domain/database_copy.py index e19233abe..504e5b4df 100644 --- a/dbt_platform_helper/domain/database_copy.py +++ b/dbt_platform_helper/domain/database_copy.py @@ -11,12 +11,12 @@ from dbt_platform_helper.domain.maintenance_page import MaintenancePage from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.config import ConfigProvider +from dbt_platform_helper.providers.vpc import Vpc +from dbt_platform_helper.providers.vpc import VpcProvider from dbt_platform_helper.utils.application import Application from dbt_platform_helper.utils.application import ApplicationNotFoundException from dbt_platform_helper.utils.application import load_application -from dbt_platform_helper.utils.aws import Vpc from dbt_platform_helper.utils.aws import get_connection_string -from dbt_platform_helper.utils.aws import get_vpc_info_by_name from dbt_platform_helper.utils.aws import wait_for_log_group_to_exist from dbt_platform_helper.utils.messages import abort_with_error @@ -28,7 +28,8 @@ def __init__( database: str, auto_approve: bool = False, load_application: Callable[[str], Application] = load_application, - vpc_config: Callable[[Session, str, str, str], Vpc] = get_vpc_info_by_name, + # TODO We inject VpcProvider as a callable here so that it can be instantiated within the method. To be improved + vpc_provider: Callable[[Session], VpcProvider] = VpcProvider, db_connection_string: Callable[ [Session, str, str, str, Callable], str ] = get_connection_string, @@ -43,7 +44,7 @@ def __init__( self.app = app self.database = database self.auto_approve = auto_approve - self.vpc_config = vpc_config + self.vpc_provider = vpc_provider self.db_connection_string = db_connection_string self.maintenance_page_provider = maintenance_page_provider self.input = input @@ -76,7 +77,8 @@ def _execute_operation(self, is_dump: bool, env: str, vpc_name: str, filename: s env_session = environment.session try: - vpc_config = self.vpc_config(env_session, self.app, env, vpc_name) + vpc_provider = self.vpc_provider(env_session) + vpc_config = vpc_provider.get_vpc_info_by_name(self.app, env, vpc_name) except AWSException as ex: self.abort(str(ex)) diff --git a/dbt_platform_helper/providers/vpc.py b/dbt_platform_helper/providers/vpc.py new file mode 100644 index 000000000..7bdd9eee6 --- /dev/null +++ b/dbt_platform_helper/providers/vpc.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +from dbt_platform_helper.providers.aws import AWSException + + +@dataclass +class Vpc: + subnets: list[str] + security_groups: list[str] + + +class VpcProvider: + def __init__(self, session): + self.ec2_client = session.client("ec2") + self.ec2_resource = session.resource("ec2") + + def get_vpc_info_by_name(self, app: str, env: str, vpc_name: str) -> Vpc: + vpc_response = self.ec2_client.describe_vpcs( + Filters=[{"Name": "tag:Name", "Values": [vpc_name]}] + ) + + matching_vpcs = vpc_response.get("Vpcs", []) + + if not matching_vpcs: + raise AWSException(f"VPC not found for name '{vpc_name}'") + + vpc_id = vpc_response["Vpcs"][0].get("VpcId") + + if not vpc_id: + raise AWSException(f"VPC id not present in vpc '{vpc_name}'") + + vpc = self.ec2_resource.Vpc(vpc_id) + + route_tables = self.ec2_client.describe_route_tables( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["RouteTables"] + + subnets = [] + for route_table in route_tables: + private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route] + if not private_routes: + continue + for association in route_table["Associations"]: + if "SubnetId" in association: + subnet_id = association["SubnetId"] + subnets.append(subnet_id) + + if not subnets: + raise AWSException(f"No private subnets found in vpc '{vpc_name}'") + + tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"} + sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags] + + if not sec_groups: + raise AWSException(f"No matching security groups found in vpc '{vpc_name}'") + + return Vpc(subnets, sec_groups) diff --git a/dbt_platform_helper/utils/application.py b/dbt_platform_helper/utils/application.py index e0909884c..2522bb4d2 100644 --- a/dbt_platform_helper/utils/application.py +++ b/dbt_platform_helper/utils/application.py @@ -1,6 +1,8 @@ import json import os import re +from dataclasses import dataclass +from dataclasses import field from pathlib import Path from typing import Dict @@ -16,16 +18,12 @@ from dbt_platform_helper.utils.messages import abort_with_error +@dataclass class Environment: name: str account_id: str sessions: Dict[str, boto3.Session] - def __init__(self, name: str, account_id: str, sessions: Dict[str, boto3.Session]): - self.name = name - self.account_id = account_id - self.sessions = sessions - @property def session(self): if self.account_id not in self.sessions: @@ -36,24 +34,17 @@ def session(self): return self.sessions[self.account_id] +@dataclass class Service: name: str kind: str - def __init__(self, name: str, kind: str): - self.name = name - self.kind = kind - +@dataclass class Application: name: str - environments: Dict[str, Environment] - services: Dict[str, Service] - - def __init__(self, name: str): - self.name = name - self.environments = {} - self.services = {} + environments: Dict[str, Environment] = field(default_factory=dict) + services: Dict[str, Service] = field(default_factory=dict) def __str__(self): output = f"Application {self.name} with" diff --git a/dbt_platform_helper/utils/aws.py b/dbt_platform_helper/utils/aws.py index b36511961..7131e9110 100644 --- a/dbt_platform_helper/utils/aws.py +++ b/dbt_platform_helper/utils/aws.py @@ -15,7 +15,6 @@ from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE from dbt_platform_helper.platform_exception import PlatformException -from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException from dbt_platform_helper.providers.aws import ImageNotFoundException from dbt_platform_helper.providers.aws import LogGroupNotFoundException @@ -377,55 +376,6 @@ def get_connection_string( return f"postgres://{conn['username']}:{conn['password']}@{conn['host']}:{conn['port']}/{conn['dbname']}" -class Vpc: - def __init__(self, subnets: list[str], security_groups: list[str]): - self.subnets = subnets - self.security_groups = security_groups - - -def get_vpc_info_by_name(session: Session, app: str, env: str, vpc_name: str) -> Vpc: - ec2_client = session.client("ec2") - vpc_response = ec2_client.describe_vpcs(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]) - - matching_vpcs = vpc_response.get("Vpcs", []) - - if not matching_vpcs: - raise AWSException(f"VPC not found for name '{vpc_name}'") - - vpc_id = vpc_response["Vpcs"][0].get("VpcId") - - if not vpc_id: - raise AWSException(f"VPC id not present in vpc '{vpc_name}'") - - ec2_resource = session.resource("ec2") - vpc = ec2_resource.Vpc(vpc_id) - - route_tables = ec2_client.describe_route_tables( - Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] - )["RouteTables"] - - subnets = [] - for route_table in route_tables: - private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route] - if not private_routes: - continue - for association in route_table["Associations"]: - if "SubnetId" in association: - subnet_id = association["SubnetId"] - subnets.append(subnet_id) - - if not subnets: - raise AWSException(f"No private subnets found in vpc '{vpc_name}'") - - tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"} - sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags] - - if not sec_groups: - raise AWSException(f"No matching security groups found in vpc '{vpc_name}'") - - return Vpc(subnets, sec_groups) - - def start_build_extraction(codebuild_client, build_options): response = codebuild_client.start_build(**build_options) return response["build"]["arn"] diff --git a/tests/platform_helper/domain/test_database_copy.py b/tests/platform_helper/domain/test_database_copy.py index 55f031709..61b8b8865 100644 --- a/tests/platform_helper/domain/test_database_copy.py +++ b/tests/platform_helper/domain/test_database_copy.py @@ -8,9 +8,9 @@ from dbt_platform_helper.domain.database_copy import DatabaseCopy from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.config import ConfigProvider +from dbt_platform_helper.providers.vpc import Vpc from dbt_platform_helper.utils.application import Application from dbt_platform_helper.utils.application import ApplicationNotFoundException -from dbt_platform_helper.utils.aws import Vpc class DataCopyMocks: @@ -25,8 +25,12 @@ def __init__(self, app="test-app", env="test-env", acc="12345", vpc=Vpc([], []), self.environment.session.client.return_value = self.client self.vpc = vpc - self.vpc_config = Mock() - self.vpc_config.return_value = vpc + self.vpc_provider = ( + Mock() + ) # this is the callable class so should return a class when called + self.instantiated_vpc_provider = Mock() + self.instantiated_vpc_provider.get_vpc_info_by_name.return_value = self.vpc + self.vpc_provider.return_value = self.instantiated_vpc_provider self.db_connection_string = Mock(return_value="test-db-connection-string") self.maintenance_page_provider = Mock() @@ -38,7 +42,7 @@ def __init__(self, app="test-app", env="test-env", acc="12345", vpc=Vpc([], []), def params(self): return { "load_application": self.load_application, - "vpc_config": self.vpc_config, + "vpc_provider": self.vpc_provider, "db_connection_string": self.db_connection_string, "maintenance_page_provider": self.maintenance_page_provider, "input": self.input, @@ -124,8 +128,9 @@ def test_database_dump(): db_copy.dump(env, vpc_name) mocks.load_application.assert_called_once() - mocks.vpc_config.assert_called_once_with( - mocks.environment.session, app, env, "test-vpc-override" + mocks.vpc_provider.assert_called_once_with(mocks.environment.session) + mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with( + app, env, "test-vpc-override" ) mocks.db_connection_string.assert_called_once_with( mocks.environment.session, app, env, "test-app-test-env-test-db" @@ -165,8 +170,9 @@ def test_database_load_with_response_of_yes(): mocks.load_application.assert_called_once() - mocks.vpc_config.assert_called_once_with( - mocks.environment.session, app, env, "test-vpc-override" + mocks.vpc_provider.assert_called_once_with(mocks.environment.session) + mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with( + app, env, "test-vpc-override" ) mocks.db_connection_string.assert_called_once_with( @@ -212,7 +218,8 @@ def test_database_load_with_response_of_no(): mocks.environment.session.assert_not_called() - mocks.vpc_config.assert_not_called() + mocks.vpc_provider.assert_not_called() + mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_not_called() mocks.db_connection_string.assert_not_called() @@ -228,7 +235,9 @@ def test_database_load_with_response_of_no(): @pytest.mark.parametrize("is_dump", (True, False)) def test_database_dump_handles_vpc_errors(is_dump): mocks = DataCopyMocks() - mocks.vpc_config.side_effect = AWSException("A VPC error occurred") + mocks.instantiated_vpc_provider.get_vpc_info_by_name.side_effect = AWSException( + "A VPC error occurred" + ) db_copy = DatabaseCopy("test-app", "test-db", **mocks.params()) @@ -239,6 +248,7 @@ def test_database_dump_handles_vpc_errors(is_dump): db_copy.load("test-env", "bad-vpc-name") assert exc.value.code == 1 + mocks.vpc_provider.assert_called_once_with(mocks.environment.session) mocks.abort.assert_called_once_with("A VPC error occurred") @@ -536,8 +546,9 @@ def test_database_dump_with_no_vpc_works_in_deploy_repo(fs, is_dump): else: db_copy.load(env, None) - mocks.vpc_config.assert_called_once_with( - mocks.environment.session, "test-app", env, "test-env-vpc" + mocks.vpc_provider.assert_called_once_with(mocks.environment.session) + mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with( + "test-app", env, "test-env-vpc" ) diff --git a/tests/platform_helper/providers/test_vpc.py b/tests/platform_helper/providers/test_vpc.py new file mode 100644 index 000000000..2d8fe505a --- /dev/null +++ b/tests/platform_helper/providers/test_vpc.py @@ -0,0 +1,107 @@ +import pytest +from moto import mock_aws + +from dbt_platform_helper.providers.aws import AWSException +from dbt_platform_helper.providers.vpc import Vpc +from dbt_platform_helper.providers.vpc import VpcProvider +from tests.platform_helper.utils.test_aws import ObjectWithId +from tests.platform_helper.utils.test_aws import mock_vpc_info_session + + +@mock_aws +def test_get_vpc_info_by_name_success(): + mock_session, mock_client, _ = mock_vpc_info_session() + vpc_provider = VpcProvider(mock_session) + + result = vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") + + expected_vpc = Vpc( + subnets=["subnet-private-1", "subnet-private-2"], security_groups=["sg-abc123"] + ) + + mock_client.describe_vpcs.assert_called_once_with( + Filters=[{"Name": "tag:Name", "Values": ["my_vpc"]}] + ) + + assert result.subnets == expected_vpc.subnets + assert result.security_groups == expected_vpc.security_groups + + +@mock_aws +def test_get_vpc_info_by_name_failure_no_matching_vpc(): + mock_session, mock_client, _ = mock_vpc_info_session() + vpc_provider = VpcProvider(mock_session) + + vpc_data = {"Vpcs": []} + mock_client.describe_vpcs.return_value = vpc_data + + with pytest.raises(AWSException) as ex: + vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") + + assert "VPC not found for name 'my_vpc'" in str(ex) + + +@mock_aws +def test_get_vpc_info_by_name_failure_no_vpc_id_in_response(): + mock_session, mock_client, _ = mock_vpc_info_session() + vpc_provider = VpcProvider(mock_session) + + vpc_data = {"Vpcs": [{"Id": "abc123"}]} + mock_client.describe_vpcs.return_value = vpc_data + + with pytest.raises(AWSException) as ex: + vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") + + assert "VPC id not present in vpc 'my_vpc'" in str(ex) + + +@mock_aws +def test_get_vpc_info_by_name_failure_no_private_subnets_in_vpc(): + mock_session, mock_client, _ = mock_vpc_info_session() + vpc_provider = VpcProvider(mock_session) + + mock_client.describe_route_tables.return_value = { + "RouteTables": [ + { + "Associations": [ + { + "Main": True, + "RouteTableId": "rtb-00cbf3c8d611a46b8", + } + ], + "Routes": [ + { + "DestinationCidrBlock": "10.151.0.0/16", + "GatewayId": "local", + "Origin": "CreateRouteTable", + "State": "active", + } + ], + "VpcId": "vpc-010327b71b948b4bc", + "OwnerId": "891377058512", + } + ] + } + + with pytest.raises(AWSException) as ex: + vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") + + assert "No private subnets found in vpc 'my_vpc'" in str(ex) + + +@mock_aws +def test_get_vpc_info_by_name_failure_no_matching_security_groups(): + mock_session, _, mock_vpc = mock_vpc_info_session() + vpc_provider = VpcProvider(mock_session) + + mock_vpc.security_groups.all.return_value = [ + ObjectWithId("sg-abc345", tags=[]), + ObjectWithId("sg-abc567", tags=[{"Key": "Name", "Value": "copilot-other_app-my_env-env"}]), + ObjectWithId("sg-abc456"), + ObjectWithId("sg-abc678", tags=[{"Key": "Name", "Value": "copilot-my_app-other_env-env"}]), + ] + + with pytest.raises(AWSException) as ex: + vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") + + assert "No matching security groups found in vpc 'my_vpc'" in str(ex) diff --git a/tests/platform_helper/utils/test_aws.py b/tests/platform_helper/utils/test_aws.py index 6fbabf28f..2916d96d4 100644 --- a/tests/platform_helper/utils/test_aws.py +++ b/tests/platform_helper/utils/test_aws.py @@ -7,23 +7,19 @@ import boto3 import botocore -from dbt_platform_helper.constants import ( - ALPHANUMERIC_ENVIRONMENT_NAME, - ALPHANUMERIC_SERVICE_NAME, - CLUSTER_NAME_SUFFIX, - HYPHENATED_APPLICATION_NAME, - REFRESH_TOKEN_MESSAGE, - SERVICE_NAME_SUFFIX, -) import pytest from moto import mock_aws -from dbt_platform_helper.providers.aws import AWSException +from dbt_platform_helper.constants import ALPHANUMERIC_ENVIRONMENT_NAME +from dbt_platform_helper.constants import ALPHANUMERIC_SERVICE_NAME +from dbt_platform_helper.constants import CLUSTER_NAME_SUFFIX +from dbt_platform_helper.constants import HYPHENATED_APPLICATION_NAME +from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE +from dbt_platform_helper.constants import SERVICE_NAME_SUFFIX from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException from dbt_platform_helper.providers.aws import LogGroupNotFoundException from dbt_platform_helper.providers.validation import ValidationException from dbt_platform_helper.utils.aws import NoProfileForAccountIdException -from dbt_platform_helper.utils.aws import Vpc from dbt_platform_helper.utils.aws import check_codebase_exists from dbt_platform_helper.utils.aws import get_account_details from dbt_platform_helper.utils.aws import get_aws_session_or_abort @@ -36,7 +32,6 @@ from dbt_platform_helper.utils.aws import get_profile_name_from_account_id from dbt_platform_helper.utils.aws import get_public_repository_arn from dbt_platform_helper.utils.aws import get_ssm_secrets -from dbt_platform_helper.utils.aws import get_vpc_info_by_name from dbt_platform_helper.utils.aws import set_ssm_param from dbt_platform_helper.utils.aws import wait_for_log_group_to_exist from tests.platform_helper.conftest import mock_aws_client @@ -852,95 +847,6 @@ def mock_vpc_info_session(): return mock_session, mock_client, mock_vpc -def test_get_vpc_info_by_name_success(): - mock_session, mock_client, _ = mock_vpc_info_session() - - result = get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") - - expected_vpc = Vpc( - subnets=["subnet-private-1", "subnet-private-2"], security_groups=["sg-abc123"] - ) - - mock_client.describe_vpcs.assert_called_once_with( - Filters=[{"Name": "tag:Name", "Values": ["my_vpc"]}] - ) - - assert result.subnets == expected_vpc.subnets - assert result.security_groups == expected_vpc.security_groups - - -def test_get_vpc_info_by_name_failure_no_matching_vpc(): - mock_session, mock_client, _ = mock_vpc_info_session() - - vpc_data = {"Vpcs": []} - mock_client.describe_vpcs.return_value = vpc_data - - with pytest.raises(AWSException) as ex: - get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") - - assert "VPC not found for name 'my_vpc'" in str(ex) - - -def test_get_vpc_info_by_name_failure_no_vpc_id_in_response(): - mock_session, mock_client, _ = mock_vpc_info_session() - - vpc_data = {"Vpcs": [{"Id": "abc123"}]} - mock_client.describe_vpcs.return_value = vpc_data - - with pytest.raises(AWSException) as ex: - get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") - - assert "VPC id not present in vpc 'my_vpc'" in str(ex) - - -def test_get_vpc_info_by_name_failure_no_private_subnets_in_vpc(): - mock_session, mock_client, mock_vpc = mock_vpc_info_session() - - mock_client.describe_route_tables.return_value = { - "RouteTables": [ - { - "Associations": [ - { - "Main": True, - "RouteTableId": "rtb-00cbf3c8d611a46b8", - } - ], - "Routes": [ - { - "DestinationCidrBlock": "10.151.0.0/16", - "GatewayId": "local", - "Origin": "CreateRouteTable", - "State": "active", - } - ], - "VpcId": "vpc-010327b71b948b4bc", - "OwnerId": "891377058512", - } - ] - } - - with pytest.raises(AWSException) as ex: - get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") - - assert "No private subnets found in vpc 'my_vpc'" in str(ex) - - -def test_get_vpc_info_by_name_failure_no_matching_security_groups(): - mock_session, mock_client, mock_vpc = mock_vpc_info_session() - - mock_vpc.security_groups.all.return_value = [ - ObjectWithId("sg-abc345", tags=[]), - ObjectWithId("sg-abc567", tags=[{"Key": "Name", "Value": "copilot-other_app-my_env-env"}]), - ObjectWithId("sg-abc456"), - ObjectWithId("sg-abc678", tags=[{"Key": "Name", "Value": "copilot-my_app-other_env-env"}]), - ] - - with pytest.raises(AWSException) as ex: - get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") - - assert "No matching security groups found in vpc 'my_vpc'" in str(ex) - - def test_wait_for_log_group_to_exist_success(): log_group_name = "/ecs/test-log-group" mock_client = Mock()