Skip to content

Commit

Permalink
Add Django connection pools feature (#5332)
Browse files Browse the repository at this point in the history
* Add Django connection pools feature

* Update api/conf/settings/databases.py

Co-authored-by: Dhruv Bhanushali <hi@dhruvkb.dev>

* Add transaction=True

* Add a workaround for tests dropping db

* Fix SyntaxWarning: invalid escape sequence '\d'

* Improve comments

* Hard-code licenses in the catalog

---------

Co-authored-by: Dhruv Bhanushali <hi@dhruvkb.dev>
  • Loading branch information
obulat and dhruvkb authored Jan 15, 2025
1 parent 24e9e4a commit 3e68724
Show file tree
Hide file tree
Showing 17 changed files with 92 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ jobs:
run: just api/init

- name: Run API tests
run: just api/test
run: just api/test-ci

- name: Print API test logs
if: success() || failure()
Expand Down
2 changes: 1 addition & 1 deletion api/api/templatetags/get_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django import template


numeric_test = re.compile("^\d+$")
numeric_test = re.compile(r"^\d+$")
register = template.Library()


Expand Down
6 changes: 2 additions & 4 deletions api/conf/settings/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


# Database
# https://docs.djangoproject.com/en/4.2/ref/settings/#databases
# https://docs.djangoproject.com/en/stable/ref/settings/#databases

DATABASES = {
"default": {
Expand All @@ -12,16 +12,14 @@
"USER": config("DJANGO_DATABASE_USER", default="deploy"),
"PASSWORD": config("DJANGO_DATABASE_PASSWORD", default="deploy"),
"NAME": config("DJANGO_DATABASE_NAME", default="openledger"),
# Default of 30 matches RDS documentation's advised max DNS caching time
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_BestPractices.html#CHAP_BestPractices.DiskPerformance
"CONN_MAX_AGE": config("DJANGO_CONN_MAX_AGE", default=30),
"CONN_HEALTH_CHECKS": config(
"DJANGO_CONN_HEALTH_CHECKS", default=True, cast=bool
),
"OPTIONS": {
"application_name": config(
"DJANGO_DATABASE_APPLICATION_NAME", default="openverse-api"
),
"pool": True,
},
}
}
7 changes: 7 additions & 0 deletions api/conf/settings/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ def suppress_unwanted_logs(record: LogRecord) -> bool:
"propagate": False,
}

# Add connection pool logging
LOGGING["loggers"]["django.db.backends.pool"] = {
"level": "DEBUG",
"handlers": ["console_structured"],
"propagate": False,
}

if not DEBUG:
# WARNING: Do not run in production long-term as it can impact performance.
middleware = (
Expand Down
8 changes: 8 additions & 0 deletions api/justfile
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ generate-docs doc="media-props" fail_on_diff="true":
test *args: wait-up
env DC_USER="ov_user" just ../exec web pytest "$@"

# Run the API tests in the CI
test-ci: wait-up
# The order is important here: the unit tests drop the database in the end,
# and when ran concurrently with the integration tests, the integration
# tests' database is dropped.
just test -k unit
just test -k "not unit"

# Run API tests locally
[positional-arguments]
test-local *args:
Expand Down
32 changes: 31 additions & 1 deletion api/pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"future >=1, <1.1",
"limit >=0.2.3, <0.3",
"pillow >=11, <12",
"psycopg >=3.1.18, <4",
"psycopg[pool] >=3.2.3, <4",
"python-decouple >=3.8, <4",
"python-xmp-toolkit >=2.0.2, <3",
"sentry-sdk >=2.19, <3",
Expand Down
11 changes: 2 additions & 9 deletions api/test/integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from api.models import OAuth2Verification, ThrottledApplication


pytestmark = pytest.mark.django_db

cache_availability_params = pytest.mark.parametrize(
"is_cache_reachable, cache_name",
[(True, "oauth_cache"), (False, "unreachable_oauth_cache")],
Expand Down Expand Up @@ -73,7 +75,6 @@ def test_auth_token_exchange(api_client, test_auth_tokens_registration):
return res_data


@pytest.mark.django_db
def test_auth_token_exchange_unsupported_method(api_client):
res = api_client.get(
"/v1/auth_tokens/token/",
Expand All @@ -90,7 +91,6 @@ def _integration_verify_most_recent_token(api_client):
return api_client.get(path)


@pytest.mark.django_db
@pytest.mark.parametrize(
"rate_limit_model",
[x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS],
Expand Down Expand Up @@ -125,7 +125,6 @@ def test_auth_email_verification(
)


@pytest.mark.django_db
@pytest.mark.parametrize(
"rate_limit_model",
[x[0] for x in ThrottledApplication.RATE_LIMIT_MODELS],
Expand Down Expand Up @@ -166,7 +165,6 @@ def test_auth_rate_limit_reporting(
assert res_data["verified"] is False


@pytest.mark.django_db
@pytest.mark.parametrize(
"sort_dir, exp_indexed_on",
[
Expand All @@ -191,7 +189,6 @@ def test_sorting_authed(api_client, test_auth_token_exchange, sort_dir, exp_inde
assert indexed_on == exp_indexed_on


@pytest.mark.django_db
@pytest.mark.parametrize(
"authority_boost, exp_source",
[
Expand Down Expand Up @@ -219,15 +216,13 @@ def test_authority_authed(
assert source == exp_source


@pytest.mark.django_db
def test_invalid_credentials_401(api_client):
res = api_client.get(
"/v1/images/", HTTP_AUTHORIZATION="Bearer thisIsNot_ARealToken"
)
assert res.status_code == 401


@pytest.mark.django_db
def test_revoked_application_access(api_client, test_auth_token_exchange):
token = test_auth_token_exchange["access_token"]
application = AccessToken.objects.get(token=token).application
Expand Down Expand Up @@ -258,7 +253,6 @@ def test_revoked_application_access(api_client, test_auth_token_exchange):
)
),
)
@pytest.mark.django_db
def test_page_size_privileges(
api_client, test_auth_token_exchange, level, page_size_modification, allowed
):
Expand Down Expand Up @@ -304,7 +298,6 @@ def test_page_size_privileges(
)
),
)
@pytest.mark.django_db
def test_pagination_depth_privileges(
api_client, test_auth_token_exchange, level, pagination_depth_modification, allowed
):
Expand Down
8 changes: 3 additions & 5 deletions api/test/unit/management/commands/test_generatewaveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from test.factory.models.audio import AudioAddOnFactory, AudioFactory


pytestmark = pytest.mark.django_db


@mock.patch("api.models.audio.generate_peaks")
def call_generatewaveforms(mock_generate_peaks: mock.MagicMock) -> tuple[str, str]:
mock_generate_peaks.side_effect = lambda _: WaveformProvider.generate_waveform()
Expand All @@ -35,7 +38,6 @@ def assert_all_audio_have_waveforms():
)


@pytest.mark.django_db
def test_creates_waveforms_for_audio():
AudioFactory.create_batch(153)

Expand All @@ -46,7 +48,6 @@ def test_creates_waveforms_for_audio():
assert_all_audio_have_waveforms()


@pytest.mark.django_db
def test_does_not_reprocess_existing_waveforms():
waveformless_audio = AudioFactory.create_batch(3)

Expand All @@ -66,7 +67,6 @@ def test_does_not_reprocess_existing_waveforms():
assert_all_audio_have_waveforms()


@pytest.mark.django_db
@mock.patch("api.models.audio.generate_peaks")
def test_paginates_audio_waveforms_to_generate(
mock_generate_peaks, django_assert_num_queries
Expand Down Expand Up @@ -101,7 +101,6 @@ def test_paginates_audio_waveforms_to_generate(
assert_all_audio_have_waveforms()


@pytest.mark.django_db
@pytest.mark.parametrize(
("exception_class", "exception_args", "exception_kwargs"),
(
Expand Down Expand Up @@ -150,7 +149,6 @@ def test_logs_and_continues_if_waveform_generation_fails(
)


@pytest.mark.django_db
@mock.patch("api.models.audio.generate_peaks")
def test_keyboard_interrupt_should_halt_processing(mock_generate_peaks):
audio_count = 23
Expand Down
2 changes: 1 addition & 1 deletion api/test/unit/models/test_media_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


pytestmark = pytest.mark.django_db
pytestmark = pytest.mark.django_db(transaction=True)

reason_params = pytest.mark.parametrize("reason", [DMCA, MATURE, OTHER])

Expand Down
2 changes: 1 addition & 1 deletion api/test/unit/utils/test_moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from test.factory.models.oauth2 import UserFactory


pytestmark = pytest.mark.django_db
pytestmark = pytest.mark.django_db(transaction=True)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging

from common.licenses import get_license_info
from common.licenses import LicenseInfo
from common.loader import provider_details as prov
from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester


logger = logging.getLogger(__name__)

CC0_LICENSE = get_license_info(license_="cc0", license_version="1.0")
CC0_LICENSE = LicenseInfo(
license="cc0",
version="1.0",
url="https://creativecommons.org/publicdomain/zero/1.0/",
)


class ClevelandDataIngester(ProviderDataIngester):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import argparse
import logging

from common.licenses import get_license_info
from common.licenses import LicenseInfo
from common.loader import provider_details as prov
from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester

Expand All @@ -42,7 +42,11 @@
class MetMuseumDataIngester(ProviderDataIngester):
providers = {"image": prov.METROPOLITAN_MUSEUM_DEFAULT_PROVIDER}
endpoint = "https://collectionapi.metmuseum.org/public/collection/v1/objects"
DEFAULT_LICENSE_INFO = get_license_info(license_="cc0", license_version="1.0")
DEFAULT_LICENSE_INFO = LicenseInfo(
license="cc0",
version="1.0",
url="https://creativecommons.org/publicdomain/zero/1.0/",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
9 changes: 6 additions & 3 deletions catalog/dags/providers/provider_api_scripts/nappy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging

from common import constants
from common.licenses import get_license_info
from common.licenses import LicenseInfo
from common.loader import provider_details as prov
from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester

Expand All @@ -28,8 +28,11 @@ class NappyDataIngester(ProviderDataIngester):
headers = {"Accept": "application/json"}

# Hardcoded to CC0, the only license Nappy.co uses
license_info = get_license_info(
"https://creativecommons.org/publicdomain/zero/1.0/"
license_info = LicenseInfo(
license="cc0",
version="1.0",
url="https://creativecommons.org/publicdomain/zero/1.0/",
raw_url=None,
)

def get_next_query_params(self, prev_query_params: dict | None) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@

LIMIT = 100

CC0_LICENSE = get_license_info(license_="cc0", license_version="1.0")
CC0_LICENSE = LicenseInfo(
license="cc0",
version="1.0",
url="https://creativecommons.org/publicdomain/zero/1.0/",
)


class ScienceMuseumDataIngester(ProviderDataIngester):
Expand Down
8 changes: 5 additions & 3 deletions catalog/dags/providers/provider_api_scripts/smithsonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from airflow.exceptions import AirflowException
from airflow.models import Variable

from common.licenses import get_license_info
from common.licenses import LicenseInfo
from common.loader import provider_details as prov
from providers.provider_api_scripts.provider_data_ingester import ProviderDataIngester

Expand Down Expand Up @@ -110,8 +110,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.api_key = Variable.get("API_KEY_DATA_GOV")
self.units_endpoint = f"{self.base_endpoint}terms/unit_code"
self.license_info = get_license_info(
license_url="https://creativecommons.org/publicdomain/zero/1.0/"
self.license_info = LicenseInfo(
license="cc0",
version="1.0",
url="https://creativecommons.org/publicdomain/zero/1.0/",
)

def get_fixed_query_params(self):
Expand Down
Loading

0 comments on commit 3e68724

Please sign in to comment.