Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-mm committed Jan 15, 2024
2 parents ce71e86 + 4fec1d2 commit e2bf4c8
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 21 deletions.
28 changes: 18 additions & 10 deletions django/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def get_deferred_fields(self):
if f.attname not in self.__dict__
}

def refresh_from_db(self, using=None, fields=None):
def refresh_from_db(self, using=None, fields=None, from_queryset=None):
"""
Reload field values from the database.
Expand Down Expand Up @@ -705,10 +705,13 @@ def refresh_from_db(self, using=None, fields=None):
"are not allowed in fields." % LOOKUP_SEP
)

hints = {"instance": self}
db_instance_qs = self.__class__._base_manager.db_manager(
using, hints=hints
).filter(pk=self.pk)
if from_queryset is None:
hints = {"instance": self}
from_queryset = self.__class__._base_manager.db_manager(using, hints=hints)
elif using is not None:
from_queryset = from_queryset.using(using)

db_instance_qs = from_queryset.filter(pk=self.pk)

# Use provided fields, if not set then reload all non-deferred fields.
deferred_fields = self.get_deferred_fields()
Expand All @@ -729,9 +732,12 @@ def refresh_from_db(self, using=None, fields=None):
# This field wasn't refreshed - skip ahead.
continue
setattr(self, field.attname, getattr(db_instance, field.attname))
# Clear cached foreign keys.
if field.is_relation and field.is_cached(self):
field.delete_cached_value(self)
# Clear or copy cached foreign keys.
if field.is_relation:
if field.is_cached(db_instance):
field.set_cached_value(self, field.get_cached_value(db_instance))
elif field.is_cached(self):
field.delete_cached_value(self)

# Clear cached relations.
for field in self._meta.related_objects:
Expand All @@ -745,8 +751,10 @@ def refresh_from_db(self, using=None, fields=None):

self._state.db = db_instance._state.db

async def arefresh_from_db(self, using=None, fields=None):
return await sync_to_async(self.refresh_from_db)(using=using, fields=fields)
async def arefresh_from_db(self, using=None, fields=None, from_queryset=None):
return await sync_to_async(self.refresh_from_db)(
using=using, fields=fields, from_queryset=from_queryset
)

def serializable_value(self, field_name):
"""
Expand Down
18 changes: 13 additions & 5 deletions django/db/models/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from django.db.models.constants import LOOKUP_SEP
from django.db.models.query_utils import Q
from django.utils.deconstruct import deconstructible
from django.utils.functional import cached_property
from django.utils.functional import cached_property, classproperty
from django.utils.hashable import make_hashable


Expand Down Expand Up @@ -402,10 +402,13 @@ def relabeled_clone(self, change_map):
return clone

def replace_expressions(self, replacements):
if not replacements:
return self
if replacement := replacements.get(self):
return replacement
if not (source_expressions := self.get_source_expressions()):
return self
clone = self.copy()
source_expressions = clone.get_source_expressions()
clone.set_source_expressions(
[
expr.replace_expressions(replacements) if expr else None
Expand Down Expand Up @@ -485,13 +488,18 @@ def select_format(self, compiler, sql, params):
class Expression(BaseExpression, Combinable):
"""An expression that can be combined with other expressions."""

@classproperty
@functools.lru_cache(maxsize=128)
def _constructor_signature(cls):
return inspect.signature(cls.__init__)

@cached_property
def identity(self):
constructor_signature = inspect.signature(self.__init__)
args, kwargs = self._constructor_args
signature = constructor_signature.bind_partial(*args, **kwargs)
signature = self._constructor_signature.bind_partial(self, *args, **kwargs)
signature.apply_defaults()
arguments = signature.arguments.items()
arguments = iter(signature.arguments.items())
next(arguments)
identity = [self.__class__]
for arg, value in arguments:
if isinstance(value, fields.Field):
Expand Down
2 changes: 2 additions & 0 deletions django/db/models/sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,8 @@ def change_aliases(self, change_map):
relabelling any references to them in select columns and the where
clause.
"""
if not change_map:
return self
# If keys and values of change_map were to intersect, an alias might be
# updated twice (e.g. T4 -> T5, T5 -> T6, so also T4 -> T6) depending
# on their order in change_map.
Expand Down
4 changes: 4 additions & 0 deletions django/db/models/sql/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def relabel_aliases(self, change_map):
Relabel the alias values of any children. 'change_map' is a dictionary
mapping old (current) alias values to the new values.
"""
if not change_map:
return self
for pos, child in enumerate(self.children):
if hasattr(child, "relabel_aliases"):
# For example another WhereNode
Expand All @@ -225,6 +227,8 @@ def relabeled_clone(self, change_map):
return clone

def replace_expressions(self, replacements):
if not replacements:
return self
if replacement := replacements.get(self):
return replacement
clone = self.create(connector=self.connector, negated=self.negated)
Expand Down
6 changes: 4 additions & 2 deletions docs/ref/models/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ For example::
will allow filtering on ``room`` and ``date``, also selecting ``full_name``,
while fetching data only from the index.

``include`` is supported only on PostgreSQL.
Unique constraints with non-key columns are ignored for databases besides
PostgreSQL.

Non-key columns have the same database restrictions as :attr:`Index.include`.

Expand Down Expand Up @@ -272,7 +273,8 @@ For example::
creates a unique constraint that only allows one row to store a ``NULL`` value
in the ``ordering`` column.

``nulls_distinct`` is ignored for databases besides PostgreSQL 15+.
Unique constraints with ``nulls_distinct`` are ignored for databases besides
PostgreSQL 15+.

``violation_error_code``
------------------------
Expand Down
25 changes: 23 additions & 2 deletions docs/ref/models/instances.txt
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ value from the database:
>>> del obj.field
>>> obj.field # Loads the field from the database

.. method:: Model.refresh_from_db(using=None, fields=None)
.. method:: Model.arefresh_from_db(using=None, fields=None)
.. method:: Model.refresh_from_db(using=None, fields=None, from_queryset=None)
.. method:: Model.arefresh_from_db(using=None, fields=None, from_queryset=None)

*Asynchronous version*: ``arefresh_from_db()``

Expand Down Expand Up @@ -197,6 +197,27 @@ all of the instance's fields when a deferred field is reloaded::
fields = fields.union(deferred_fields)
super().refresh_from_db(using, fields, **kwargs)

The ``from_queryset`` argument allows using a different queryset than the one
created from :attr:`~django.db.models.Model._base_manager`. It gives you more
control over how the model is reloaded. For example, when your model uses soft
deletion you can make ``refresh_from_db()`` to take this into account::

obj.refresh_from_db(from_queryset=MyModel.active_objects.all())

You can cache related objects that otherwise would be cleared from the reloaded
instance::

obj.refresh_from_db(from_queryset=MyModel.objects.select_related("related_field"))

You can lock the row until the end of transaction before reloading a model's
values::

obj.refresh_from_db(from_queryset=MyModel.objects.select_for_update())

.. versionchanged:: 5.1

The ``from_queryset`` argument was added.

.. method:: Model.get_deferred_fields()

A helper method that returns a set containing the attribute names of all those
Expand Down
5 changes: 5 additions & 0 deletions docs/releases/5.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ Models
:class:`~django.contrib.postgres.fields.ArrayField` can now be :ref:`sliced
<slicing-using-f>`.

* The new ``from_queryset`` argument of :meth:`.Model.refresh_from_db` and
:meth:`.Model.arefresh_from_db` allows customizing the queryset used to
reload a model's value. This can be used to lock the row before reloading or
to select related objects.

Requests and Responses
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
11 changes: 11 additions & 0 deletions tests/async/test_async_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,14 @@ async def test_arefresh_from_db(self):
await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
await self.s1.arefresh_from_db()
self.assertEqual(self.s1.field, 20)

async def test_arefresh_from_db_from_queryset(self):
await SimpleModel.objects.filter(pk=self.s1.pk).aupdate(field=20)
with self.assertRaises(SimpleModel.DoesNotExist):
await self.s1.arefresh_from_db(
from_queryset=SimpleModel.objects.filter(field=0)
)
await self.s1.arefresh_from_db(
from_queryset=SimpleModel.objects.filter(field__gt=0)
)
self.assertEqual(self.s1.field, 20)
56 changes: 54 additions & 2 deletions tests/basic/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
from unittest import mock

from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
from django.db import (
DEFAULT_DB_ALIAS,
DatabaseError,
connection,
connections,
models,
transaction,
)
from django.db.models.manager import BaseManager
from django.db.models.query import MAX_GET_RESULTS, EmptyQuerySet
from django.test import (
Expand All @@ -13,7 +20,8 @@
TransactionTestCase,
skipUnlessDBFeature,
)
from django.test.utils import ignore_warnings
from django.test.utils import CaptureQueriesContext, ignore_warnings
from django.utils.connection import ConnectionDoesNotExist
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.translation import gettext_lazy

Expand Down Expand Up @@ -1003,3 +1011,47 @@ def test_prefetched_cache_cleared(self):
# Cache was cleared and new results are available.
self.assertCountEqual(a2_prefetched.selfref_set.all(), [s])
self.assertCountEqual(a2_prefetched.cited.all(), [s])

@skipUnlessDBFeature("has_select_for_update")
def test_refresh_for_update(self):
a = Article.objects.create(pub_date=datetime.now())
for_update_sql = connection.ops.for_update_sql()

with transaction.atomic(), CaptureQueriesContext(connection) as ctx:
a.refresh_from_db(from_queryset=Article.objects.select_for_update())
self.assertTrue(
any(for_update_sql in query["sql"] for query in ctx.captured_queries)
)

def test_refresh_with_related(self):
a = Article.objects.create(pub_date=datetime.now())
fa = FeaturedArticle.objects.create(article=a)

from_queryset = FeaturedArticle.objects.select_related("article")
with self.assertNumQueries(1):
fa.refresh_from_db(from_queryset=from_queryset)
self.assertEqual(fa.article.pub_date, a.pub_date)
with self.assertNumQueries(2):
fa.refresh_from_db()
self.assertEqual(fa.article.pub_date, a.pub_date)

def test_refresh_overwrites_queryset_using(self):
a = Article.objects.create(pub_date=datetime.now())

from_queryset = Article.objects.using("nonexistent")
with self.assertRaises(ConnectionDoesNotExist):
a.refresh_from_db(from_queryset=from_queryset)
a.refresh_from_db(using="default", from_queryset=from_queryset)

def test_refresh_overwrites_queryset_fields(self):
a = Article.objects.create(pub_date=datetime.now())
headline = "headline"
Article.objects.filter(pk=a.pk).update(headline=headline)

from_queryset = Article.objects.only("pub_date")
with self.assertNumQueries(1):
a.refresh_from_db(from_queryset=from_queryset)
self.assertNotEqual(a.headline, headline)
with self.assertNumQueries(1):
a.refresh_from_db(fields=["headline"], from_queryset=from_queryset)
self.assertEqual(a.headline, headline)

0 comments on commit e2bf4c8

Please sign in to comment.