Skip to content

Commit

Permalink
Merge pull request #135 from yunojuno/feat/custom-field-cls
Browse files Browse the repository at this point in the history
feature: allow overriding field class
  • Loading branch information
bellini666 authored Jul 4, 2022
2 parents c63641f + 4fac45d commit fe0479e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 7 deletions.
4 changes: 4 additions & 0 deletions strawberry_django/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def django_model(self):
type_ = utils.unwrap_type(self.type)
return utils.get_django_model(type_)

@classmethod
def from_field(cls, field, django_type):
raise NotImplementedError


class StrawberryDjangoField(
StrawberryDjangoPagination,
Expand Down
32 changes: 25 additions & 7 deletions strawberry_django/type.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import dataclasses
from typing import Any, Optional
from typing import Any, Optional, TypeVar

import django
import strawberry
from strawberry import UNSET
from strawberry.annotation import StrawberryAnnotation

from . import utils
from .fields.field import StrawberryDjangoField
from .fields.field import StrawberryDjangoField, StrawberryDjangoFieldBase
from .fields.types import (
get_model_field,
is_optional,
Expand All @@ -18,6 +18,10 @@

_type = type

StrawberryDjangoFieldType = TypeVar(
"StrawberryDjangoFieldType", bound=StrawberryDjangoFieldBase
)


def get_type_attr(type_, field_name):
attr = getattr(type_, field_name, UNSET)
Expand All @@ -32,9 +36,9 @@ def get_field(django_type, field_name, field_annotation=None):
attr = get_type_attr(django_type.origin, field_name)

if utils.is_field(attr):
field = StrawberryDjangoField.from_field(attr, django_type)
field = django_type.field_cls.from_field(attr, django_type)
else:
field = StrawberryDjangoField(
field = django_type.field_cls(
default=attr,
type_annotation=field_annotation,
)
Expand Down Expand Up @@ -122,11 +126,24 @@ class StrawberryDjangoType:
filters: Any
order: Any
pagination: Any


def process_type(cls, model, *, filters=UNSET, pagination=UNSET, order=UNSET, **kwargs):
field_cls: StrawberryDjangoFieldType


def process_type(
cls,
model,
*,
filters=UNSET,
pagination=UNSET,
order=UNSET,
field_cls=UNSET,
**kwargs
):
original_annotations = cls.__dict__.get("__annotations__", {})

if not field_cls or field_cls is UNSET:
field_cls = StrawberryDjangoField

django_type = StrawberryDjangoType(
origin=cls,
model=model,
Expand All @@ -136,6 +153,7 @@ def process_type(cls, model, *, filters=UNSET, pagination=UNSET, order=UNSET, **
filters=filters,
order=order,
pagination=pagination,
field_cls=field_cls,
)

fields = get_fields(django_type)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from strawberry import auto

import strawberry_django
from strawberry_django.fields.field import StrawberryDjangoField

from .models import User

Expand Down Expand Up @@ -36,3 +37,42 @@ class InputType:
user = InputType(1, "user")
assert user.id == 1
assert user.name == "user"


def test_custom_field_cls():
"""Custom field_cls is applied to all fields."""

class CustomStrawberryDjangoField(StrawberryDjangoField):
pass

@strawberry_django.type(User, field_cls=CustomStrawberryDjangoField)
class UserType:
id: int
name: auto

assert all(
isinstance(field, CustomStrawberryDjangoField)
for field in UserType._type_definition.fields
)


def test_custom_field_cls__explicit_field_type():
"""Custom field_cls is applied to all fields."""

class CustomStrawberryDjangoField(StrawberryDjangoField):
pass

@strawberry_django.type(User, field_cls=CustomStrawberryDjangoField)
class UserType:
id: int
name: auto = strawberry_django.field()

assert isinstance(
UserType._type_definition.get_field("id"), CustomStrawberryDjangoField
)
assert isinstance(
UserType._type_definition.get_field("name"), StrawberryDjangoField
)
assert not isinstance(
UserType._type_definition.get_field("name"), CustomStrawberryDjangoField
)

0 comments on commit fe0479e

Please sign in to comment.