Skip to content

Commit

Permalink
refactor: Fixing the Vpc provider in Database copy (#709)
Browse files Browse the repository at this point in the history
Co-authored-by: chiaramapellimt <chiara.mapelli@madetech.com>
Co-authored-by: Chiara <95863059+chiaramapellimt@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 9, 2025
1 parent f10efe2 commit 37f4181
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 183 deletions.
12 changes: 7 additions & 5 deletions dbt_platform_helper/domain/database_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from dbt_platform_helper.domain.maintenance_page import MaintenancePage
from dbt_platform_helper.providers.aws import AWSException
from dbt_platform_helper.providers.config import ConfigProvider
from dbt_platform_helper.providers.vpc import Vpc
from dbt_platform_helper.providers.vpc import VpcProvider
from dbt_platform_helper.utils.application import Application
from dbt_platform_helper.utils.application import ApplicationNotFoundException
from dbt_platform_helper.utils.application import load_application
from dbt_platform_helper.utils.aws import Vpc
from dbt_platform_helper.utils.aws import get_connection_string
from dbt_platform_helper.utils.aws import get_vpc_info_by_name
from dbt_platform_helper.utils.aws import wait_for_log_group_to_exist
from dbt_platform_helper.utils.messages import abort_with_error

Expand All @@ -28,7 +28,8 @@ def __init__(
database: str,
auto_approve: bool = False,
load_application: Callable[[str], Application] = load_application,
vpc_config: Callable[[Session, str, str, str], Vpc] = get_vpc_info_by_name,
# TODO We inject VpcProvider as a callable here so that it can be instantiated within the method. To be improved
vpc_provider: Callable[[Session], VpcProvider] = VpcProvider,
db_connection_string: Callable[
[Session, str, str, str, Callable], str
] = get_connection_string,
Expand All @@ -43,7 +44,7 @@ def __init__(
self.app = app
self.database = database
self.auto_approve = auto_approve
self.vpc_config = vpc_config
self.vpc_provider = vpc_provider
self.db_connection_string = db_connection_string
self.maintenance_page_provider = maintenance_page_provider
self.input = input
Expand Down Expand Up @@ -76,7 +77,8 @@ def _execute_operation(self, is_dump: bool, env: str, vpc_name: str, filename: s
env_session = environment.session

try:
vpc_config = self.vpc_config(env_session, self.app, env, vpc_name)
vpc_provider = self.vpc_provider(env_session)
vpc_config = vpc_provider.get_vpc_info_by_name(self.app, env, vpc_name)
except AWSException as ex:
self.abort(str(ex))

Expand Down
57 changes: 57 additions & 0 deletions dbt_platform_helper/providers/vpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass

from dbt_platform_helper.providers.aws import AWSException


@dataclass
class Vpc:
subnets: list[str]
security_groups: list[str]


class VpcProvider:
def __init__(self, session):
self.ec2_client = session.client("ec2")
self.ec2_resource = session.resource("ec2")

def get_vpc_info_by_name(self, app: str, env: str, vpc_name: str) -> Vpc:
vpc_response = self.ec2_client.describe_vpcs(
Filters=[{"Name": "tag:Name", "Values": [vpc_name]}]
)

matching_vpcs = vpc_response.get("Vpcs", [])

if not matching_vpcs:
raise AWSException(f"VPC not found for name '{vpc_name}'")

vpc_id = vpc_response["Vpcs"][0].get("VpcId")

if not vpc_id:
raise AWSException(f"VPC id not present in vpc '{vpc_name}'")

vpc = self.ec2_resource.Vpc(vpc_id)

route_tables = self.ec2_client.describe_route_tables(
Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
)["RouteTables"]

subnets = []
for route_table in route_tables:
private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route]
if not private_routes:
continue
for association in route_table["Associations"]:
if "SubnetId" in association:
subnet_id = association["SubnetId"]
subnets.append(subnet_id)

if not subnets:
raise AWSException(f"No private subnets found in vpc '{vpc_name}'")

tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"}
sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags]

if not sec_groups:
raise AWSException(f"No matching security groups found in vpc '{vpc_name}'")

return Vpc(subnets, sec_groups)
23 changes: 7 additions & 16 deletions dbt_platform_helper/utils/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import re
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from typing import Dict

Expand All @@ -16,16 +18,12 @@
from dbt_platform_helper.utils.messages import abort_with_error


@dataclass
class Environment:
name: str
account_id: str
sessions: Dict[str, boto3.Session]

def __init__(self, name: str, account_id: str, sessions: Dict[str, boto3.Session]):
self.name = name
self.account_id = account_id
self.sessions = sessions

@property
def session(self):
if self.account_id not in self.sessions:
Expand All @@ -36,24 +34,17 @@ def session(self):
return self.sessions[self.account_id]


@dataclass
class Service:
name: str
kind: str

def __init__(self, name: str, kind: str):
self.name = name
self.kind = kind


@dataclass
class Application:
name: str
environments: Dict[str, Environment]
services: Dict[str, Service]

def __init__(self, name: str):
self.name = name
self.environments = {}
self.services = {}
environments: Dict[str, Environment] = field(default_factory=dict)
services: Dict[str, Service] = field(default_factory=dict)

def __str__(self):
output = f"Application {self.name} with"
Expand Down
50 changes: 0 additions & 50 deletions dbt_platform_helper/utils/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from dbt_platform_helper.constants import REFRESH_TOKEN_MESSAGE
from dbt_platform_helper.platform_exception import PlatformException
from dbt_platform_helper.providers.aws import AWSException
from dbt_platform_helper.providers.aws import CopilotCodebaseNotFoundException
from dbt_platform_helper.providers.aws import ImageNotFoundException
from dbt_platform_helper.providers.aws import LogGroupNotFoundException
Expand Down Expand Up @@ -377,55 +376,6 @@ def get_connection_string(
return f"postgres://{conn['username']}:{conn['password']}@{conn['host']}:{conn['port']}/{conn['dbname']}"


class Vpc:
def __init__(self, subnets: list[str], security_groups: list[str]):
self.subnets = subnets
self.security_groups = security_groups


def get_vpc_info_by_name(session: Session, app: str, env: str, vpc_name: str) -> Vpc:
ec2_client = session.client("ec2")
vpc_response = ec2_client.describe_vpcs(Filters=[{"Name": "tag:Name", "Values": [vpc_name]}])

matching_vpcs = vpc_response.get("Vpcs", [])

if not matching_vpcs:
raise AWSException(f"VPC not found for name '{vpc_name}'")

vpc_id = vpc_response["Vpcs"][0].get("VpcId")

if not vpc_id:
raise AWSException(f"VPC id not present in vpc '{vpc_name}'")

ec2_resource = session.resource("ec2")
vpc = ec2_resource.Vpc(vpc_id)

route_tables = ec2_client.describe_route_tables(
Filters=[{"Name": "vpc-id", "Values": [vpc_id]}]
)["RouteTables"]

subnets = []
for route_table in route_tables:
private_routes = [route for route in route_table["Routes"] if "NatGatewayId" in route]
if not private_routes:
continue
for association in route_table["Associations"]:
if "SubnetId" in association:
subnet_id = association["SubnetId"]
subnets.append(subnet_id)

if not subnets:
raise AWSException(f"No private subnets found in vpc '{vpc_name}'")

tag_value = {"Key": "Name", "Value": f"copilot-{app}-{env}-env"}
sec_groups = [sg.id for sg in vpc.security_groups.all() if sg.tags and tag_value in sg.tags]

if not sec_groups:
raise AWSException(f"No matching security groups found in vpc '{vpc_name}'")

return Vpc(subnets, sec_groups)


def start_build_extraction(codebuild_client, build_options):
response = codebuild_client.start_build(**build_options)
return response["build"]["arn"]
Expand Down
35 changes: 23 additions & 12 deletions tests/platform_helper/domain/test_database_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from dbt_platform_helper.domain.database_copy import DatabaseCopy
from dbt_platform_helper.providers.aws import AWSException
from dbt_platform_helper.providers.config import ConfigProvider
from dbt_platform_helper.providers.vpc import Vpc
from dbt_platform_helper.utils.application import Application
from dbt_platform_helper.utils.application import ApplicationNotFoundException
from dbt_platform_helper.utils.aws import Vpc


class DataCopyMocks:
Expand All @@ -25,8 +25,12 @@ def __init__(self, app="test-app", env="test-env", acc="12345", vpc=Vpc([], []),
self.environment.session.client.return_value = self.client

self.vpc = vpc
self.vpc_config = Mock()
self.vpc_config.return_value = vpc
self.vpc_provider = (
Mock()
) # this is the callable class so should return a class when called
self.instantiated_vpc_provider = Mock()
self.instantiated_vpc_provider.get_vpc_info_by_name.return_value = self.vpc
self.vpc_provider.return_value = self.instantiated_vpc_provider
self.db_connection_string = Mock(return_value="test-db-connection-string")
self.maintenance_page_provider = Mock()

Expand All @@ -38,7 +42,7 @@ def __init__(self, app="test-app", env="test-env", acc="12345", vpc=Vpc([], []),
def params(self):
return {
"load_application": self.load_application,
"vpc_config": self.vpc_config,
"vpc_provider": self.vpc_provider,
"db_connection_string": self.db_connection_string,
"maintenance_page_provider": self.maintenance_page_provider,
"input": self.input,
Expand Down Expand Up @@ -124,8 +128,9 @@ def test_database_dump():
db_copy.dump(env, vpc_name)

mocks.load_application.assert_called_once()
mocks.vpc_config.assert_called_once_with(
mocks.environment.session, app, env, "test-vpc-override"
mocks.vpc_provider.assert_called_once_with(mocks.environment.session)
mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with(
app, env, "test-vpc-override"
)
mocks.db_connection_string.assert_called_once_with(
mocks.environment.session, app, env, "test-app-test-env-test-db"
Expand Down Expand Up @@ -165,8 +170,9 @@ def test_database_load_with_response_of_yes():

mocks.load_application.assert_called_once()

mocks.vpc_config.assert_called_once_with(
mocks.environment.session, app, env, "test-vpc-override"
mocks.vpc_provider.assert_called_once_with(mocks.environment.session)
mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with(
app, env, "test-vpc-override"
)

mocks.db_connection_string.assert_called_once_with(
Expand Down Expand Up @@ -212,7 +218,8 @@ def test_database_load_with_response_of_no():

mocks.environment.session.assert_not_called()

mocks.vpc_config.assert_not_called()
mocks.vpc_provider.assert_not_called()
mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_not_called()

mocks.db_connection_string.assert_not_called()

Expand All @@ -228,7 +235,9 @@ def test_database_load_with_response_of_no():
@pytest.mark.parametrize("is_dump", (True, False))
def test_database_dump_handles_vpc_errors(is_dump):
mocks = DataCopyMocks()
mocks.vpc_config.side_effect = AWSException("A VPC error occurred")
mocks.instantiated_vpc_provider.get_vpc_info_by_name.side_effect = AWSException(
"A VPC error occurred"
)

db_copy = DatabaseCopy("test-app", "test-db", **mocks.params())

Expand All @@ -239,6 +248,7 @@ def test_database_dump_handles_vpc_errors(is_dump):
db_copy.load("test-env", "bad-vpc-name")

assert exc.value.code == 1
mocks.vpc_provider.assert_called_once_with(mocks.environment.session)
mocks.abort.assert_called_once_with("A VPC error occurred")


Expand Down Expand Up @@ -536,8 +546,9 @@ def test_database_dump_with_no_vpc_works_in_deploy_repo(fs, is_dump):
else:
db_copy.load(env, None)

mocks.vpc_config.assert_called_once_with(
mocks.environment.session, "test-app", env, "test-env-vpc"
mocks.vpc_provider.assert_called_once_with(mocks.environment.session)
mocks.instantiated_vpc_provider.get_vpc_info_by_name.assert_called_once_with(
"test-app", env, "test-env-vpc"
)


Expand Down
Loading

0 comments on commit 37f4181

Please sign in to comment.