Fixed #26151 -- Refactored MigrationWriter.serialize()

Thanks Markus Holtermann for review.
This commit is contained in:
Yoong Kang Lim 2016-01-28 01:22:39 +11:00 committed by Tim Graham
parent fc584f0685
commit 4b1529e2cb
2 changed files with 392 additions and 266 deletions

View File

@ -0,0 +1,388 @@
from __future__ import unicode_literals
import collections
import datetime
import decimal
import functools
import math
import types
from importlib import import_module
from django.db import models
from django.db.migrations.operations.base import Operation
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
from django.utils import datetime_safe, six
from django.utils.encoding import force_text
from django.utils.functional import LazyObject, Promise
from django.utils.timezone import utc
from django.utils.version import get_docs_version
try:
import enum
except ImportError:
# No support on Python 2 if enum34 isn't installed.
enum = None
class BaseSerializer(object):
def __init__(self, value):
self.value = value
def serialize(self):
raise NotImplementedError('Subclasses of BaseSerializer must implement the serialize() method.')
class BaseSequenceSerializer(BaseSerializer):
def _format(self):
raise NotImplementedError('Subclasses of BaseSequenceSerializer must implement the _format() method.')
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
value = self._format()
return value % (", ".join(strings)), imports
class BaseSimpleSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), set()
class ByteTypeSerializer(BaseSerializer):
def serialize(self):
value_repr = repr(self.value)
if six.PY2:
# Prepend the `b` prefix since we're importing unicode_literals
value_repr = 'b' + value_repr
return value_repr, set()
class DatetimeSerializer(BaseSerializer):
def serialize(self):
if self.value.tzinfo is not None and self.value.tzinfo != utc:
self.value = self.value.astimezone(utc)
value_repr = repr(self.value).replace("<UTC>", "utc")
if isinstance(self.value, datetime_safe.datetime):
value_repr = "datetime.%s" % value_repr
imports = ["import datetime"]
if self.value.tzinfo is not None:
imports.append("from django.utils.timezone import utc")
return value_repr, set(imports)
class DateSerializer(BaseSerializer):
def serialize(self):
value_repr = repr(self.value)
if isinstance(self.value, datetime_safe.date):
value_repr = "datetime.%s" % value_repr
return value_repr, {"import datetime"}
class DecimalSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), {"from decimal import Decimal"}
class DeconstructableSerializer(BaseSerializer):
@staticmethod
def serialize_deconstructed(path, args, kwargs):
name, imports = DeconstructableSerializer._serialize_path(path)
strings = []
for arg in args:
arg_string, arg_imports = serializer_factory(arg).serialize()
strings.append(arg_string)
imports.update(arg_imports)
for kw, arg in sorted(kwargs.items()):
arg_string, arg_imports = serializer_factory(arg).serialize()
imports.update(arg_imports)
strings.append("%s=%s" % (kw, arg_string))
return "%s(%s)" % (name, ", ".join(strings)), imports
@staticmethod
def _serialize_path(path):
module, name = path.rsplit(".", 1)
if module == "django.db.models":
imports = {"from django.db import models"}
name = "models.%s" % name
else:
imports = {"import %s" % module}
name = path
return name, imports
def serialize(self):
return self.serialize_deconstructed(*self.value.deconstruct())
class DictionarySerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for k, v in sorted(self.value.items()):
k_string, k_imports = serializer_factory(k).serialize()
v_string, v_imports = serializer_factory(v).serialize()
imports.update(k_imports)
imports.update(v_imports)
strings.append((k_string, v_string))
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
class EnumSerializer(BaseSerializer):
def serialize(self):
enum_class = self.value.__class__
module = enum_class.__module__
imports = {"import %s" % module}
v_string, v_imports = serializer_factory(self.value.value).serialize()
imports.update(v_imports)
return "%s.%s(%s)" % (module, enum_class.__name__, v_string), imports
class FloatSerializer(BaseSimpleSerializer):
def serialize(self):
if math.isnan(self.value) or math.isinf(self.value):
return 'float("{}")'.format(self.value), set()
return super(FloatSerializer, self).serialize()
class FrozensetSerializer(BaseSequenceSerializer):
def _format(self):
return "frozenset([%s])"
class FunctionTypeSerializer(BaseSerializer):
def serialize(self):
if getattr(self.value, "__self__", None) and isinstance(self.value.__self__, type):
klass = self.value.__self__
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, self.value.__name__), {"import %s" % module}
# Further error checking
if self.value.__name__ == '<lambda>':
raise ValueError("Cannot serialize function: lambda")
if self.value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % self.value)
# Python 3 is a lot easier, and only uses this branch if it's not local.
if getattr(self.value, "__qualname__", None) and getattr(self.value, "__module__", None):
if "<" not in self.value.__qualname__: # Qualname can include <locals>
return "%s.%s" % \
(self.value.__module__, self.value.__qualname__), {"import %s" % self.value.__module__}
# Python 2/fallback version
module_name = self.value.__module__
# Make sure it's actually there and not an unbound method
module = import_module(module_name)
if not hasattr(module, self.value.__name__):
raise ValueError(
"Could not find function %s in %s.\n"
"Please note that due to Python 2 limitations, you cannot "
"serialize unbound method functions (e.g. a method "
"declared and used in the same class body). Please move "
"the function into the main module body to use migrations.\n"
"For more information, see "
"https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values"
% (self.value.__name__, module_name, get_docs_version())
)
# Needed on Python 2 only
if module_name == '__builtin__':
return self.value.__name__, set()
return "%s.%s" % (module_name, self.value.__name__), {"import %s" % module_name}
class FunctoolsPartialSerializer(BaseSerializer):
def serialize(self):
imports = {'import functools'}
# Serialize functools.partial() arguments
func_string, func_imports = serializer_factory(self.value.func).serialize()
args_string, args_imports = serializer_factory(self.value.args).serialize()
keywords_string, keywords_imports = serializer_factory(self.value.keywords).serialize()
# Add any imports needed by arguments
imports.update(func_imports)
imports.update(args_imports)
imports.update(keywords_imports)
return (
"functools.partial(%s, *%s, **%s)" % (
func_string, args_string, keywords_string,
),
imports,
)
class IterableSerializer(BaseSerializer):
def serialize(self):
imports = set()
strings = []
for item in self.value:
item_string, item_imports = serializer_factory(item).serialize()
imports.update(item_imports)
strings.append(item_string)
# When len(strings)==0, the empty iterable should be serialized as
# "()", not "(,)" because (,) is invalid Python syntax.
value = "(%s)" if len(strings) != 1 else "(%s,)"
return value % (", ".join(strings)), imports
class ModelFieldSerializer(DeconstructableSerializer):
def serialize(self):
attr_name, path, args, kwargs = self.value.deconstruct()
return self.serialize_deconstructed(path, args, kwargs)
class ModelManagerSerializer(DeconstructableSerializer):
def serialize(self):
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
if as_manager:
name, imports = self._serialize_path(qs_path)
return "%s.as_manager()" % name, imports
else:
return self.serialize_deconstructed(manager_path, args, kwargs)
class OperationSerializer(BaseSerializer):
def serialize(self):
from django.db.migrations.writer import OperationWriter
string, imports = OperationWriter(self.value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
return string.rstrip(','), imports
class RegexSerializer(BaseSerializer):
def serialize(self):
imports = {"import re"}
regex_pattern, pattern_imports = serializer_factory(self.value.pattern).serialize()
regex_flags, flag_imports = serializer_factory(self.value.flags).serialize()
imports.update(pattern_imports)
imports.update(flag_imports)
args = [regex_pattern]
if self.value.flags:
args.append(regex_flags)
return "re.compile(%s)" % ', '.join(args), imports
class SequenceSerializer(BaseSequenceSerializer):
def _format(self):
return "[%s]"
class SetSerializer(BaseSequenceSerializer):
def _format(self):
# Don't use the literal "{%s}" as it doesn't support empty set
return "set([%s])"
class SettingsReferenceSerializer(BaseSerializer):
def serialize(self):
return "settings.%s" % self.value.setting_name, {"from django.conf import settings"}
class TextTypeSerializer(BaseSerializer):
def serialize(self):
value_repr = repr(self.value)
if six.PY2:
# Strip the `u` prefix since we're importing unicode_literals
value_repr = value_repr[1:]
return value_repr, set()
class TimedeltaSerializer(BaseSerializer):
def serialize(self):
return repr(self.value), {"import datetime"}
class TimeSerializer(BaseSerializer):
def serialize(self):
value_repr = repr(self.value)
if isinstance(self.value, datetime_safe.time):
value_repr = "datetime.%s" % value_repr
return value_repr, {"import datetime"}
class TupleSerializer(BaseSequenceSerializer):
def _format(self):
# When len(value)==0, the empty tuple should be serialized as "()",
# not "(,)" because (,) is invalid Python syntax.
return "(%s)" if len(self.value) != 1 else "(%s,)"
class TypeSerializer(BaseSerializer):
def serialize(self):
special_cases = [
(models.Model, "models.Model", []),
]
for case, string, imports in special_cases:
if case is self.value:
return string, set(imports)
if hasattr(self.value, "__module__"):
module = self.value.__module__
if module == six.moves.builtins.__name__:
return self.value.__name__, set()
else:
return "%s.%s" % (module, self.value.__name__), {"import %s" % module}
def serializer_factory(value):
from django.db.migrations.writer import SettingsReference
if isinstance(value, Promise):
value = force_text(value)
elif isinstance(value, LazyObject):
# The unwrapped value is returned as the first item of the arguments
# tuple.
value = value.__reduce__()[1][0]
# Unfortunately some of these are order-dependent.
if isinstance(value, frozenset):
return FrozensetSerializer(value)
if isinstance(value, list):
return SequenceSerializer(value)
if isinstance(value, set):
return SetSerializer(value)
if isinstance(value, tuple):
return TupleSerializer(value)
if isinstance(value, dict):
return DictionarySerializer(value)
if enum and isinstance(value, enum.Enum):
return EnumSerializer(value)
if isinstance(value, datetime.datetime):
return DatetimeSerializer(value)
if isinstance(value, datetime.date):
return DateSerializer(value)
if isinstance(value, datetime.time):
return TimeSerializer(value)
if isinstance(value, datetime.timedelta):
return TimedeltaSerializer(value)
if isinstance(value, SettingsReference):
return SettingsReferenceSerializer(value)
if isinstance(value, float):
return FloatSerializer(value)
if isinstance(value, six.integer_types + (bool, type(None))):
return BaseSimpleSerializer(value)
if isinstance(value, six.binary_type):
return ByteTypeSerializer(value)
if isinstance(value, six.text_type):
return TextTypeSerializer(value)
if isinstance(value, decimal.Decimal):
return DecimalSerializer(value)
if isinstance(value, models.Field):
return ModelFieldSerializer(value)
if isinstance(value, type):
return TypeSerializer(value)
if isinstance(value, models.manager.BaseManager):
return ModelManagerSerializer(value)
if isinstance(value, Operation):
return OperationSerializer(value)
if isinstance(value, functools.partial):
return FunctoolsPartialSerializer(value)
# Anything that knows how to deconstruct itself.
if hasattr(value, 'deconstruct'):
return DeconstructableSerializer(value)
if isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
return FunctionTypeSerializer(value)
if isinstance(value, collections.Iterable):
return IterableSerializer(value)
if isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
return RegexSerializer(value)
raise ValueError(
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
"topics/migrations/#migration-serializing" % (value, get_docs_version())
)

View File

@ -1,29 +1,19 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections
import datetime
import decimal
import functools
import math
import os import os
import re import re
import types
from importlib import import_module from importlib import import_module
from django import get_version from django import get_version
from django.apps import apps from django.apps import apps
from django.db import migrations, models from django.db import migrations
from django.db.migrations.loader import MigrationLoader from django.db.migrations.loader import MigrationLoader
from django.db.migrations.operations.base import Operation from django.db.migrations.serializer import serializer_factory
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
from django.utils import datetime_safe, six
from django.utils._os import upath from django.utils._os import upath
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import LazyObject, Promise
from django.utils.inspect import get_func_args from django.utils.inspect import get_func_args
from django.utils.module_loading import module_dir from django.utils.module_loading import module_dir
from django.utils.timezone import now, utc from django.utils.timezone import now
from django.utils.version import get_docs_version
try: try:
import enum import enum
@ -229,20 +219,6 @@ class MigrationWriter(object):
return (MIGRATION_TEMPLATE % items).encode("utf8") return (MIGRATION_TEMPLATE % items).encode("utf8")
@staticmethod
def serialize_datetime(value):
"""
Returns a serialized version of a datetime object that is valid,
executable python code. It converts timezone-aware values to utc with
an 'executable' utc representation of tzinfo.
"""
if value.tzinfo is not None and value.tzinfo != utc:
value = value.astimezone(utc)
value_repr = repr(value).replace("<UTC>", "utc")
if isinstance(value, datetime_safe.datetime):
value_repr = "datetime.%s" % value_repr
return value_repr
@property @property
def basedir(self): def basedir(self):
migrations_package_name = MigrationLoader.migrations_module(self.migration.app_label) migrations_package_name = MigrationLoader.migrations_module(self.migration.app_label)
@ -312,247 +288,9 @@ class MigrationWriter(object):
def path(self): def path(self):
return os.path.join(self.basedir, self.filename) return os.path.join(self.basedir, self.filename)
@classmethod
def serialize_deconstructed(cls, path, args, kwargs):
name, imports = cls._serialize_path(path)
strings = []
for arg in args:
arg_string, arg_imports = cls.serialize(arg)
strings.append(arg_string)
imports.update(arg_imports)
for kw, arg in sorted(kwargs.items()):
arg_string, arg_imports = cls.serialize(arg)
imports.update(arg_imports)
strings.append("%s=%s" % (kw, arg_string))
return "%s(%s)" % (name, ", ".join(strings)), imports
@classmethod
def _serialize_path(cls, path):
module, name = path.rsplit(".", 1)
if module == "django.db.models":
imports = {"from django.db import models"}
name = "models.%s" % name
else:
imports = {"import %s" % module}
name = path
return name, imports
@classmethod @classmethod
def serialize(cls, value): def serialize(cls, value):
""" return serializer_factory(value).serialize()
Serializes the value to a string that's parsable by Python, along
with any needed imports to make that string work.
More advanced than repr() as it can encode things
like datetime.datetime.now.
"""
# FIXME: Ideally Promise would be reconstructible, but for now we
# use force_text on them and defer to the normal string serialization
# process.
if isinstance(value, Promise):
value = force_text(value)
elif isinstance(value, LazyObject):
# The unwrapped value is returned as the first item of the
# arguments tuple.
value = value.__reduce__()[1][0]
# Sequences
if isinstance(value, (frozenset, list, set, tuple)):
imports = set()
strings = []
for item in value:
item_string, item_imports = cls.serialize(item)
imports.update(item_imports)
strings.append(item_string)
if isinstance(value, set):
# Don't use the literal "{%s}" as it doesn't support empty set
format = "set([%s])"
elif isinstance(value, frozenset):
format = "frozenset([%s])"
elif isinstance(value, tuple):
# When len(value)==0, the empty tuple should be serialized as
# "()", not "(,)" because (,) is invalid Python syntax.
format = "(%s)" if len(value) != 1 else "(%s,)"
else:
format = "[%s]"
return format % (", ".join(strings)), imports
# Dictionaries
elif isinstance(value, dict):
imports = set()
strings = []
for k, v in sorted(value.items()):
k_string, k_imports = cls.serialize(k)
v_string, v_imports = cls.serialize(v)
imports.update(k_imports)
imports.update(v_imports)
strings.append((k_string, v_string))
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
# Enums
elif enum and isinstance(value, enum.Enum):
enum_class = value.__class__
module = enum_class.__module__
imports = {"import %s" % module}
v_string, v_imports = cls.serialize(value.value)
imports.update(v_imports)
return "%s.%s(%s)" % (module, enum_class.__name__, v_string), imports
# Datetimes
elif isinstance(value, datetime.datetime):
value_repr = cls.serialize_datetime(value)
imports = ["import datetime"]
if value.tzinfo is not None:
imports.append("from django.utils.timezone import utc")
return value_repr, set(imports)
# Dates
elif isinstance(value, datetime.date):
value_repr = repr(value)
if isinstance(value, datetime_safe.date):
value_repr = "datetime.%s" % value_repr
return value_repr, {"import datetime"}
# Times
elif isinstance(value, datetime.time):
value_repr = repr(value)
if isinstance(value, datetime_safe.time):
value_repr = "datetime.%s" % value_repr
return value_repr, {"import datetime"}
# Timedeltas
elif isinstance(value, datetime.timedelta):
return repr(value), {"import datetime"}
# Settings references
elif isinstance(value, SettingsReference):
return "settings.%s" % value.setting_name, {"from django.conf import settings"}
# Simple types
elif isinstance(value, float):
if math.isnan(value) or math.isinf(value):
return 'float("{}")'.format(value), set()
return repr(value), set()
elif isinstance(value, six.integer_types + (bool, type(None))):
return repr(value), set()
elif isinstance(value, six.binary_type):
value_repr = repr(value)
if six.PY2:
# Prepend the `b` prefix since we're importing unicode_literals
value_repr = 'b' + value_repr
return value_repr, set()
elif isinstance(value, six.text_type):
value_repr = repr(value)
if six.PY2:
# Strip the `u` prefix since we're importing unicode_literals
value_repr = value_repr[1:]
return value_repr, set()
# Decimal
elif isinstance(value, decimal.Decimal):
return repr(value), {"from decimal import Decimal"}
# Django fields
elif isinstance(value, models.Field):
attr_name, path, args, kwargs = value.deconstruct()
return cls.serialize_deconstructed(path, args, kwargs)
# Classes
elif isinstance(value, type):
special_cases = [
(models.Model, "models.Model", []),
]
for case, string, imports in special_cases:
if case is value:
return string, set(imports)
if hasattr(value, "__module__"):
module = value.__module__
if module == six.moves.builtins.__name__:
return value.__name__, set()
else:
return "%s.%s" % (module, value.__name__), {"import %s" % module}
elif isinstance(value, models.manager.BaseManager):
as_manager, manager_path, qs_path, args, kwargs = value.deconstruct()
if as_manager:
name, imports = cls._serialize_path(qs_path)
return "%s.as_manager()" % name, imports
else:
return cls.serialize_deconstructed(manager_path, args, kwargs)
elif isinstance(value, Operation):
string, imports = OperationWriter(value, indentation=0).serialize()
# Nested operation, trailing comma is handled in upper OperationWriter._write()
return string.rstrip(','), imports
elif isinstance(value, functools.partial):
imports = {'import functools'}
# Serialize functools.partial() arguments
func_string, func_imports = cls.serialize(value.func)
args_string, args_imports = cls.serialize(value.args)
keywords_string, keywords_imports = cls.serialize(value.keywords)
# Add any imports needed by arguments
imports.update(func_imports)
imports.update(args_imports)
imports.update(keywords_imports)
return (
"functools.partial(%s, *%s, **%s)" % (
func_string, args_string, keywords_string,
),
imports,
)
# Anything that knows how to deconstruct itself.
elif hasattr(value, 'deconstruct'):
return cls.serialize_deconstructed(*value.deconstruct())
# Functions
elif isinstance(value, (types.FunctionType, types.BuiltinFunctionType)):
# @classmethod?
if getattr(value, "__self__", None) and isinstance(value.__self__, type):
klass = value.__self__
module = klass.__module__
return "%s.%s.%s" % (module, klass.__name__, value.__name__), {"import %s" % module}
# Further error checking
if value.__name__ == '<lambda>':
raise ValueError("Cannot serialize function: lambda")
if value.__module__ is None:
raise ValueError("Cannot serialize function %r: No module" % value)
# Python 3 is a lot easier, and only uses this branch if it's not local.
if getattr(value, "__qualname__", None) and getattr(value, "__module__", None):
if "<" not in value.__qualname__: # Qualname can include <locals>
return "%s.%s" % (value.__module__, value.__qualname__), {"import %s" % value.__module__}
# Python 2/fallback version
module_name = value.__module__
# Make sure it's actually there and not an unbound method
module = import_module(module_name)
if not hasattr(module, value.__name__):
raise ValueError(
"Could not find function %s in %s.\n"
"Please note that due to Python 2 limitations, you cannot "
"serialize unbound method functions (e.g. a method "
"declared and used in the same class body). Please move "
"the function into the main module body to use migrations.\n"
"For more information, see "
"https://docs.djangoproject.com/en/%s/topics/migrations/#serializing-values"
% (value.__name__, module_name, get_docs_version()))
# Needed on Python 2 only
if module_name == '__builtin__':
return value.__name__, set()
return "%s.%s" % (module_name, value.__name__), {"import %s" % module_name}
# Other iterables
elif isinstance(value, collections.Iterable):
imports = set()
strings = []
for item in value:
item_string, item_imports = cls.serialize(item)
imports.update(item_imports)
strings.append(item_string)
# When len(strings)==0, the empty iterable should be serialized as
# "()", not "(,)" because (,) is invalid Python syntax.
format = "(%s)" if len(strings) != 1 else "(%s,)"
return format % (", ".join(strings)), imports
# Compiled regex
elif isinstance(value, (COMPILED_REGEX_TYPE, RegexObject)):
imports = {"import re"}
regex_pattern, pattern_imports = cls.serialize(value.pattern)
regex_flags, flag_imports = cls.serialize(value.flags)
imports.update(pattern_imports)
imports.update(flag_imports)
args = [regex_pattern]
if value.flags:
args.append(regex_flags)
return "re.compile(%s)" % ', '.join(args), imports
# Uh oh.
else:
raise ValueError(
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
"topics/migrations/#migration-serializing" % (value, get_docs_version())
)
MIGRATION_TEMPLATE = """\ MIGRATION_TEMPLATE = """\