diff --git a/strawberry_django/mutations/fields.py b/strawberry_django/mutations/fields.py index 95421001..b58c3c8e 100644 --- a/strawberry_django/mutations/fields.py +++ b/strawberry_django/mutations/fields.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Annotated, Any, Union +from typing import TYPE_CHECKING, Annotated, Any, TypeVar, Union import strawberry from django.core.exceptions import ( @@ -22,7 +22,7 @@ StrawberryDjangoFieldFilters, ) from strawberry_django.fields.types import OperationInfo, OperationMessage -from strawberry_django.optimizer import DjangoOptimizerExtension +from strawberry_django.optimizer import DjangoOptimizerExtension, optimize from strawberry_django.permissions import filter_with_perms, get_with_perms from strawberry_django.resolvers import django_resolver from strawberry_django.settings import strawberry_django_settings @@ -45,6 +45,8 @@ from .types import FullCleanOptions +_T = TypeVar("_T", bound="models.Model | list[models.Model]") + def _get_validation_errors(error: Exception): if isinstance(error, PermissionDenied): @@ -225,6 +227,32 @@ def arguments(self, value: list[StrawberryArgument]): args_prop = super(DjangoMutationBase, self.__class__).arguments return args_prop.fset(self, value) # type: ignore + def refetch(self, resolved: _T, *, info: Info | None) -> _T: + if not DjangoOptimizerExtension.enabled or info is None: + return resolved + + if isinstance(resolved, list) and resolved: + model = type(resolved[0]) + if issubclass(model, models.Model): + original_order = {r.pk: i for i, r in enumerate(resolved)} + resolved_qs = optimize( + model._default_manager.filter(pk__in=[r.pk for r in resolved]), + info=info, + ) + # sort the resolved objects in the order they were given + resolved = sorted( # type: ignore + resolved_qs, + key=lambda r: original_order[r.pk], + ) + elif isinstance(resolved, models.Model): + model = type(resolved) + resolved = optimize( + model._default_manager.filter(pk=resolved.pk), + info=info, + ).get() + + return resolved + class DjangoCreateMutation(DjangoMutationCUD, StrawberryDjangoFieldFilters): @django_resolver @@ -240,37 +268,39 @@ def resolver( data: list[Any] | Any = kwargs.get(self.argument_name) - if self.is_list: - assert isinstance(data, list) - return [ - self.create( - resolvers.parse_input(info, vars(d), key_attr=self.key_attr), + # Do not optimize anything while retrieving the object to create + with DjangoOptimizerExtension.disabled(): + if self.is_list: + assert isinstance(data, list) + resolved = [ + self.create( + resolvers.parse_input(info, vars(d), key_attr=self.key_attr), + info=info, + ) + for d in data + ] + else: + assert not isinstance(data, list) + resolved = self.create( + resolvers.parse_input(info, vars(data), key_attr=self.key_attr) + if data is not None + else {}, info=info, ) - for d in data - ] - - assert not isinstance(data, list) - return self.create( - resolvers.parse_input(info, vars(data), key_attr=self.key_attr) - if data is not None - else {}, - info=info, - ) + + return self.refetch(resolved, info=info) def create(self, data: dict[str, Any], *, info: Info): model = self.django_model assert model is not None - # Do not optimize anything while retrieving the object to create - with DjangoOptimizerExtension.disabled(): - return resolvers.create( - info, - model, - data, - key_attr=self.key_attr, - full_clean=self.full_clean, - ) + return resolvers.create( + info, + model, + data, + key_attr=self.key_attr, + full_clean=self.full_clean, + ) def get_vdata(data: Any) -> dict[str, Any]: @@ -307,10 +337,14 @@ def resolver( data: list[Any] | Any = kwargs.get(self.argument_name) - if isinstance(data, list): - return [self.instance_level_update(info, kwargs, d) for d in data] + # Do not optimize anything while retrieving the object to update + with DjangoOptimizerExtension.disabled(): + if isinstance(data, list): + resolved = [self.instance_level_update(info, kwargs, d) for d in data] + else: + resolved = self.instance_level_update(info, kwargs, data) - return self.instance_level_update(info, kwargs, data) + return self.refetch(resolved, info=info) def instance_level_update( self, @@ -352,15 +386,13 @@ def update( instance: models.Model | Iterable[models.Model], data: dict[str, Any], ): - # Do not optimize anything while retrieving the object to update - with DjangoOptimizerExtension.disabled(): - return resolvers.update( - info, - instance, - data, - key_attr=self.key_attr, - full_clean=self.full_clean, - ) + return resolvers.update( + info, + instance, + data, + key_attr=self.key_attr, + full_clean=self.full_clean, + ) class DjangoDeleteMutation( @@ -405,7 +437,9 @@ def resolver( ) return self.delete( - info, instance, resolvers.parse_input(info, vdata, key_attr=self.key_attr) + info, + instance, + resolvers.parse_input(info, vdata, key_attr=self.key_attr), ) def delete( @@ -414,10 +448,8 @@ def delete( instance: models.Model | Iterable[models.Model], data: dict[str, Any] | None = None, ): - # Do not optimize anything while retrieving the object to update - with DjangoOptimizerExtension.disabled(): - return resolvers.delete( - info, - instance, - data=data, - ) + return resolvers.delete( + info, + instance, + data=data, + ) diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index f48d54e3..6f2a388b 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -541,6 +541,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien id name dueDate + isDelayed milestones { id name @@ -557,15 +558,17 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien milestone_1_id = to_base64("MilestoneType", milestone_1.pk) MilestoneFactory.create(project=project) - res = gql_client.query( - query, - { - "input": { - "id": to_base64("ProjectType", project.pk), - "milestones": [{"id": milestone_1_id}], + with assert_num_queries(14): + res = gql_client.query( + query, + { + "input": { + "id": to_base64("ProjectType", project.pk), + "milestones": [{"id": milestone_1_id}], + }, }, - }, - ) + ) + assert res.data assert isinstance(res.data["updateProject"], dict)