Skip to content

Commit

Permalink
fix(mutations): Refetch instances to optimize the return value (#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 authored Dec 15, 2024
1 parent 0d59b84 commit ea97376
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 54 deletions.
124 changes: 78 additions & 46 deletions strawberry_django/mutations/fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Annotated, Any, Union
from typing import TYPE_CHECKING, Annotated, Any, TypeVar, Union

import strawberry
from django.core.exceptions import (
Expand All @@ -22,7 +22,7 @@
StrawberryDjangoFieldFilters,
)
from strawberry_django.fields.types import OperationInfo, OperationMessage
from strawberry_django.optimizer import DjangoOptimizerExtension
from strawberry_django.optimizer import DjangoOptimizerExtension, optimize
from strawberry_django.permissions import filter_with_perms, get_with_perms
from strawberry_django.resolvers import django_resolver
from strawberry_django.settings import strawberry_django_settings
Expand All @@ -45,6 +45,8 @@

from .types import FullCleanOptions

_T = TypeVar("_T", bound="models.Model | list[models.Model]")


def _get_validation_errors(error: Exception):
if isinstance(error, PermissionDenied):
Expand Down Expand Up @@ -225,6 +227,32 @@ def arguments(self, value: list[StrawberryArgument]):
args_prop = super(DjangoMutationBase, self.__class__).arguments
return args_prop.fset(self, value) # type: ignore

def refetch(self, resolved: _T, *, info: Info | None) -> _T:
if not DjangoOptimizerExtension.enabled or info is None:
return resolved

if isinstance(resolved, list) and resolved:
model = type(resolved[0])
if issubclass(model, models.Model):
original_order = {r.pk: i for i, r in enumerate(resolved)}
resolved_qs = optimize(
model._default_manager.filter(pk__in=[r.pk for r in resolved]),
info=info,
)
# sort the resolved objects in the order they were given
resolved = sorted( # type: ignore
resolved_qs,
key=lambda r: original_order[r.pk],
)
elif isinstance(resolved, models.Model):
model = type(resolved)
resolved = optimize(
model._default_manager.filter(pk=resolved.pk),
info=info,
).get()

return resolved


class DjangoCreateMutation(DjangoMutationCUD, StrawberryDjangoFieldFilters):
@django_resolver
Expand All @@ -240,37 +268,39 @@ def resolver(

data: list[Any] | Any = kwargs.get(self.argument_name)

if self.is_list:
assert isinstance(data, list)
return [
self.create(
resolvers.parse_input(info, vars(d), key_attr=self.key_attr),
# Do not optimize anything while retrieving the object to create
with DjangoOptimizerExtension.disabled():
if self.is_list:
assert isinstance(data, list)
resolved = [
self.create(
resolvers.parse_input(info, vars(d), key_attr=self.key_attr),
info=info,
)
for d in data
]
else:
assert not isinstance(data, list)
resolved = self.create(
resolvers.parse_input(info, vars(data), key_attr=self.key_attr)
if data is not None
else {},
info=info,
)
for d in data
]

assert not isinstance(data, list)
return self.create(
resolvers.parse_input(info, vars(data), key_attr=self.key_attr)
if data is not None
else {},
info=info,
)

return self.refetch(resolved, info=info)

def create(self, data: dict[str, Any], *, info: Info):
model = self.django_model
assert model is not None

# Do not optimize anything while retrieving the object to create
with DjangoOptimizerExtension.disabled():
return resolvers.create(
info,
model,
data,
key_attr=self.key_attr,
full_clean=self.full_clean,
)
return resolvers.create(
info,
model,
data,
key_attr=self.key_attr,
full_clean=self.full_clean,
)


def get_vdata(data: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -307,10 +337,14 @@ def resolver(

data: list[Any] | Any = kwargs.get(self.argument_name)

if isinstance(data, list):
return [self.instance_level_update(info, kwargs, d) for d in data]
# Do not optimize anything while retrieving the object to update
with DjangoOptimizerExtension.disabled():
if isinstance(data, list):
resolved = [self.instance_level_update(info, kwargs, d) for d in data]
else:
resolved = self.instance_level_update(info, kwargs, data)

return self.instance_level_update(info, kwargs, data)
return self.refetch(resolved, info=info)

def instance_level_update(
self,
Expand Down Expand Up @@ -352,15 +386,13 @@ def update(
instance: models.Model | Iterable[models.Model],
data: dict[str, Any],
):
# Do not optimize anything while retrieving the object to update
with DjangoOptimizerExtension.disabled():
return resolvers.update(
info,
instance,
data,
key_attr=self.key_attr,
full_clean=self.full_clean,
)
return resolvers.update(
info,
instance,
data,
key_attr=self.key_attr,
full_clean=self.full_clean,
)


class DjangoDeleteMutation(
Expand Down Expand Up @@ -405,7 +437,9 @@ def resolver(
)

return self.delete(
info, instance, resolvers.parse_input(info, vdata, key_attr=self.key_attr)
info,
instance,
resolvers.parse_input(info, vdata, key_attr=self.key_attr),
)

def delete(
Expand All @@ -414,10 +448,8 @@ def delete(
instance: models.Model | Iterable[models.Model],
data: dict[str, Any] | None = None,
):
# Do not optimize anything while retrieving the object to update
with DjangoOptimizerExtension.disabled():
return resolvers.delete(
info,
instance,
data=data,
)
return resolvers.delete(
info,
instance,
data=data,
)
19 changes: 11 additions & 8 deletions tests/test_input_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien
id
name
dueDate
isDelayed
milestones {
id
name
Expand All @@ -557,15 +558,17 @@ def test_input_update_m2m_set_not_null_mutation(db, gql_client: GraphQLTestClien
milestone_1_id = to_base64("MilestoneType", milestone_1.pk)
MilestoneFactory.create(project=project)

res = gql_client.query(
query,
{
"input": {
"id": to_base64("ProjectType", project.pk),
"milestones": [{"id": milestone_1_id}],
with assert_num_queries(14):
res = gql_client.query(
query,
{
"input": {
"id": to_base64("ProjectType", project.pk),
"milestones": [{"id": milestone_1_id}],
},
},
},
)
)

assert res.data
assert isinstance(res.data["updateProject"], dict)

Expand Down

1 comment on commit ea97376

@keithhackbarth
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome 🤩 ! Thank you!

Please sign in to comment.