Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-mm committed Jan 3, 2025
2 parents 486ae8c + a4d3f25 commit 9130980
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 53 deletions.
1 change: 0 additions & 1 deletion django/contrib/postgres/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def check(self, **kwargs):
)
)
else:
# Remove the field name checks as they are not needed here.
base_checks = self.base_field.check()
if base_checks:
error_messages = "\n ".join(
Expand Down
16 changes: 9 additions & 7 deletions django/core/management/commands/inspectdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,12 @@ def handle_inspection(self, options):
connection.introspection.get_primary_key_columns(
cursor, table_name
)
or []
)
primary_key_column = (
primary_key_columns[0] if primary_key_columns else None
primary_key_columns[0]
if len(primary_key_columns) == 1
else None
)
unique_columns = [
c["columns"][0]
Expand All @@ -128,6 +131,11 @@ def handle_inspection(self, options):
yield ""
yield "class %s(models.Model):" % model_name
known_models.append(model_name)

if len(primary_key_columns) > 1:
fields = ", ".join([f"'{col}'" for col in primary_key_columns])
yield f" pk = models.CompositePrimaryKey({fields})"

used_column_names = [] # Holds column names used in the table so far
column_to_field_name = {} # Maps column names to names of model fields
used_relations = set() # Holds foreign relations used in the table.
Expand All @@ -151,12 +159,6 @@ def handle_inspection(self, options):
# Add primary_key and unique, if necessary.
if column_name == primary_key_column:
extra_params["primary_key"] = True
if len(primary_key_columns) > 1:
comment_notes.append(
"The composite primary key (%s) found, that is not "
"supported. The first column is selected."
% ", ".join(primary_key_columns)
)
elif column_name in unique_columns:
extra_params["unique"] = True

Expand Down
12 changes: 4 additions & 8 deletions django/db/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from django.db.models.expressions import Case, F, Value, When
from django.db.models.functions import Cast, Trunc
from django.db.models.query_utils import FilteredRelation, Q
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE
from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, ROW_COUNT
from django.db.models.utils import (
AltersData,
create_namedtuple_class,
Expand Down Expand Up @@ -1209,11 +1209,7 @@ def _raw_delete(self, using):
"""
query = self.query.clone()
query.__class__ = sql.DeleteQuery
cursor = query.get_compiler(using).execute_sql(CURSOR)
if cursor:
with cursor:
return cursor.rowcount
return 0
return query.get_compiler(using).execute_sql(ROW_COUNT)

_raw_delete.alters_data = True

Expand Down Expand Up @@ -1252,7 +1248,7 @@ def update(self, **kwargs):
# Clear any annotations so that they won't be present in subqueries.
query.annotations = {}
with transaction.mark_for_rollback_on_error(using=self.db):
rows = query.get_compiler(self.db).execute_sql(CURSOR)
rows = query.get_compiler(self.db).execute_sql(ROW_COUNT)
self._result_cache = None
return rows

Expand All @@ -1277,7 +1273,7 @@ def _update(self, values):
# Clear any annotations so that they won't be present in subqueries.
query.annotations = {}
self._result_cache = None
return query.get_compiler(self.db).execute_sql(CURSOR)
return query.get_compiler(self.db).execute_sql(ROW_COUNT)

_update.alters_data = True
_update.queryset_only = False
Expand Down
46 changes: 26 additions & 20 deletions django/db/models/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MULTI,
NO_RESULTS,
ORDER_DIR,
ROW_COUNT,
SINGLE,
)
from django.db.models.sql.query import Query, get_order_dir
Expand Down Expand Up @@ -1596,15 +1597,15 @@ def execute_sql(
):
"""
Run the query against the database and return the result(s). The
return value is a single data item if result_type is SINGLE, or an
iterator over the results if the result_type is MULTI.
result_type is either MULTI (use fetchmany() to retrieve all rows),
SINGLE (only retrieve a single row), or None. In this last case, the
cursor is returned if any query is executed, since it's used by
subclasses such as InsertQuery). It's possible, however, that no query
is needed, as the filters describe an empty set. In that case, None is
returned, to avoid any unnecessary database interaction.
return value depends on the value of result_type.
When result_type is:
- MULTI: Retrieves all rows using fetchmany(). Wraps in an iterator for
chunked reads when supported.
- SINGLE: Retrieves a single row using fetchone().
- ROW_COUNT: Retrieves the number of rows in the result.
- CURSOR: Runs the query, and returns the cursor object. It is the
caller's responsibility to close the cursor.
"""
result_type = result_type or NO_RESULTS
try:
Expand All @@ -1627,6 +1628,11 @@ def execute_sql(
cursor.close()
raise

if result_type == ROW_COUNT:
try:
return cursor.rowcount
finally:
cursor.close()
if result_type == CURSOR:
# Give the caller the cursor to process and close.
return cursor
Expand Down Expand Up @@ -2069,19 +2075,19 @@ def execute_sql(self, result_type):
non-empty query that is executed. Row counts for any subsequent,
related queries are not available.
"""
cursor = super().execute_sql(result_type)
try:
rows = cursor.rowcount if cursor else 0
is_empty = cursor is None
finally:
if cursor:
cursor.close()
row_count = super().execute_sql(result_type)
is_empty = row_count is None
row_count = row_count or 0

for query in self.query.get_related_updates():
aux_rows = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_rows:
rows = aux_rows
# If the result_type is NO_RESULTS then the aux_row_count is None.
aux_row_count = query.get_compiler(self.using).execute_sql(result_type)
if is_empty and aux_row_count:
# Returns the row count for any related updates as the number of
# rows updated.
row_count = aux_row_count
is_empty = False
return rows
return row_count

def pre_sql_setup(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion django/db/models/sql/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
# How many results to expect from a cursor.execute call
MULTI = "multi"
SINGLE = "single"
CURSOR = "cursor"
NO_RESULTS = "no results"
# Rather than returning results, returns:
CURSOR = "cursor"
ROW_COUNT = "row count"

ORDER_DIR = {
"ASC": ("ASC", "DESC"),
Expand Down
12 changes: 6 additions & 6 deletions django/db/models/sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,12 +1704,12 @@ def add_filtered_relation(self, filtered_relation, alias):
"relations outside the %r (got %r)."
% (filtered_relation.relation_name, lookup)
)
else:
raise ValueError(
"FilteredRelation's condition doesn't support nested "
"relations deeper than the relation_name (got %r for "
"%r)." % (lookup, filtered_relation.relation_name)
)
if len(lookup_field_parts) > len(relation_field_parts) + 1:
raise ValueError(
"FilteredRelation's condition doesn't support nested "
"relations deeper than the relation_name (got %r for "
"%r)." % (lookup, filtered_relation.relation_name)
)
filtered_relation.condition = rename_prefix_from_q(
filtered_relation.relation_name,
alias,
Expand Down
12 changes: 6 additions & 6 deletions django/db/models/sql/subqueries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
"""

from django.core.exceptions import FieldError
from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
from django.db.models.sql.constants import (
GET_ITERATOR_CHUNK_SIZE,
NO_RESULTS,
ROW_COUNT,
)
from django.db.models.sql.query import Query

__all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
Expand All @@ -17,11 +21,7 @@ class DeleteQuery(Query):
def do_query(self, table, where, using):
self.alias_map = {table: self.alias_map[table]}
self.where = where
cursor = self.get_compiler(using).execute_sql(CURSOR)
if cursor:
with cursor:
return cursor.rowcount
return 0
return self.get_compiler(using).execute_sql(ROW_COUNT)

def delete_batch(self, pk_list, using):
"""
Expand Down
2 changes: 2 additions & 0 deletions django/test/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,8 @@ def generic(
"scheme": "https" if secure else "http",
"headers": [(b"host", b"testserver")],
}
if self.defaults:
extra = {**self.defaults, **extra}
if data:
s["headers"].extend(
[
Expand Down
5 changes: 5 additions & 0 deletions docs/ref/class-based-views/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ MRO is an acronym for Method Resolution Order.

Performs key view initialization prior to :meth:`dispatch`.

Assigns the :class:`~django.http.HttpRequest` to the view's ``request``
attribute, and any positional and/or keyword arguments
:ref:`captured from the URL pattern <how-django-processes-a-request>`
to the ``args`` and ``kwargs`` attributes, respectively.

If overriding this method, you must call ``super()``.

.. method:: dispatch(request, *args, **kwargs)
Expand Down
9 changes: 9 additions & 0 deletions tests/delete/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,15 @@ def test_fast_delete_aggregation(self):
)
self.assertIs(Base.objects.exists(), False)

def test_fast_delete_empty_result_set(self):
user = User.objects.create()
with self.assertNumQueries(0):
self.assertEqual(
User.objects.filter(pk__in=[]).delete(),
(0, {}),
)
self.assertSequenceEqual(User.objects.all(), [user])

def test_fast_delete_full_match(self):
avatar = Avatar.objects.create(desc="bar")
User.objects.create(avatar=avatar)
Expand Down
2 changes: 1 addition & 1 deletion tests/file_storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,5 @@ def pathlib_upload_to(self, filename):
storage=temp_storage, upload_to="tests", max_length=20
)
extended_length = models.FileField(
storage=temp_storage, upload_to="tests", max_length=300
storage=temp_storage, upload_to="tests", max_length=1024
)
13 changes: 13 additions & 0 deletions tests/filtered_relation/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,19 @@ def test_condition_deeper_relation_name(self):
),
)

def test_condition_deeper_relation_name_implicit_exact(self):
msg = (
"FilteredRelation's condition doesn't support nested relations "
"deeper than the relation_name (got 'book__editor__name' for 'book')."
)
with self.assertRaisesMessage(ValueError, msg):
Author.objects.annotate(
book_editor=FilteredRelation(
"book",
condition=Q(book__editor__name="b"),
),
)

def test_with_empty_relation_name_error(self):
with self.assertRaisesMessage(ValueError, "relation_name cannot be empty."):
FilteredRelation("", condition=Q(blank=""))
Expand Down
5 changes: 2 additions & 3 deletions tests/inspectdb/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,10 @@ def test_composite_primary_key(self):
call_command("inspectdb", table_name, stdout=out)
output = out.getvalue()
self.assertIn(
f"column_1 = models.{field_type}(primary_key=True) # The composite "
f"primary key (column_1, column_2) found, that is not supported. The "
f"first column is selected.",
"pk = models.CompositePrimaryKey('column_1', 'column_2')",
output,
)
self.assertIn(f"column_1 = models.{field_type}()", output)
self.assertIn(
"column_2 = models.%s()"
% connection.features.introspected_field_types["IntegerField"],
Expand Down
13 changes: 13 additions & 0 deletions tests/test_client/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,19 @@ def test_request_factory_sets_headers(self):
self.assertEqual(request.headers["x-another-header"], "some other value")
self.assertIn("HTTP_X_ANOTHER_HEADER", request.META)

def test_async_request_factory_default_headers(self):
request_factory_with_headers = AsyncRequestFactory(
**{
"Authorization": "Bearer faketoken",
"X-Another-Header": "some other value",
}
)
request = request_factory_with_headers.get("/somewhere/")
self.assertEqual(request.headers["authorization"], "Bearer faketoken")
self.assertIn("HTTP_AUTHORIZATION", request.META)
self.assertEqual(request.headers["x-another-header"], "some other value")
self.assertIn("HTTP_X_ANOTHER_HEADER", request.META)

def test_request_factory_query_string(self):
request = self.request_factory.get("/somewhere/", {"example": "data"})
self.assertNotIn("Query-String", request.headers)
Expand Down

0 comments on commit 9130980

Please sign in to comment.