diff --git a/dbt_platform_helper/domain/database_copy.py b/dbt_platform_helper/domain/database_copy.py index 9f1f97acc..9ac016d61 100644 --- a/dbt_platform_helper/domain/database_copy.py +++ b/dbt_platform_helper/domain/database_copy.py @@ -11,12 +11,12 @@ from dbt_platform_helper.domain.maintenance_page import MaintenancePageProvider from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.config import ConfigProvider -from dbt_platform_helper.providers.vpc import Vpc -from dbt_platform_helper.providers.vpc import VpcProvider from dbt_platform_helper.utils.application import Application from dbt_platform_helper.utils.application import ApplicationNotFoundException from dbt_platform_helper.utils.application import load_application +from dbt_platform_helper.utils.aws import Vpc from dbt_platform_helper.utils.aws import get_connection_string +from dbt_platform_helper.utils.aws import get_vpc_info_by_name from dbt_platform_helper.utils.aws import wait_for_log_group_to_exist from dbt_platform_helper.utils.messages import abort_with_error @@ -28,7 +28,7 @@ def __init__( database: str, auto_approve: bool = False, load_application: Callable[[str], Application] = load_application, - vpc_config: Callable[[Session, str, str, str], Vpc] = VpcProvider.get_vpc_info_by_name, + vpc_config: Callable[[Session, str, str, str], Vpc] = get_vpc_info_by_name, db_connection_string: Callable[ [Session, str, str, str, Callable], str ] = get_connection_string, diff --git a/dbt_platform_helper/providers/vpc.py b/dbt_platform_helper/providers/vpc.py deleted file mode 100644 index 7bdd9eee6..000000000 --- a/dbt_platform_helper/providers/vpc.py +++ /dev/null @@ -1,57 +0,0 @@ -from dataclasses import dataclass - -from dbt_platform_helper.providers.aws import AWSException - - -@dataclass -class Vpc: - subnets: list[str] - security_groups: list[str] - - -class VpcProvider: - def __init__(self, session): - self.ec2_client = session.client("ec2") - self.ec2_resource = session.resource("ec2") - - def get_vpc_info_by_name(self, app: str, env: str, vpc_name: str) -> Vpc: - vpc_response = self.ec2_client.describe_vpcs( - Filters=[{"Name": "tag:Name", "Values": [vpc_name]}] - ) - - matching_vpcs = vpc_response.get("Vpcs", []) - - if not matching_vpcs: - raise AWSException(f"VPC not found for name '{vpc_name}'") - - vpc_id = vpc_response["Vpcs"][0].get("VpcId") - - if not vpc_id: - raise AWSException(f"VPC id not present in vpc '{vpc_name}'") - - vpc = self.ec2_resource.Vpc(vpc_id) - - route_tables = self.ec2_client.describe_route_tables( - Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] - )["RouteTables"] - - subnets = [] - for route_table in route_tables: - private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route] - if not private_routes: - continue - for association in route_table["Associations"]: - if "SubnetId" in association: - subnet_id = association["SubnetId"] - subnets.append(subnet_id) - - if not subnets: - raise AWSException(f"No private subnets found in vpc '{vpc_name}'") - - tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"} - sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags] - - if not sec_groups: - raise AWSException(f"No matching security groups found in vpc '{vpc_name}'") - - return Vpc(subnets, sec_groups) diff --git a/dbt_platform_helper/utils/application.py b/dbt_platform_helper/utils/application.py index 2522bb4d2..e0909884c 100644 --- a/dbt_platform_helper/utils/application.py +++ b/dbt_platform_helper/utils/application.py @@ -1,8 +1,6 @@ import json import os import re -from dataclasses import dataclass -from dataclasses import field from pathlib import Path from typing import Dict @@ -18,12 +16,16 @@ from dbt_platform_helper.utils.messages import abort_with_error -@dataclass class Environment: name: str account_id: str sessions: Dict[str, boto3.Session] + def __init__(self, name: str, account_id: str, sessions: Dict[str, boto3.Session]): + self.name = name + self.account_id = account_id + self.sessions = sessions + @property def session(self): if self.account_id not in self.sessions: @@ -34,17 +36,24 @@ def session(self): return self.sessions[self.account_id] -@dataclass class Service: name: str kind: str + def __init__(self, name: str, kind: str): + self.name = name + self.kind = kind + -@dataclass class Application: name: str - environments: Dict[str, Environment] = field(default_factory=dict) - services: Dict[str, Service] = field(default_factory=dict) + environments: Dict[str, Environment] + services: Dict[str, Service] + + def __init__(self, name: str): + self.name = name + self.environments = {} + self.services = {} def __str__(self): output = f"Application {self.name} with" diff --git a/dbt_platform_helper/utils/aws.py b/dbt_platform_helper/utils/aws.py index 7131e9110..b36511961 100644 --- a/dbt_platform_helper/utils/aws.py +++ b/dbt_platform_helper/utils/aws.py @@ -15,6 +15,7 @@ from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE from dbt_platform_helper.platform_exception import PlatformException +from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException from dbt_platform_helper.providers.aws import ImageNotFoundException from dbt_platform_helper.providers.aws import LogGroupNotFoundException @@ -376,6 +377,55 @@ def get_connection_string( return f"postgres://{conn['username']}:{conn['password']}@{conn['host']}:{conn['port']}/{conn['dbname']}" +class Vpc: + def __init__(self, subnets: list[str], security_groups: list[str]): + self.subnets = subnets + self.security_groups = security_groups + + +def get_vpc_info_by_name(session: Session, app: str, env: str, vpc_name: str) -> Vpc: + ec2_client = session.client("ec2") + vpc_response = ec2_client.describe_vpcs(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]) + + matching_vpcs = vpc_response.get("Vpcs", []) + + if not matching_vpcs: + raise AWSException(f"VPC not found for name '{vpc_name}'") + + vpc_id = vpc_response["Vpcs"][0].get("VpcId") + + if not vpc_id: + raise AWSException(f"VPC id not present in vpc '{vpc_name}'") + + ec2_resource = session.resource("ec2") + vpc = ec2_resource.Vpc(vpc_id) + + route_tables = ec2_client.describe_route_tables( + Filters=[{"Name": "vpc-id", "Values": [vpc_id]}] + )["RouteTables"] + + subnets = [] + for route_table in route_tables: + private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route] + if not private_routes: + continue + for association in route_table["Associations"]: + if "SubnetId" in association: + subnet_id = association["SubnetId"] + subnets.append(subnet_id) + + if not subnets: + raise AWSException(f"No private subnets found in vpc '{vpc_name}'") + + tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"} + sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags] + + if not sec_groups: + raise AWSException(f"No matching security groups found in vpc '{vpc_name}'") + + return Vpc(subnets, sec_groups) + + def start_build_extraction(codebuild_client, build_options): response = codebuild_client.start_build(**build_options) return response["build"]["arn"] diff --git a/tests/platform_helper/domain/test_database_copy.py b/tests/platform_helper/domain/test_database_copy.py index daaa0e8b1..55f031709 100644 --- a/tests/platform_helper/domain/test_database_copy.py +++ b/tests/platform_helper/domain/test_database_copy.py @@ -8,9 +8,9 @@ from dbt_platform_helper.domain.database_copy import DatabaseCopy from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.config import ConfigProvider -from dbt_platform_helper.providers.vpc import Vpc from dbt_platform_helper.utils.application import Application from dbt_platform_helper.utils.application import ApplicationNotFoundException +from dbt_platform_helper.utils.aws import Vpc class DataCopyMocks: diff --git a/tests/platform_helper/providers/test_vpc.py b/tests/platform_helper/providers/test_vpc.py deleted file mode 100644 index 2d8fe505a..000000000 --- a/tests/platform_helper/providers/test_vpc.py +++ /dev/null @@ -1,107 +0,0 @@ -import pytest -from moto import mock_aws - -from dbt_platform_helper.providers.aws import AWSException -from dbt_platform_helper.providers.vpc import Vpc -from dbt_platform_helper.providers.vpc import VpcProvider -from tests.platform_helper.utils.test_aws import ObjectWithId -from tests.platform_helper.utils.test_aws import mock_vpc_info_session - - -@mock_aws -def test_get_vpc_info_by_name_success(): - mock_session, mock_client, _ = mock_vpc_info_session() - vpc_provider = VpcProvider(mock_session) - - result = vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") - - expected_vpc = Vpc( - subnets=["subnet-private-1", "subnet-private-2"], security_groups=["sg-abc123"] - ) - - mock_client.describe_vpcs.assert_called_once_with( - Filters=[{"Name": "tag:Name", "Values": ["my_vpc"]}] - ) - - assert result.subnets == expected_vpc.subnets - assert result.security_groups == expected_vpc.security_groups - - -@mock_aws -def test_get_vpc_info_by_name_failure_no_matching_vpc(): - mock_session, mock_client, _ = mock_vpc_info_session() - vpc_provider = VpcProvider(mock_session) - - vpc_data = {"Vpcs": []} - mock_client.describe_vpcs.return_value = vpc_data - - with pytest.raises(AWSException) as ex: - vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") - - assert "VPC not found for name 'my_vpc'" in str(ex) - - -@mock_aws -def test_get_vpc_info_by_name_failure_no_vpc_id_in_response(): - mock_session, mock_client, _ = mock_vpc_info_session() - vpc_provider = VpcProvider(mock_session) - - vpc_data = {"Vpcs": [{"Id": "abc123"}]} - mock_client.describe_vpcs.return_value = vpc_data - - with pytest.raises(AWSException) as ex: - vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") - - assert "VPC id not present in vpc 'my_vpc'" in str(ex) - - -@mock_aws -def test_get_vpc_info_by_name_failure_no_private_subnets_in_vpc(): - mock_session, mock_client, _ = mock_vpc_info_session() - vpc_provider = VpcProvider(mock_session) - - mock_client.describe_route_tables.return_value = { - "RouteTables": [ - { - "Associations": [ - { - "Main": True, - "RouteTableId": "rtb-00cbf3c8d611a46b8", - } - ], - "Routes": [ - { - "DestinationCidrBlock": "10.151.0.0/16", - "GatewayId": "local", - "Origin": "CreateRouteTable", - "State": "active", - } - ], - "VpcId": "vpc-010327b71b948b4bc", - "OwnerId": "891377058512", - } - ] - } - - with pytest.raises(AWSException) as ex: - vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") - - assert "No private subnets found in vpc 'my_vpc'" in str(ex) - - -@mock_aws -def test_get_vpc_info_by_name_failure_no_matching_security_groups(): - mock_session, _, mock_vpc = mock_vpc_info_session() - vpc_provider = VpcProvider(mock_session) - - mock_vpc.security_groups.all.return_value = [ - ObjectWithId("sg-abc345", tags=[]), - ObjectWithId("sg-abc567", tags=[{"Key": "Name", "Value": "copilot-other_app-my_env-env"}]), - ObjectWithId("sg-abc456"), - ObjectWithId("sg-abc678", tags=[{"Key": "Name", "Value": "copilot-my_app-other_env-env"}]), - ] - - with pytest.raises(AWSException) as ex: - vpc_provider.get_vpc_info_by_name("my_app", "my_env", "my_vpc") - - assert "No matching security groups found in vpc 'my_vpc'" in str(ex) diff --git a/tests/platform_helper/utils/test_aws.py b/tests/platform_helper/utils/test_aws.py index 2916d96d4..6fbabf28f 100644 --- a/tests/platform_helper/utils/test_aws.py +++ b/tests/platform_helper/utils/test_aws.py @@ -7,19 +7,23 @@ import boto3 import botocore +from dbt_platform_helper.constants import ( + ALPHANUMERIC_ENVIRONMENT_NAME, + ALPHANUMERIC_SERVICE_NAME, + CLUSTER_NAME_SUFFIX, + HYPHENATED_APPLICATION_NAME, + REFRESH_TOKEN_MESSAGE, + SERVICE_NAME_SUFFIX, +) import pytest from moto import mock_aws -from dbt_platform_helper.constants import ALPHANUMERIC_ENVIRONMENT_NAME -from dbt_platform_helper.constants import ALPHANUMERIC_SERVICE_NAME -from dbt_platform_helper.constants import CLUSTER_NAME_SUFFIX -from dbt_platform_helper.constants import HYPHENATED_APPLICATION_NAME -from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE -from dbt_platform_helper.constants import SERVICE_NAME_SUFFIX +from dbt_platform_helper.providers.aws import AWSException from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException from dbt_platform_helper.providers.aws import LogGroupNotFoundException from dbt_platform_helper.providers.validation import ValidationException from dbt_platform_helper.utils.aws import NoProfileForAccountIdException +from dbt_platform_helper.utils.aws import Vpc from dbt_platform_helper.utils.aws import check_codebase_exists from dbt_platform_helper.utils.aws import get_account_details from dbt_platform_helper.utils.aws import get_aws_session_or_abort @@ -32,6 +36,7 @@ from dbt_platform_helper.utils.aws import get_profile_name_from_account_id from dbt_platform_helper.utils.aws import get_public_repository_arn from dbt_platform_helper.utils.aws import get_ssm_secrets +from dbt_platform_helper.utils.aws import get_vpc_info_by_name from dbt_platform_helper.utils.aws import set_ssm_param from dbt_platform_helper.utils.aws import wait_for_log_group_to_exist from tests.platform_helper.conftest import mock_aws_client @@ -847,6 +852,95 @@ def mock_vpc_info_session(): return mock_session, mock_client, mock_vpc +def test_get_vpc_info_by_name_success(): + mock_session, mock_client, _ = mock_vpc_info_session() + + result = get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") + + expected_vpc = Vpc( + subnets=["subnet-private-1", "subnet-private-2"], security_groups=["sg-abc123"] + ) + + mock_client.describe_vpcs.assert_called_once_with( + Filters=[{"Name": "tag:Name", "Values": ["my_vpc"]}] + ) + + assert result.subnets == expected_vpc.subnets + assert result.security_groups == expected_vpc.security_groups + + +def test_get_vpc_info_by_name_failure_no_matching_vpc(): + mock_session, mock_client, _ = mock_vpc_info_session() + + vpc_data = {"Vpcs": []} + mock_client.describe_vpcs.return_value = vpc_data + + with pytest.raises(AWSException) as ex: + get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") + + assert "VPC not found for name 'my_vpc'" in str(ex) + + +def test_get_vpc_info_by_name_failure_no_vpc_id_in_response(): + mock_session, mock_client, _ = mock_vpc_info_session() + + vpc_data = {"Vpcs": [{"Id": "abc123"}]} + mock_client.describe_vpcs.return_value = vpc_data + + with pytest.raises(AWSException) as ex: + get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") + + assert "VPC id not present in vpc 'my_vpc'" in str(ex) + + +def test_get_vpc_info_by_name_failure_no_private_subnets_in_vpc(): + mock_session, mock_client, mock_vpc = mock_vpc_info_session() + + mock_client.describe_route_tables.return_value = { + "RouteTables": [ + { + "Associations": [ + { + "Main": True, + "RouteTableId": "rtb-00cbf3c8d611a46b8", + } + ], + "Routes": [ + { + "DestinationCidrBlock": "10.151.0.0/16", + "GatewayId": "local", + "Origin": "CreateRouteTable", + "State": "active", + } + ], + "VpcId": "vpc-010327b71b948b4bc", + "OwnerId": "891377058512", + } + ] + } + + with pytest.raises(AWSException) as ex: + get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") + + assert "No private subnets found in vpc 'my_vpc'" in str(ex) + + +def test_get_vpc_info_by_name_failure_no_matching_security_groups(): + mock_session, mock_client, mock_vpc = mock_vpc_info_session() + + mock_vpc.security_groups.all.return_value = [ + ObjectWithId("sg-abc345", tags=[]), + ObjectWithId("sg-abc567", tags=[{"Key": "Name", "Value": "copilot-other_app-my_env-env"}]), + ObjectWithId("sg-abc456"), + ObjectWithId("sg-abc678", tags=[{"Key": "Name", "Value": "copilot-my_app-other_env-env"}]), + ] + + with pytest.raises(AWSException) as ex: + get_vpc_info_by_name(mock_session, "my_app", "my_env", "my_vpc") + + assert "No matching security groups found in vpc 'my_vpc'" in str(ex) + + def test_wait_for_log_group_to_exist_success(): log_group_name = "/ecs/test-log-group" mock_client = Mock()