From f352f2a0daaba229881a4f978acca422127b4fd5 Mon Sep 17 00:00:00 2001 From: Thiago Bellini Ribeiro Date: Thu, 30 May 2024 18:51:47 -0300 Subject: [PATCH] refactor: Use graphql-core's collect_sub_fields instead of our own implementation (#537) This allows us to simplify our optimizer implementation, and also paves the way for nested optimizations implementation in the near future. --- strawberry_django/optimizer.py | 263 +++++++++++++++++++---------- strawberry_django/utils/inspect.py | 95 ----------- 2 files changed, 172 insertions(+), 186 deletions(-) diff --git a/strawberry_django/optimizer.py b/strawberry_django/optimizer.py index 47df7606..a9c26f9e 100644 --- a/strawberry_django/optimizer.py +++ b/strawberry_django/optimizer.py @@ -28,6 +28,8 @@ ) from django.db.models.manager import BaseManager from django.db.models.query import QuerySet +from graphql import FieldNode, GraphQLObjectType, GraphQLOutputType, GraphQLWrappingType +from graphql.execution.collect_fields import collect_sub_fields from graphql.language.ast import OperationType from graphql.type.definition import GraphQLResolveInfo, get_named_type from strawberry import relay @@ -37,7 +39,6 @@ from strawberry.schema.schema import Schema from strawberry.type import get_object_definition from strawberry.types.info import Info -from strawberry.types.nodes import InlineFragment, Selection, convert_selections from strawberry.utils.typing import eval_type from typing_extensions import assert_never, assert_type, get_args @@ -50,7 +51,6 @@ PrefetchInspector, get_model_fields, get_possible_type_definitions, - get_selections, ) from .utils.typing import ( AnnotateCallable, @@ -364,12 +364,48 @@ def _get_prefetch_queryset( ) +def _get_selections( + info: GraphQLResolveInfo, + parent_type: GraphQLObjectType, +) -> dict[str, list[FieldNode]]: + return collect_sub_fields( + info.schema, + info.fragments, + info.variable_values, + parent_type, + info.field_nodes, + ) + + +def _generate_selection_resolve_info( + info: GraphQLResolveInfo, + field_nodes: list[FieldNode], + return_type: GraphQLOutputType, + parent_type: GraphQLObjectType, +): + field_node = field_nodes[0] + return GraphQLResolveInfo( + field_name=field_node.name.value, + field_nodes=field_nodes, + return_type=return_type, + parent_type=parent_type, + path=info.path.add_key(0).add_key(field_node.name.value, parent_type.name), + schema=info.schema, + fragments=info.fragments, + root_value=info.root_value, + operation=info.operation, + variable_values=info.variable_values, + context=info.context, + is_awaitable=info.is_awaitable, + ) + + def _get_model_hints( model: type[models.Model], schema: Schema, object_definition: StrawberryObjectDefinition, - selection: Selection, *, + parent_type: GraphQLObjectType, info: GraphQLResolveInfo, config: OptimizerConfig | None = None, prefix: str = "", @@ -393,63 +429,24 @@ def _get_model_hints( GenericRelation, ) - store = OptimizerStore() cache = cache or {} - t_name = schema.config.name_converter.from_object(object_definition) # In case this is a relay field, find the selected edges/nodes, the selected fields # are actually inside edges -> node selection... - if issubclass( - object_definition.origin, - relay.Connection, - ): - # TODO: Connections are mostly used for pagination so it doesn't make sense for - # us to optimize those, as our prefetch would be thrown away causing an extra - # useless query. Is there a way for us to properly optimize this in the future? - if level > 0: - return None - - n_type = object_definition.type_var_map.get("NodeType") - if n_type is None: - specialized_type_var_map = object_definition.specialized_type_var_map or {} - - n_type = specialized_type_var_map["NodeType"] - if isinstance(n_type, LazyType): - n_type = n_type.resolve_type() - - n_definition = get_object_definition(n_type, strict=True) - - for edges in get_selections(selection, typename=t_name).values(): - if edges.name != "edges": - continue - - e_definition = get_object_definition(relay.Edge, strict=True) - e_type = e_definition.resolve_generic( - relay.Edge[cast(Type[relay.Node], n_type)], - ) - e_name = schema.config.name_converter.from_object( - get_object_definition(e_type, strict=True), - ) - for node in get_selections(edges, typename=e_name).values(): - if node.name != "node": - continue - - new_store = _get_model_hints( - model=model, - schema=schema, - object_definition=n_definition, - selection=node, - info=info, - config=config, - prefix=prefix, - cache=cache, - level=level, - ) - if new_store is not None: - store |= new_store - - return store + if issubclass(object_definition.origin, relay.Connection): + return _get_model_hints_from_connection( + model, + schema, + object_definition, + parent_type=parent_type, + info=info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + store = OptimizerStore() fields = { schema.config.name_converter.get_graphql_name(f): f for f in object_definition.fields @@ -473,8 +470,9 @@ def _get_model_hints( if pk is not None: store.only.append(pk.attname) - for f_selection in get_selections(selection, typename=t_name).values(): - field = fields.get(f_selection.name, None) + for f_selections in _get_selections(info, parent_type).values(): + f_selection = f_selections[0] + field = fields.get(f_selection.name.value, None) if not field: continue @@ -482,9 +480,20 @@ def _get_model_hints( if getattr(field, "disable_optimization", False): continue + field_definition = parent_type.fields[f_selection.name.value].type + while isinstance(field_definition, GraphQLWrappingType): + field_definition = field_definition.of_type + + f_info = _generate_selection_resolve_info( + info, + f_selections, + field_definition, + parent_type, + ) + # Add annotations from the field if they exist field_store = getattr(field, "store", None) - if field_store is not None: + if field_store: if ( len(field_store.annotate) == 1 and _annotate_placeholder in field_store.annotate @@ -500,14 +509,20 @@ def _get_model_hints( field.name: field_store.annotate[_annotate_placeholder], } store |= ( - field_store.with_prefix(prefix, info=info) if prefix else field_store + field_store.with_prefix(prefix, info=f_info) if prefix else field_store ) # Then from the model property if one is defined model_attr = getattr(model, field.python_name, None) - if model_attr is not None and isinstance(model_attr, ModelProperty): + if ( + model_attr is not None + and isinstance(model_attr, ModelProperty) + and model_attr.store + ): attr_store = model_attr.store - store |= attr_store.with_prefix(prefix, info=info) if prefix else attr_store + store |= ( + attr_store.with_prefix(prefix, info=f_info) if prefix else attr_store + ) # Lastly, from the django field itself model_fieldname: str = getattr(field, "django_name", None) or field.python_name @@ -533,24 +548,21 @@ def _get_model_hints( f_model, schema, f_type_def, - f_selection, - info=info, + parent_type=cast(GraphQLObjectType, field_definition), + info=f_info, config=config, cache=cache, level=level + 1, ) if f_store is not None: cache.setdefault(f_model, []).append((level, f_store)) - store |= f_store.with_prefix(path, info=info) + store |= f_store.with_prefix(path, info=f_info) elif GenericForeignKey and isinstance(model_field, GenericForeignKey): # There's not much we can do to optimize generic foreign keys regarding # only/select_related because they can be anything. # Just prefetch_related them store.prefetch_related.append(model_fieldname) - elif isinstance( - model_field, - _relation_fields, - ): + elif isinstance(model_field, _relation_fields): f_types = list(get_possible_type_definitions(field.type)) if len(f_types) > 1: # This might be a generic foreign key. @@ -563,8 +575,8 @@ def _get_model_hints( remote_model, schema, f_types[0], - f_selection, - info=info, + parent_type=cast(GraphQLObjectType, field_definition), + info=f_info, config=config, cache=cache, level=level + 1, @@ -625,11 +637,7 @@ def _get_model_hints( config, info, ) - f_qs = f_store.apply( - base_qs, - info=info, - config=config, - ) + f_qs = f_store.apply(base_qs, info=f_info, config=config) f_prefetch = Prefetch(path, queryset=f_qs) f_prefetch._optimizer_sentinel = _sentinel # type: ignore store.prefetch_related.append(f_prefetch) @@ -651,6 +659,83 @@ def _get_model_hints( return store +def _get_model_hints_from_connection( + model: type[models.Model], + schema: Schema, + object_definition: StrawberryObjectDefinition, + *, + parent_type: GraphQLObjectType, + info: GraphQLResolveInfo, + config: OptimizerConfig | None = None, + prefix: str = "", + cache: dict[type[models.Model], list[tuple[int, OptimizerStore]]] | None = None, + level: int = 0, +) -> OptimizerStore | None: + # TODO: Connections are mostly used for pagination so it doesn't make sense for + # us to optimize those, as our prefetch would be thrown away causing an extra + # useless query. Is there a way for us to properly optimize this in the future? + if level > 0: + return None + + store = None + + n_type = object_definition.type_var_map.get("NodeType") + if n_type is None: + specialized_type_var_map = object_definition.specialized_type_var_map or {} + + n_type = specialized_type_var_map["NodeType"] + + if isinstance(n_type, LazyType): + n_type = n_type.resolve_type() + + n_definition = get_object_definition(n_type, strict=True) + + for edges in _get_selections(info, parent_type).values(): + edge = edges[0] + if edge.name.value != "edges": + continue + + e_definition = get_object_definition(relay.Edge, strict=True) + e_type = e_definition.resolve_generic( + relay.Edge[cast(Type[relay.Node], n_type)], + ) + e_gql_definition = schema.schema_converter.from_object( + get_object_definition(e_type, strict=True), + ) + e_info = _generate_selection_resolve_info( + info, + edges, + e_gql_definition, + parent_type, + ) + for nodes in _get_selections(e_info, e_gql_definition).values(): + node = nodes[0] + if node.name.value != "node": + continue + + n_gql_definition = schema.schema_converter.from_object(n_definition) + n_info = _generate_selection_resolve_info( + info, + nodes, + n_gql_definition, + e_gql_definition, + ) + + store = _get_model_hints( + model=model, + schema=schema, + object_definition=n_definition, + parent_type=n_gql_definition, + info=n_info, + config=config, + prefix=prefix, + cache=cache, + level=level, + ) + + return store + + def optimize( qs: QuerySet[_M] | BaseManager[_M], info: GraphQLResolveInfo | Info, @@ -711,7 +796,6 @@ def optimize( store = store or OptimizerStore() schema = cast(Schema, info.schema._strawberry_schema) # type: ignore - field_name = info.field_name gql_type = get_named_type(info.return_type) strawberry_type = schema.get_type_by_name(gql_type.name) if strawberry_type is None: @@ -738,21 +822,18 @@ def optimize( else: object_definitions = [object_definition] - for selection in convert_selections(info, info.field_nodes): - if isinstance(selection, InlineFragment) or selection.name != field_name: - continue - - for t_definition in object_definitions: - new_store = _get_model_hints( - qs.model, - schema, - t_definition, - selection, - info=info, - config=config, - ) - if new_store is not None: - store |= new_store + for inner_object_definition in object_definitions: + parent_type = schema.schema_converter.from_object(inner_object_definition) + new_store = _get_model_hints( + qs.model, + schema, + inner_object_definition, + parent_type=parent_type, + info=info, + config=config, + ) + if new_store is not None: + store |= new_store if store: qs = store.apply(qs, info=info, config=config) diff --git a/strawberry_django/utils/inspect.py b/strawberry_django/utils/inspect.py index 65ea8c92..9d1badf6 100644 --- a/strawberry_django/utils/inspect.py +++ b/strawberry_django/utils/inspect.py @@ -27,12 +27,6 @@ StrawberryTypeVar, has_object_definition, ) -from strawberry.types.nodes import ( - FragmentSpread, - InlineFragment, - SelectedField, - Selection, -) from strawberry.types.types import StrawberryObjectDefinition from strawberry.union import StrawberryUnion from strawberry.utils.str_converters import to_camel_case @@ -133,95 +127,6 @@ def get_possible_type_definitions( yield t.__strawberry_definition__ -def get_selections( - selection: Selection, - *, - typename: Optional[str] = None, -) -> Dict[str, SelectedField]: - """Resolve subselections considering fragments. - - Args: - ---- - selection: - The selection to retrieve subselections from - typename: - Only resolve fragments for that typename - - Yields: - ------ - All possibilities for the type - - """ - # Because of the way graphql spreads fragments, - # later selections should replace previous ones - ret: Dict[str, SelectedField] = {} - - def merge_selections(f1: SelectedField, f2: SelectedField) -> SelectedField: - if not f1.selections: - return f2 - if not f2.selections: - return f1 - - f1_selections = { - s.name: s for s in f1.selections if isinstance(s, SelectedField) - } - f2_selections = { - s.name: s for s in f2.selections if isinstance(s, SelectedField) - } - - selections: dict[str, SelectedField] = {} - for f_name in set(f1_selections) - set(f2_selections): - selections[f_name] = f1_selections[f_name] - for f_name in set(f2_selections) - set(f1_selections): - selections[f_name] = f2_selections[f_name] - for f_name in set(f2_selections) & set(f1_selections): - selections[f_name] = f1_selections[f_name] - selections[f_name] = merge_selections( - f1_selections[f_name], - f2_selections[f_name], - ) - - f1.selections = list(selections.values()) + [ - s - for s in (f1.selections + f2.selections) - if isinstance(s, (FragmentSpread, InlineFragment)) - ] - return f1 - - for s in selection.selections: - if isinstance(s, SelectedField): - # @include(if: ) - include = s.directives.get("include") - if include and not include["if"]: - continue - - # @skip(if: ) - skip = s.directives.get("skip") - if skip and skip["if"]: - continue - - f_name = s.alias or s.name - existing = ret.get(f_name) - if existing is not None: - ret[f_name] = merge_selections(existing, s) - else: - ret[f_name] = s - elif isinstance(s, (FragmentSpread, InlineFragment)): - if typename is not None and s.type_condition != typename: - continue - - for f_name, f in get_selections(s, typename=typename).items(): - existing = ret.get(f_name) - if existing is not None: - ret[f_name] = merge_selections(existing, f) - else: - ret[f_name] = f - else: # pragma:nocover - raise TypeError(s) - - return ret - - @dataclasses.dataclass(eq=True) class PrefetchInspector: """Prefetch hints."""