Allowed Context.push to behave as a context mananger.

Thanks Loic Bistuer for the review.
This commit is contained in:
Curtis Maloney 2013-07-16 21:11:32 +10:00 committed by Tim Graham
parent 828359e52d
commit a3e7d73ed7
7 changed files with 145 additions and 98 deletions

View File

@ -12,6 +12,21 @@ class ContextPopException(Exception):
"pop() has been called more times than push()" "pop() has been called more times than push()"
pass pass
class ContextDict(dict):
def __init__(self, context, *args, **kwargs):
super(ContextDict, self).__init__(*args, **kwargs)
context.dicts.append(self)
self.context = context
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
self.context.pop()
class BaseContext(object): class BaseContext(object):
def __init__(self, dict_=None): def __init__(self, dict_=None):
self._reset_dicts(dict_) self._reset_dicts(dict_)
@ -34,10 +49,8 @@ class BaseContext(object):
for d in reversed(self.dicts): for d in reversed(self.dicts):
yield d yield d
def push(self): def push(self, *args, **kwargs):
d = {} return ContextDict(self, *args, **kwargs)
self.dicts.append(d)
return d
def pop(self): def pop(self):
if len(self.dicts) == 1: if len(self.dicts) == 1:
@ -83,6 +96,7 @@ class BaseContext(object):
new_context._reset_dicts(values) new_context._reset_dicts(values)
return new_context return new_context
class Context(BaseContext): class Context(BaseContext):
"A stack container for variable context" "A stack container for variable context"
def __init__(self, dict_=None, autoescape=True, current_app=None, def __init__(self, dict_=None, autoescape=True, current_app=None,
@ -106,6 +120,7 @@ class Context(BaseContext):
self.dicts.append(other_dict) self.dicts.append(other_dict)
return other_dict return other_dict
class RenderContext(BaseContext): class RenderContext(BaseContext):
""" """
A stack container for storing Template state. A stack container for storing Template state.

View File

@ -95,10 +95,9 @@ class FilterNode(Node):
def render(self, context): def render(self, context):
output = self.nodelist.render(context) output = self.nodelist.render(context)
# Apply filters. # Apply filters.
context.update({'var': output}) with context.push(var=output):
filtered = self.filter_expr.resolve(context) return self.filter_expr.resolve(context)
context.pop()
return filtered
class FirstOfNode(Node): class FirstOfNode(Node):
def __init__(self, variables, escape=False): def __init__(self, variables, escape=False):
@ -143,71 +142,69 @@ class ForNode(Node):
parentloop = context['forloop'] parentloop = context['forloop']
else: else:
parentloop = {} parentloop = {}
context.push() with context.push():
try: try:
values = self.sequence.resolve(context, True) values = self.sequence.resolve(context, True)
except VariableDoesNotExist: except VariableDoesNotExist:
values = [] values = []
if values is None: if values is None:
values = [] values = []
if not hasattr(values, '__len__'): if not hasattr(values, '__len__'):
values = list(values) values = list(values)
len_values = len(values) len_values = len(values)
if len_values < 1: if len_values < 1:
context.pop() return self.nodelist_empty.render(context)
return self.nodelist_empty.render(context) nodelist = NodeList()
nodelist = NodeList() if self.is_reversed:
if self.is_reversed: values = reversed(values)
values = reversed(values) unpack = len(self.loopvars) > 1
unpack = len(self.loopvars) > 1 # Create a forloop value in the context. We'll update counters on each
# Create a forloop value in the context. We'll update counters on each # iteration just below.
# iteration just below. loop_dict = context['forloop'] = {'parentloop': parentloop}
loop_dict = context['forloop'] = {'parentloop': parentloop} for i, item in enumerate(values):
for i, item in enumerate(values): # Shortcuts for current loop iteration number.
# Shortcuts for current loop iteration number. loop_dict['counter0'] = i
loop_dict['counter0'] = i loop_dict['counter'] = i+1
loop_dict['counter'] = i+1 # Reverse counter iteration numbers.
# Reverse counter iteration numbers. loop_dict['revcounter'] = len_values - i
loop_dict['revcounter'] = len_values - i loop_dict['revcounter0'] = len_values - i - 1
loop_dict['revcounter0'] = len_values - i - 1 # Boolean values designating first and last times through loop.
# Boolean values designating first and last times through loop. loop_dict['first'] = (i == 0)
loop_dict['first'] = (i == 0) loop_dict['last'] = (i == len_values - 1)
loop_dict['last'] = (i == len_values - 1)
pop_context = False pop_context = False
if unpack: if unpack:
# If there are multiple loop variables, unpack the item into # If there are multiple loop variables, unpack the item into
# them. # them.
try:
unpacked_vars = dict(zip(self.loopvars, item))
except TypeError:
pass
else:
pop_context = True
context.update(unpacked_vars)
else:
context[self.loopvars[0]] = item
# In TEMPLATE_DEBUG mode provide source of the node which
# actually raised the exception
if settings.TEMPLATE_DEBUG:
for node in self.nodelist_loop:
try: try:
unpacked_vars = dict(zip(self.loopvars, item))
except TypeError:
pass
else:
pop_context = True
context.update(unpacked_vars)
else:
context[self.loopvars[0]] = item
# In TEMPLATE_DEBUG mode provide source of the node which
# actually raised the exception
if settings.TEMPLATE_DEBUG:
for node in self.nodelist_loop:
try:
nodelist.append(node.render(context))
except Exception as e:
if not hasattr(e, 'django_template_source'):
e.django_template_source = node.source
raise
else:
for node in self.nodelist_loop:
nodelist.append(node.render(context)) nodelist.append(node.render(context))
except Exception as e: if pop_context:
if not hasattr(e, 'django_template_source'): # The loop variables were pushed on to the context so pop them
e.django_template_source = node.source # off again. This is necessary because the tag lets the length
raise # of loopvars differ to the length of each set of items and we
else: # don't want to leave any vars from the previous loop on the
for node in self.nodelist_loop: # context.
nodelist.append(node.render(context)) context.pop()
if pop_context:
# The loop variables were pushed on to the context so pop them
# off again. This is necessary because the tag lets the length
# of loopvars differ to the length of each set of items and we
# don't want to leave any vars from the previous loop on the
# context.
context.pop()
context.pop()
return nodelist.render(context) return nodelist.render(context)
class IfChangedNode(Node): class IfChangedNode(Node):
@ -500,10 +497,9 @@ class WithNode(Node):
def render(self, context): def render(self, context):
values = dict([(key, val.resolve(context)) for key, val in values = dict([(key, val.resolve(context)) for key, val in
six.iteritems(self.extra_context)]) six.iteritems(self.extra_context)])
context.update(values) with context.push(**values):
output = self.nodelist.render(context) return self.nodelist.render(context)
context.pop()
return output
@register.tag @register.tag
def autoescape(parser, token): def autoescape(parser, token):

View File

@ -164,11 +164,8 @@ def render_to_string(template_name, dictionary=None, context_instance=None):
return t.render(Context(dictionary)) return t.render(Context(dictionary))
# Add the dictionary to the context stack, ensuring it gets removed again # Add the dictionary to the context stack, ensuring it gets removed again
# to keep the context_instance in the same state it started in. # to keep the context_instance in the same state it started in.
context_instance.update(dictionary) with context_instance.push(dictionary):
try:
return t.render(context_instance) return t.render(context_instance)
finally:
context_instance.pop()
def select_template(template_name_list): def select_template(template_name_list):
"Given a list of template names, returns the first that can be loaded." "Given a list of template names, returns the first that can be loaded."

View File

@ -47,22 +47,21 @@ class BlockNode(Node):
def render(self, context): def render(self, context):
block_context = context.render_context.get(BLOCK_CONTEXT_KEY) block_context = context.render_context.get(BLOCK_CONTEXT_KEY)
context.push() with context.push():
if block_context is None: if block_context is None:
context['block'] = self context['block'] = self
result = self.nodelist.render(context) result = self.nodelist.render(context)
else: else:
push = block = block_context.pop(self.name) push = block = block_context.pop(self.name)
if block is None: if block is None:
block = self block = self
# Create new block so we can store context without thread-safety issues. # Create new block so we can store context without thread-safety issues.
block = BlockNode(block.name, block.nodelist) block = BlockNode(block.name, block.nodelist)
block.context = context block.context = context
context['block'] = block context['block'] = block
result = block.nodelist.render(context) result = block.nodelist.render(context)
if push is not None: if push is not None:
block_context.push(self.name, push) block_context.push(self.name, push)
context.pop()
return result return result
def super(self): def super(self):
@ -133,10 +132,9 @@ class BaseIncludeNode(Node):
in six.iteritems(self.extra_context)]) in six.iteritems(self.extra_context)])
if self.isolated_context: if self.isolated_context:
return template.render(context.new(values)) return template.render(context.new(values))
context.update(values) with context.push(**values):
output = template.render(context) return template.render(context)
context.pop()
return output
class ConstantIncludeNode(BaseIncludeNode): class ConstantIncludeNode(BaseIncludeNode):
def __init__(self, template_path, *args, **kwargs): def __init__(self, template_path, *args, **kwargs):

View File

@ -325,6 +325,31 @@ If you ``pop()`` too much, it'll raise
... ...
django.template.ContextPopException django.template.ContextPopException
.. versionadded:: 1.7
You can also use ``push()`` as a context manager to ensure a matching ``pop()``
is called.
>>> c = Context()
>>> c['foo'] = 'first level'
>>> with c.push():
>>> c['foo'] = 'second level'
>>> c['foo']
'second level'
>>> c['foo']
'first level'
All arguments passed to ``push()`` will be passed to the ``dict`` constructor
used to build the new context level.
>>> c = Context()
>>> c['foo'] = 'first level'
>>> with c.push(foo='second level'):
>>> c['foo']
'second level'
>>> c['foo']
'first level'
.. method:: update(other_dict) .. method:: update(other_dict)
In addition to ``push()`` and ``pop()``, the ``Context`` In addition to ``push()`` and ``pop()``, the ``Context``

View File

@ -60,6 +60,13 @@ Minor features
* :attr:`~django.db.models.Options.app_label` is no longer required for models * :attr:`~django.db.models.Options.app_label` is no longer required for models
that are defined in a ``models`` package within an app. that are defined in a ``models`` package within an app.
* The :meth:`Context.push() <django.template.Context.push>` method now returns
a context manager which automatically calls :meth:`pop()
<django.template.Context.pop>` upon exiting the ``with`` statement.
Additionally, :meth:`push() <django.template.Context.push>` now accepts
parameters that are passed to the ``dict`` constructor used to build the new
context level.
Backwards incompatible changes in 1.7 Backwards incompatible changes in 1.7
===================================== =====================================

View File

@ -16,3 +16,12 @@ class ContextTests(TestCase):
self.assertEqual(c.pop(), {"a": 2}) self.assertEqual(c.pop(), {"a": 2})
self.assertEqual(c["a"], 1) self.assertEqual(c["a"], 1)
self.assertEqual(c.get("foo", 42), 42) self.assertEqual(c.get("foo", 42), 42)
with c.push():
c['a'] = 2
self.assertEqual(c['a'], 2)
self.assertEqual(c['a'], 1)
with c.push(a=3):
self.assertEqual(c['a'], 3)
self.assertEqual(c['a'], 1)