[4.1.x] Fixed #34139 -- Fixed acreate(), aget_or_create(), and aupdate_or_create() methods for related managers.

Bug in 58b27e0dbb3d31ca1438790870b2b51ecdb10500.

Backport of 7b94847e384b1a8c05a7d4c8778958c0290bdf9a from main
This commit is contained in:
Jon Janzen 2022-11-04 15:22:32 +01:00 committed by Mariusz Felisiak
parent 8740d2f452
commit 9fb57fcc70
8 changed files with 155 additions and 0 deletions

View File

@ -491,6 +491,7 @@ answer newbie questions, and generally made Django that much better:
John Shaffer <jshaffer2112@gmail.com> John Shaffer <jshaffer2112@gmail.com>
Jökull Sólberg Auðunsson <jokullsolberg@gmail.com> Jökull Sólberg Auðunsson <jokullsolberg@gmail.com>
Jon Dufresne <jon.dufresne@gmail.com> Jon Dufresne <jon.dufresne@gmail.com>
Jon Janzen <jon@jonjanzen.com>
Jonas Haag <jonas@lophus.org> Jonas Haag <jonas@lophus.org>
Jonas Lundberg <jonas.lundberg@gmail.com> Jonas Lundberg <jonas.lundberg@gmail.com>
Jonathan Davis <jonathandavis47780@gmail.com> Jonathan Davis <jonathandavis47780@gmail.com>

View File

@ -2,6 +2,8 @@ import functools
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from asgiref.sync import sync_to_async
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.core import checks from django.core import checks
from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
@ -744,6 +746,11 @@ def create_generic_related_manager(superclass, rel):
create.alters_data = True create.alters_data = True
async def acreate(self, **kwargs):
return await sync_to_async(self.create)(**kwargs)
acreate.alters_data = True
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val kwargs[self.object_id_field_name] = self.pk_val
@ -752,6 +759,11 @@ def create_generic_related_manager(superclass, rel):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, **kwargs):
return await sync_to_async(self.get_or_create)(**kwargs)
aget_or_create.alters_data = True
def update_or_create(self, **kwargs): def update_or_create(self, **kwargs):
kwargs[self.content_type_field_name] = self.content_type kwargs[self.content_type_field_name] = self.content_type
kwargs[self.object_id_field_name] = self.pk_val kwargs[self.object_id_field_name] = self.pk_val
@ -760,4 +772,9 @@ def create_generic_related_manager(superclass, rel):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, **kwargs):
return await sync_to_async(self.update_or_create)(**kwargs)
aupdate_or_create.alters_data = True
return GenericRelatedObjectManager return GenericRelatedObjectManager

View File

@ -63,6 +63,8 @@ and two directions (forward and reverse) for a total of six combinations.
``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead. ``ReverseManyToManyDescriptor``, use ``ManyToManyDescriptor`` instead.
""" """
from asgiref.sync import sync_to_async
from django.core.exceptions import FieldError from django.core.exceptions import FieldError
from django.db import connections, router, transaction from django.db import connections, router, transaction
from django.db.models import Q, signals from django.db.models import Q, signals
@ -765,6 +767,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
create.alters_data = True create.alters_data = True
async def acreate(self, **kwargs):
return await sync_to_async(self.create)(**kwargs)
acreate.alters_data = True
def get_or_create(self, **kwargs): def get_or_create(self, **kwargs):
self._check_fk_val() self._check_fk_val()
kwargs[self.field.name] = self.instance kwargs[self.field.name] = self.instance
@ -773,6 +780,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, **kwargs):
return await sync_to_async(self.get_or_create)(**kwargs)
aget_or_create.alters_data = True
def update_or_create(self, **kwargs): def update_or_create(self, **kwargs):
self._check_fk_val() self._check_fk_val()
kwargs[self.field.name] = self.instance kwargs[self.field.name] = self.instance
@ -781,6 +793,11 @@ def create_reverse_many_to_one_manager(superclass, rel):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, **kwargs):
return await sync_to_async(self.update_or_create)(**kwargs)
aupdate_or_create.alters_data = True
# remove() and clear() are only provided if the ForeignKey can have a # remove() and clear() are only provided if the ForeignKey can have a
# value of null. # value of null.
if rel.field.null: if rel.field.null:
@ -1161,6 +1178,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
create.alters_data = True create.alters_data = True
async def acreate(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.create)(
through_defaults=through_defaults, **kwargs
)
acreate.alters_data = True
def get_or_create(self, *, through_defaults=None, **kwargs): def get_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance) db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create( obj, created = super(ManyRelatedManager, self.db_manager(db)).get_or_create(
@ -1174,6 +1198,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
get_or_create.alters_data = True get_or_create.alters_data = True
async def aget_or_create(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.get_or_create)(
through_defaults=through_defaults, **kwargs
)
aget_or_create.alters_data = True
def update_or_create(self, *, through_defaults=None, **kwargs): def update_or_create(self, *, through_defaults=None, **kwargs):
db = router.db_for_write(self.instance.__class__, instance=self.instance) db = router.db_for_write(self.instance.__class__, instance=self.instance)
obj, created = super( obj, created = super(
@ -1187,6 +1218,13 @@ def create_forward_many_to_many_manager(superclass, rel, reverse):
update_or_create.alters_data = True update_or_create.alters_data = True
async def aupdate_or_create(self, *, through_defaults=None, **kwargs):
return await sync_to_async(self.update_or_create)(
through_defaults=through_defaults, **kwargs
)
aupdate_or_create.alters_data = True
def _get_target_ids(self, target_field_name, objs): def _get_target_ids(self, target_field_name, objs):
""" """
Return the set of ids of `objs` that the target field references. Return the set of ids of `objs` that the target field references.

View File

@ -76,6 +76,9 @@ Related objects reference
intermediate instance(s). intermediate instance(s).
.. method:: create(through_defaults=None, **kwargs) .. method:: create(through_defaults=None, **kwargs)
.. method:: acreate(through_defaults=None, **kwargs)
*Asynchronous version*: ``acreate``
Creates a new object, saves it and puts it in the related object set. Creates a new object, saves it and puts it in the related object set.
Returns the newly created object:: Returns the newly created object::
@ -110,6 +113,10 @@ Related objects reference
needed. You can use callables as values in the ``through_defaults`` needed. You can use callables as values in the ``through_defaults``
dictionary. dictionary.
.. versionchanged:: 4.1
``acreate()`` method was added.
.. method:: remove(*objs, bulk=True) .. method:: remove(*objs, bulk=True)
Removes the specified model objects from the related object set:: Removes the specified model objects from the related object set::

View File

@ -16,3 +16,7 @@ Bugfixes
an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and an empty :meth:`Sitemap.items() <django.contrib.sitemaps.Sitemap.items>` and
a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod` a callable :attr:`~django.contrib.sitemaps.Sitemap.lastmod`
(:ticket:`34088`). (:ticket:`34088`).
* Fixed a bug in Django 4.1 that caused a crash of ``acreate()``,
``aget_or_create()``, and ``aupdate_or_create()`` asynchronous methods for
related managers (:ticket:`34139`).

View File

@ -9,3 +9,7 @@ class RelatedModel(models.Model):
class SimpleModel(models.Model): class SimpleModel(models.Model):
field = models.IntegerField() field = models.IntegerField()
created = models.DateTimeField(default=timezone.now) created = models.DateTimeField(default=timezone.now)
class ManyToManyModel(models.Model):
simples = models.ManyToManyField("SimpleModel")

View File

@ -0,0 +1,56 @@
from django.test import TestCase
from .models import ManyToManyModel, SimpleModel
class AsyncRelatedManagersOperationTest(TestCase):
@classmethod
def setUpTestData(cls):
cls.mtm1 = ManyToManyModel.objects.create()
cls.s1 = SimpleModel.objects.create(field=0)
async def test_acreate(self):
await self.mtm1.simples.acreate(field=2)
new_simple = await self.mtm1.simples.aget()
self.assertEqual(new_simple.field, 2)
async def test_acreate_reverse(self):
await self.s1.relatedmodel_set.acreate()
new_relatedmodel = await self.s1.relatedmodel_set.aget()
self.assertEqual(new_relatedmodel.simple, self.s1)
async def test_aget_or_create(self):
new_simple, created = await self.mtm1.simples.aget_or_create(field=2)
self.assertIs(created, True)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
new_simple, created = await self.mtm1.simples.aget_or_create(
id=new_simple.id, through_defaults={"field": 3}
)
self.assertIs(created, False)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
async def test_aget_or_create_reverse(self):
new_relatedmodel, created = await self.s1.relatedmodel_set.aget_or_create()
self.assertIs(created, True)
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
self.assertEqual(new_relatedmodel.simple, self.s1)
async def test_aupdate_or_create(self):
new_simple, created = await self.mtm1.simples.aupdate_or_create(field=2)
self.assertIs(created, True)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 2)
new_simple, created = await self.mtm1.simples.aupdate_or_create(
id=new_simple.id, defaults={"field": 3}
)
self.assertIs(created, False)
self.assertEqual(await self.mtm1.simples.acount(), 1)
self.assertEqual(new_simple.field, 3)
async def test_aupdate_or_create_reverse(self):
new_relatedmodel, created = await self.s1.relatedmodel_set.aupdate_or_create()
self.assertIs(created, True)
self.assertEqual(await self.s1.relatedmodel_set.acount(), 1)
self.assertEqual(new_relatedmodel.simple, self.s1)

View File

@ -45,6 +45,10 @@ class GenericRelationsTests(TestCase):
# Original list of tags: # Original list of tags:
return obj.tag, obj.content_type.model_class(), obj.object_id return obj.tag, obj.content_type.model_class(), obj.object_id
async def test_generic_async_acreate(self):
await self.bacon.tags.acreate(tag="orange")
self.assertEqual(await self.bacon.tags.acount(), 3)
def test_generic_update_or_create_when_created(self): def test_generic_update_or_create_when_created(self):
""" """
Should be able to use update_or_create from the generic related manager Should be able to use update_or_create from the generic related manager
@ -70,6 +74,18 @@ class GenericRelationsTests(TestCase):
self.assertEqual(count + 1, self.bacon.tags.count()) self.assertEqual(count + 1, self.bacon.tags.count())
self.assertEqual(tag.tag, "juicy") self.assertEqual(tag.tag, "juicy")
async def test_generic_async_aupdate_or_create(self):
tag, created = await self.bacon.tags.aupdate_or_create(
id=self.fatty.id, defaults={"tag": "orange"}
)
self.assertIs(created, False)
self.assertEqual(tag.tag, "orange")
self.assertEqual(await self.bacon.tags.acount(), 2)
tag, created = await self.bacon.tags.aupdate_or_create(tag="pink")
self.assertIs(created, True)
self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "pink")
def test_generic_get_or_create_when_created(self): def test_generic_get_or_create_when_created(self):
""" """
Should be able to use get_or_create from the generic related manager Should be able to use get_or_create from the generic related manager
@ -96,6 +112,18 @@ class GenericRelationsTests(TestCase):
# shouldn't had changed the tag # shouldn't had changed the tag
self.assertEqual(tag.tag, "stinky") self.assertEqual(tag.tag, "stinky")
async def test_generic_async_aget_or_create(self):
tag, created = await self.bacon.tags.aget_or_create(
id=self.fatty.id, defaults={"tag": "orange"}
)
self.assertIs(created, False)
self.assertEqual(tag.tag, "fatty")
self.assertEqual(await self.bacon.tags.acount(), 2)
tag, created = await self.bacon.tags.aget_or_create(tag="orange")
self.assertIs(created, True)
self.assertEqual(await self.bacon.tags.acount(), 3)
self.assertEqual(tag.tag, "orange")
def test_generic_relations_m2m_mimic(self): def test_generic_relations_m2m_mimic(self):
""" """
Objects with declared GenericRelations can be tagged directly -- the Objects with declared GenericRelations can be tagged directly -- the