diff --git a/django/contrib/gis/tests/relatedapp/tests.py b/django/contrib/gis/tests/relatedapp/tests.py index 3d162b065e..acc616c74f 100644 --- a/django/contrib/gis/tests/relatedapp/tests.py +++ b/django/contrib/gis/tests/relatedapp/tests.py @@ -184,12 +184,12 @@ class RelatedGeoModelTest(unittest.TestCase): self.assertEqual(m.point, t[1]) # Test disabled until #10572 is resolved. - #def test08_defer_only(self): - # "Testing defer() and only() on Geographic models." - # qs = Location.objects.all() - # def_qs = Location.objects.defer('point') - # for loc, def_loc in zip(qs, def_qs): - # self.assertEqual(loc.point, def_loc.point) + def test08_defer_only(self): + "Testing defer() and only() on Geographic models." + qs = Location.objects.all() + def_qs = Location.objects.defer('point') + for loc, def_loc in zip(qs, def_qs): + self.assertEqual(loc.point, def_loc.point) # TODO: Related tests for KML, GML, and distance lookups. diff --git a/django/db/models/base.py b/django/db/models/base.py index 01e2ca7011..05cd0d9ea1 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -362,9 +362,8 @@ class Model(object): # DeferredAttribute classes, so we only need to do this # once. obj = self.__class__.__dict__[field.attname] - pk_val = obj.pk_value model = obj.model_ref() - return (model_unpickle, (model, pk_val, defers), data) + return (model_unpickle, (model, defers), data) def _get_pk_val(self, meta=None): if not meta: @@ -635,12 +634,12 @@ def get_absolute_url(opts, func, self, *args, **kwargs): class Empty(object): pass -def model_unpickle(model, pk_val, attrs): +def model_unpickle(model, attrs): """ Used to unpickle Model subclasses with deferred fields. """ from django.db.models.query_utils import deferred_class_factory - cls = deferred_class_factory(model, pk_val, attrs) + cls = deferred_class_factory(model, attrs) return cls.__new__(cls) model_unpickle.__safe_for_unpickle__ = True diff --git a/django/db/models/query.py b/django/db/models/query.py index ea7129b693..9dcc031a39 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -190,6 +190,20 @@ class QuerySet(object): index_start = len(extra_select) aggregate_start = index_start + len(self.model._meta.fields) + load_fields = only_load.get(self.model) + skip = None + if load_fields and not fill_cache: + # Some fields have been deferred, so we have to initialise + # via keyword arguments. + skip = set() + init_list = [] + for field in fields: + if field.name not in load_fields: + skip.add(field.attname) + else: + init_list.append(field.attname) + model_cls = deferred_class_factory(self.model, skip) + for row in self.query.results_iter(): if fill_cache: obj, _ = get_cached_row(self.model, row, @@ -197,25 +211,10 @@ class QuerySet(object): requested=requested, offset=len(aggregate_select), only_load=only_load) else: - load_fields = only_load.get(self.model) - if load_fields: - # Some fields have been deferred, so we have to initialise - # via keyword arguments. + if skip: row_data = row[index_start:aggregate_start] pk_val = row_data[pk_idx] - skip = set() - init_list = [] - for field in fields: - if field.name not in load_fields: - skip.add(field.attname) - else: - init_list.append(field.attname) - if skip: - model_cls = deferred_class_factory(self.model, pk_val, - skip) - obj = model_cls(**dict(zip(init_list, row_data))) - else: - obj = self.model(*row[index_start:aggregate_start]) + obj = model_cls(**dict(zip(init_list, row_data))) else: # Omit aggregates in object creation. obj = self.model(*row[index_start:aggregate_start]) @@ -927,7 +926,7 @@ def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, else: init_list.append(field.attname) if skip: - klass = deferred_class_factory(klass, pk_val, skip) + klass = deferred_class_factory(klass, skip) obj = klass(**dict(zip(init_list, fields))) else: obj = klass(*fields) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 8baa654344..7a5ad919a1 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -158,9 +158,8 @@ class DeferredAttribute(object): A wrapper for a deferred-loading field. When the value is read from this object the first time, the query is executed. """ - def __init__(self, field_name, pk_value, model): + def __init__(self, field_name, model): self.field_name = field_name - self.pk_value = pk_value self.model_ref = weakref.ref(model) self.loaded = False @@ -170,21 +169,18 @@ class DeferredAttribute(object): Returns the cached value. """ assert instance is not None - if not self.loaded: - obj = self.model_ref() - if obj is None: - return - self.value = list(obj._base_manager.filter(pk=self.pk_value).values_list(self.field_name, flat=True))[0] - self.loaded = True - return self.value + cls = self.model_ref() + data = instance.__dict__ + if data.get(self.field_name, self) is self: + data[self.field_name] = cls._base_manager.filter(pk=instance.pk).values_list(self.field_name, flat=True).get() + return data[self.field_name] - def __set__(self, name, value): + def __set__(self, instance, value): """ Deferred loading attributes can be set normally (which means there will never be a database lookup involved. """ - self.value = value - self.loaded = True + instance.__dict__[self.field_name] = value def select_related_descend(field, restricted, requested): """ @@ -206,7 +202,7 @@ def select_related_descend(field, restricted, requested): # This function is needed because data descriptors must be defined on a class # object, not an instance, to have any effect. -def deferred_class_factory(model, pk_value, attrs): +def deferred_class_factory(model, attrs): """ Returns a class object that is a copy of "model" with the specified "attrs" being replaced with DeferredAttribute objects. The "pk_value" ties the @@ -223,7 +219,7 @@ def deferred_class_factory(model, pk_value, attrs): # are identical. name = "%s_Deferred_%s" % (model.__name__, '_'.join(sorted(list(attrs)))) - overrides = dict([(attr, DeferredAttribute(attr, pk_value, model)) + overrides = dict([(attr, DeferredAttribute(attr, model)) for attr in attrs]) overrides["Meta"] = Meta overrides["__module__"] = model.__module__ @@ -233,4 +229,3 @@ def deferred_class_factory(model, pk_value, attrs): # The above function is also used to unpickle model instances with deferred # fields. deferred_class_factory.__safe_for_unpickling__ = True - diff --git a/tests/regressiontests/defer_regress/models.py b/tests/regressiontests/defer_regress/models.py index c46d7ce176..0cd6facff7 100644 --- a/tests/regressiontests/defer_regress/models.py +++ b/tests/regressiontests/defer_regress/models.py @@ -6,7 +6,7 @@ from django.conf import settings from django.db import connection, models class Item(models.Model): - name = models.CharField(max_length=10) + name = models.CharField(max_length=15) text = models.TextField(default="xyzzy") value = models.IntegerField() other_value = models.IntegerField(default=0) @@ -14,6 +14,9 @@ class Item(models.Model): def __unicode__(self): return self.name +class RelatedItem(models.Model): + item = models.ForeignKey(Item) + __test__ = {"regression_tests": """ Deferred fields should really be deferred and not accidentally use the field's default value just because they aren't passed to __init__. @@ -39,9 +42,31 @@ True u"xyzzy" >>> len(connection.queries) == num + 2 # Effect of text lookup. True +>>> obj.text +u"xyzzy" +>>> len(connection.queries) == num + 2 +True >>> settings.DEBUG = False +Regression test for #10695. Make sure different instances don't inadvertently +share data in the deferred descriptor objects. + +>>> i = Item.objects.create(name="no I'm first", value=37) +>>> items = Item.objects.only('value').order_by('-value') +>>> items[0].name +u'first' +>>> items[1].name +u"no I'm first" + +>>> _ = RelatedItem.objects.create(item=i) +>>> r = RelatedItem.objects.defer('item').get() +>>> r.item_id == i.id +True +>>> r.item == i +True + + + """ } -