Skip to content

Commit

Permalink
refactor: Use graphql-core's collect_sub_fields instead of our own im…
Browse files Browse the repository at this point in the history
…plementation (#537)

This allows us to simplify our optimizer implementation, and also paves
the way for nested optimizations implementation in the near future.
  • Loading branch information
bellini666 authored May 30, 2024
1 parent ffd89fd commit f352f2a
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 186 deletions.
263 changes: 172 additions & 91 deletions strawberry_django/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
)
from django.db.models.manager import BaseManager
from django.db.models.query import QuerySet
from graphql import FieldNode, GraphQLObjectType, GraphQLOutputType, GraphQLWrappingType
from graphql.execution.collect_fields import collect_sub_fields
from graphql.language.ast import OperationType
from graphql.type.definition import GraphQLResolveInfo, get_named_type
from strawberry import relay
Expand All @@ -37,7 +39,6 @@
from strawberry.schema.schema import Schema
from strawberry.type import get_object_definition
from strawberry.types.info import Info
from strawberry.types.nodes import InlineFragment, Selection, convert_selections
from strawberry.utils.typing import eval_type
from typing_extensions import assert_never, assert_type, get_args

Expand All @@ -50,7 +51,6 @@
PrefetchInspector,
get_model_fields,
get_possible_type_definitions,
get_selections,
)
from .utils.typing import (
AnnotateCallable,
Expand Down Expand Up @@ -364,12 +364,48 @@ def _get_prefetch_queryset(
)


def _get_selections(
info: GraphQLResolveInfo,
parent_type: GraphQLObjectType,
) -> dict[str, list[FieldNode]]:
return collect_sub_fields(
info.schema,
info.fragments,
info.variable_values,
parent_type,
info.field_nodes,
)


def _generate_selection_resolve_info(
info: GraphQLResolveInfo,
field_nodes: list[FieldNode],
return_type: GraphQLOutputType,
parent_type: GraphQLObjectType,
):
field_node = field_nodes[0]
return GraphQLResolveInfo(
field_name=field_node.name.value,
field_nodes=field_nodes,
return_type=return_type,
parent_type=parent_type,
path=info.path.add_key(0).add_key(field_node.name.value, parent_type.name),
schema=info.schema,
fragments=info.fragments,
root_value=info.root_value,
operation=info.operation,
variable_values=info.variable_values,
context=info.context,
is_awaitable=info.is_awaitable,
)


def _get_model_hints(
model: type[models.Model],
schema: Schema,
object_definition: StrawberryObjectDefinition,
selection: Selection,
*,
parent_type: GraphQLObjectType,
info: GraphQLResolveInfo,
config: OptimizerConfig | None = None,
prefix: str = "",
Expand All @@ -393,63 +429,24 @@ def _get_model_hints(
GenericRelation,
)

store = OptimizerStore()
cache = cache or {}
t_name = schema.config.name_converter.from_object(object_definition)

# In case this is a relay field, find the selected edges/nodes, the selected fields
# are actually inside edges -> node selection...
if issubclass(
object_definition.origin,
relay.Connection,
):
# TODO: Connections are mostly used for pagination so it doesn't make sense for
# us to optimize those, as our prefetch would be thrown away causing an extra
# useless query. Is there a way for us to properly optimize this in the future?
if level > 0:
return None

n_type = object_definition.type_var_map.get("NodeType")
if n_type is None:
specialized_type_var_map = object_definition.specialized_type_var_map or {}

n_type = specialized_type_var_map["NodeType"]
if isinstance(n_type, LazyType):
n_type = n_type.resolve_type()

n_definition = get_object_definition(n_type, strict=True)

for edges in get_selections(selection, typename=t_name).values():
if edges.name != "edges":
continue

e_definition = get_object_definition(relay.Edge, strict=True)
e_type = e_definition.resolve_generic(
relay.Edge[cast(Type[relay.Node], n_type)],
)
e_name = schema.config.name_converter.from_object(
get_object_definition(e_type, strict=True),
)
for node in get_selections(edges, typename=e_name).values():
if node.name != "node":
continue

new_store = _get_model_hints(
model=model,
schema=schema,
object_definition=n_definition,
selection=node,
info=info,
config=config,
prefix=prefix,
cache=cache,
level=level,
)
if new_store is not None:
store |= new_store

return store
if issubclass(object_definition.origin, relay.Connection):
return _get_model_hints_from_connection(
model,
schema,
object_definition,
parent_type=parent_type,
info=info,
config=config,
prefix=prefix,
cache=cache,
level=level,
)

store = OptimizerStore()
fields = {
schema.config.name_converter.get_graphql_name(f): f
for f in object_definition.fields
Expand All @@ -473,18 +470,30 @@ def _get_model_hints(
if pk is not None:
store.only.append(pk.attname)

for f_selection in get_selections(selection, typename=t_name).values():
field = fields.get(f_selection.name, None)
for f_selections in _get_selections(info, parent_type).values():
f_selection = f_selections[0]
field = fields.get(f_selection.name.value, None)
if not field:
continue

# Do not optimize the field if the user asked not to
if getattr(field, "disable_optimization", False):
continue

field_definition = parent_type.fields[f_selection.name.value].type
while isinstance(field_definition, GraphQLWrappingType):
field_definition = field_definition.of_type

f_info = _generate_selection_resolve_info(
info,
f_selections,
field_definition,
parent_type,
)

# Add annotations from the field if they exist
field_store = getattr(field, "store", None)
if field_store is not None:
if field_store:
if (
len(field_store.annotate) == 1
and _annotate_placeholder in field_store.annotate
Expand All @@ -500,14 +509,20 @@ def _get_model_hints(
field.name: field_store.annotate[_annotate_placeholder],
}
store |= (
field_store.with_prefix(prefix, info=info) if prefix else field_store
field_store.with_prefix(prefix, info=f_info) if prefix else field_store
)

# Then from the model property if one is defined
model_attr = getattr(model, field.python_name, None)
if model_attr is not None and isinstance(model_attr, ModelProperty):
if (
model_attr is not None
and isinstance(model_attr, ModelProperty)
and model_attr.store
):
attr_store = model_attr.store
store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store
store |= (
attr_store.with_prefix(prefix, info=f_info) if prefix else attr_store
)

# Lastly, from the django field itself
model_fieldname: str = getattr(field, "django_name", None) or field.python_name
Expand All @@ -533,24 +548,21 @@ def _get_model_hints(
f_model,
schema,
f_type_def,
f_selection,
info=info,
parent_type=cast(GraphQLObjectType, field_definition),
info=f_info,
config=config,
cache=cache,
level=level + 1,
)
if f_store is not None:
cache.setdefault(f_model, []).append((level, f_store))
store |= f_store.with_prefix(path, info=info)
store |= f_store.with_prefix(path, info=f_info)
elif GenericForeignKey and isinstance(model_field, GenericForeignKey):
# There's not much we can do to optimize generic foreign keys regarding
# only/select_related because they can be anything.
# Just prefetch_related them
store.prefetch_related.append(model_fieldname)
elif isinstance(
model_field,
_relation_fields,
):
elif isinstance(model_field, _relation_fields):
f_types = list(get_possible_type_definitions(field.type))
if len(f_types) > 1:
# This might be a generic foreign key.
Expand All @@ -563,8 +575,8 @@ def _get_model_hints(
remote_model,
schema,
f_types[0],
f_selection,
info=info,
parent_type=cast(GraphQLObjectType, field_definition),
info=f_info,
config=config,
cache=cache,
level=level + 1,
Expand Down Expand Up @@ -625,11 +637,7 @@ def _get_model_hints(
config,
info,
)
f_qs = f_store.apply(
base_qs,
info=info,
config=config,
)
f_qs = f_store.apply(base_qs, info=f_info, config=config)
f_prefetch = Prefetch(path, queryset=f_qs)
f_prefetch._optimizer_sentinel = _sentinel # type: ignore
store.prefetch_related.append(f_prefetch)
Expand All @@ -651,6 +659,83 @@ def _get_model_hints(
return store


def _get_model_hints_from_connection(
model: type[models.Model],
schema: Schema,
object_definition: StrawberryObjectDefinition,
*,
parent_type: GraphQLObjectType,
info: GraphQLResolveInfo,
config: OptimizerConfig | None = None,
prefix: str = "",
cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None,
level: int = 0,
) -> OptimizerStore | None:
# TODO: Connections are mostly used for pagination so it doesn't make sense for
# us to optimize those, as our prefetch would be thrown away causing an extra
# useless query. Is there a way for us to properly optimize this in the future?
if level > 0:
return None

store = None

n_type = object_definition.type_var_map.get("NodeType")
if n_type is None:
specialized_type_var_map = object_definition.specialized_type_var_map or {}

n_type = specialized_type_var_map["NodeType"]

if isinstance(n_type, LazyType):
n_type = n_type.resolve_type()

n_definition = get_object_definition(n_type, strict=True)

for edges in _get_selections(info, parent_type).values():
edge = edges[0]
if edge.name.value != "edges":
continue

e_definition = get_object_definition(relay.Edge, strict=True)
e_type = e_definition.resolve_generic(
relay.Edge[cast(Type[relay.Node], n_type)],
)
e_gql_definition = schema.schema_converter.from_object(
get_object_definition(e_type, strict=True),
)
e_info = _generate_selection_resolve_info(
info,
edges,
e_gql_definition,
parent_type,
)
for nodes in _get_selections(e_info, e_gql_definition).values():
node = nodes[0]
if node.name.value != "node":
continue

n_gql_definition = schema.schema_converter.from_object(n_definition)
n_info = _generate_selection_resolve_info(
info,
nodes,
n_gql_definition,
e_gql_definition,
)

store = _get_model_hints(
model=model,
schema=schema,
object_definition=n_definition,
parent_type=n_gql_definition,
info=n_info,
config=config,
prefix=prefix,
cache=cache,
level=level,
)

return store


def optimize(
qs: QuerySet[_M] | BaseManager[_M],
info: GraphQLResolveInfo | Info,
Expand Down Expand Up @@ -711,7 +796,6 @@ def optimize(
store = store or OptimizerStore()
schema = cast(Schema, info.schema._strawberry_schema) # type: ignore

field_name = info.field_name
gql_type = get_named_type(info.return_type)
strawberry_type = schema.get_type_by_name(gql_type.name)
if strawberry_type is None:
Expand All @@ -738,21 +822,18 @@ def optimize(
else:
object_definitions = [object_definition]

for selection in convert_selections(info, info.field_nodes):
if isinstance(selection, InlineFragment) or selection.name != field_name:
continue

for t_definition in object_definitions:
new_store = _get_model_hints(
qs.model,
schema,
t_definition,
selection,
info=info,
config=config,
)
if new_store is not None:
store |= new_store
for inner_object_definition in object_definitions:
parent_type = schema.schema_converter.from_object(inner_object_definition)
new_store = _get_model_hints(
qs.model,
schema,
inner_object_definition,
parent_type=parent_type,
info=info,
config=config,
)
if new_store is not None:
store |= new_store

if store:
qs = store.apply(qs, info=info, config=config)
Expand Down
Loading

0 comments on commit f352f2a

Please sign in to comment.