Skip to content

Commit

Permalink
refactor: Use qs._result_cache instead of len() for safety reasons
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Dec 18, 2024
1 parent 51a6291 commit a98263b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ReverseOneToOneDescriptor,
)
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS
from django.db.models.query import MAX_GET_RESULTS # type: ignore
from django.db.models.query_utils import DeferredAttribute
from strawberry import UNSET, relay
from strawberry.annotation import StrawberryAnnotation
Expand Down Expand Up @@ -292,10 +292,10 @@ def qs_hook(qs: models.QuerySet):
# Don't use qs.get() if the queryset is optimized by prefetching.
# Calling get in that case would disregard the prefetched results, because get implicitly
# adds a limit to the query
if is_optimized_by_prefetching(qs):
if (result_cache := qs._result_cache) is not None: # type: ignore
# mimic behavior of get()
# the queryset is already prefetched, no issue with just using len()
qs_len = len(qs)
qs_len = len(result_cache)
if qs_len == 0:
raise qs.model.DoesNotExist(
f"{qs.model._meta.object_name} matching query does not exist."
Expand All @@ -305,7 +305,8 @@ def qs_hook(qs: models.QuerySet):
f"get() returned more than one {qs.model._meta.object_name} -- it returned "
f"{qs_len if qs_len < MAX_GET_RESULTS else f'more than {qs_len - 1}'}!"
)
return qs[0]
return result_cache[0]

return qs.get()

return qs_hook
Expand Down
2 changes: 1 addition & 1 deletion strawberry_django/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from django.db import DEFAULT_DB_ALIAS
from django.db.models import Count, QuerySet, Window
from django.db.models.functions import RowNumber
from django.db.models.query import MAX_GET_RESULTS
from django.db.models.query import MAX_GET_RESULTS # type: ignore
from strawberry.types import Info
from strawberry.types.arguments import StrawberryArgument
from strawberry.types.unset import UNSET, UnsetType
Expand Down
4 changes: 2 additions & 2 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ class MilestoneType(relay.Node, Named):
order=IssueOrder,
pagination=True,
)
first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues")
first_issue_required: "IssueType" = strawberry_django.field(field_name="issues")
issues_paginated: OffsetPaginated["IssueType"] = strawberry_django.offset_paginated(
field_name="issues",
order=IssueOrder,
Expand All @@ -169,8 +171,6 @@ class MilestoneType(relay.Node, Named):
filters=IssueFilter,
)
)
first_issue: Optional["IssueType"] = strawberry_django.field(field_name="issues")
first_issue_required: "IssueType" = strawberry_django.field(field_name="issues")

@strawberry_django.field(
prefetch_related=[
Expand Down

0 comments on commit a98263b

Please sign in to comment.