Spaces:
Building
Building
# flake8: noqa | |
import typing | |
import warnings | |
import sys | |
from copy import deepcopy | |
from dataclasses import MISSING, is_dataclass, fields as dc_fields | |
from datetime import datetime | |
from decimal import Decimal | |
from uuid import UUID | |
from enum import Enum | |
from typing_inspect import is_union_type # type: ignore | |
from marshmallow import fields, Schema, post_load # type: ignore | |
from marshmallow.exceptions import ValidationError # type: ignore | |
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass, | |
_ExtendedEncoder, _user_overrides_or_exts) | |
from dataclasses_json.utils import (_is_collection, _is_optional, | |
_issubclass_safe, _timestamp_to_dt_aware, | |
_is_new_type, _get_type_origin, | |
_handle_undefined_parameters_safe, | |
CatchAllVar) | |
class _TimestampField(fields.Field): | |
def _serialize(self, value, attr, obj, **kwargs): | |
if value is not None: | |
return value.timestamp() | |
else: | |
if not self.required: | |
return None | |
else: | |
raise ValidationError(self.default_error_messages["required"]) | |
def _deserialize(self, value, attr, data, **kwargs): | |
if value is not None: | |
return _timestamp_to_dt_aware(value) | |
else: | |
if not self.required: | |
return None | |
else: | |
raise ValidationError(self.default_error_messages["required"]) | |
class _IsoField(fields.Field): | |
def _serialize(self, value, attr, obj, **kwargs): | |
if value is not None: | |
return value.isoformat() | |
else: | |
if not self.required: | |
return None | |
else: | |
raise ValidationError(self.default_error_messages["required"]) | |
def _deserialize(self, value, attr, data, **kwargs): | |
if value is not None: | |
return datetime.fromisoformat(value) | |
else: | |
if not self.required: | |
return None | |
else: | |
raise ValidationError(self.default_error_messages["required"]) | |
class _UnionField(fields.Field): | |
def __init__(self, desc, cls, field, *args, **kwargs): | |
self.desc = desc | |
self.cls = cls | |
self.field = field | |
super().__init__(*args, **kwargs) | |
def _serialize(self, value, attr, obj, **kwargs): | |
if self.allow_none and value is None: | |
return None | |
for type_, schema_ in self.desc.items(): | |
if _issubclass_safe(type(value), type_): | |
if is_dataclass(value): | |
res = schema_._serialize(value, attr, obj, **kwargs) | |
res['__type'] = str(type_.__name__) | |
return res | |
break | |
elif isinstance(value, _get_type_origin(type_)): | |
return schema_._serialize(value, attr, obj, **kwargs) | |
else: | |
warnings.warn( | |
f'The type "{type(value).__name__}" (value: "{value}") ' | |
f'is not in the list of possible types of typing.Union ' | |
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' | |
f'Value cannot be serialized properly.') | |
return super()._serialize(value, attr, obj, **kwargs) | |
def _deserialize(self, value, attr, data, **kwargs): | |
tmp_value = deepcopy(value) | |
if isinstance(tmp_value, dict) and '__type' in tmp_value: | |
dc_name = tmp_value['__type'] | |
for type_, schema_ in self.desc.items(): | |
if is_dataclass(type_) and type_.__name__ == dc_name: | |
del tmp_value['__type'] | |
return schema_._deserialize(tmp_value, attr, data, **kwargs) | |
elif isinstance(tmp_value, dict): | |
warnings.warn( | |
f'Attempting to deserialize "dict" (value: "{tmp_value}) ' | |
f'that does not have a "__type" type specifier field into' | |
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).' | |
f'Deserialization may fail, or deserialization to wrong type may occur.' | |
) | |
return super()._deserialize(tmp_value, attr, data, **kwargs) | |
else: | |
for type_, schema_ in self.desc.items(): | |
if isinstance(tmp_value, _get_type_origin(type_)): | |
return schema_._deserialize(tmp_value, attr, data, **kwargs) | |
else: | |
warnings.warn( | |
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") ' | |
f'is not in the list of possible types of typing.Union ' | |
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). ' | |
f'Value cannot be deserialized properly.') | |
return super()._deserialize(tmp_value, attr, data, **kwargs) | |
class _TupleVarLen(fields.List): | |
""" | |
variable-length homogeneous tuples | |
""" | |
def _deserialize(self, value, attr, data, **kwargs): | |
optional_list = super()._deserialize(value, attr, data, **kwargs) | |
return None if optional_list is None else tuple(optional_list) | |
TYPES = { | |
typing.Mapping: fields.Mapping, | |
typing.MutableMapping: fields.Mapping, | |
typing.List: fields.List, | |
typing.Dict: fields.Dict, | |
typing.Tuple: fields.Tuple, | |
typing.Callable: fields.Function, | |
typing.Any: fields.Raw, | |
dict: fields.Dict, | |
list: fields.List, | |
tuple: fields.Tuple, | |
str: fields.Str, | |
int: fields.Int, | |
float: fields.Float, | |
bool: fields.Bool, | |
datetime: _TimestampField, | |
UUID: fields.UUID, | |
Decimal: fields.Decimal, | |
CatchAllVar: fields.Dict, | |
} | |
A = typing.TypeVar('A') | |
JsonData = typing.Union[str, bytes, bytearray] | |
TEncoded = typing.Dict[str, typing.Any] | |
TOneOrMulti = typing.Union[typing.List[A], A] | |
TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded] | |
if sys.version_info >= (3, 7) or typing.TYPE_CHECKING: | |
class SchemaF(Schema, typing.Generic[A]): | |
"""Lift Schema into a type constructor""" | |
def __init__(self, *args, **kwargs): | |
""" | |
Raises exception because this class should not be inherited. | |
This class is helper only. | |
""" | |
super().__init__(*args, **kwargs) | |
raise NotImplementedError() | |
def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: # type: ignore | |
# mm has the wrong return type annotation (dict) so we can ignore the mypy error | |
pass | |
def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded: | |
pass | |
def dump(self, obj: TOneOrMulti, # type: ignore | |
many: typing.Optional[bool] = None) -> TOneOrMultiEncoded: | |
pass | |
def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args, | |
**kwargs) -> str: | |
pass | |
def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str: | |
pass | |
def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, # type: ignore | |
**kwargs) -> str: | |
pass | |
# type: ignore | |
def load(self, data: typing.List[TEncoded], | |
many: bool = True, partial: typing.Optional[bool] = None, | |
unknown: typing.Optional[str] = None) -> \ | |
typing.List[A]: | |
# ignore the mypy error of the decorator because mm does not define lists as an allowed input type | |
pass | |
def load(self, data: TEncoded, | |
many: None = None, partial: typing.Optional[bool] = None, | |
unknown: typing.Optional[str] = None) -> A: | |
pass | |
def load(self, data: TOneOrMultiEncoded, | |
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, | |
unknown: typing.Optional[str] = None) -> TOneOrMulti: | |
pass | |
# type: ignore | |
def loads(self, json_data: JsonData, # type: ignore | |
many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, | |
**kwargs) -> typing.List[A]: | |
# ignore the mypy error of the decorator because mm does not define bytes as correct input data | |
# mm has the wrong return type annotation (dict) so we can ignore the mypy error | |
# for the return type overlap | |
pass | |
def loads(self, json_data: JsonData, | |
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None, | |
**kwargs) -> TOneOrMulti: | |
pass | |
SchemaType = SchemaF[A] | |
else: | |
SchemaType = Schema | |
def build_type(type_, options, mixin, field, cls): | |
def inner(type_, options): | |
while True: | |
if not _is_new_type(type_): | |
break | |
type_ = type_.__supertype__ | |
if is_dataclass(type_): | |
if _issubclass_safe(type_, mixin): | |
options['field_many'] = bool( | |
_is_supported_generic(field.type) and _is_collection( | |
field.type)) | |
return fields.Nested(type_.schema(), **options) | |
else: | |
warnings.warn(f"Nested dataclass field {field.name} of type " | |
f"{field.type} detected in " | |
f"{cls.__name__} that is not an instance of " | |
f"dataclass_json. Did you mean to recursively " | |
f"serialize this field? If so, make sure to " | |
f"augment {type_} with either the " | |
f"`dataclass_json` decorator or mixin.") | |
return fields.Field(**options) | |
origin = getattr(type_, '__origin__', type_) | |
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if | |
a is not type(None)] | |
if type_ == Ellipsis: | |
return type_ | |
if _is_optional(type_): | |
options["allow_none"] = True | |
if origin is tuple: | |
if len(args) == 2 and args[1] == Ellipsis: | |
return _TupleVarLen(args[0], **options) | |
else: | |
return fields.Tuple(args, **options) | |
if origin in TYPES: | |
return TYPES[origin](*args, **options) | |
if _issubclass_safe(origin, Enum): | |
return fields.Enum(enum=origin, by_value=True, *args, **options) | |
if is_union_type(type_): | |
union_types = [a for a in getattr(type_, '__args__', []) if | |
a is not type(None)] | |
union_desc = dict(zip(union_types, args)) | |
return _UnionField(union_desc, cls, field, **options) | |
warnings.warn( | |
f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} " | |
f"It's advised to pass the correct marshmallow type to `mm_field`.") | |
return fields.Field(**options) | |
return inner(type_, options) | |
def schema(cls, mixin, infer_missing): | |
schema = {} | |
overrides = _user_overrides_or_exts(cls) | |
# TODO check the undefined parameters and add the proper schema action | |
# https://marshmallow.readthedocs.io/en/stable/quickstart.html | |
for field in dc_fields(cls): | |
metadata = overrides[field.name] | |
if metadata.mm_field is not None: | |
schema[field.name] = metadata.mm_field | |
else: | |
type_ = field.type | |
options: typing.Dict[str, typing.Any] = {} | |
missing_key = 'missing' if infer_missing else 'default' | |
if field.default is not MISSING: | |
options[missing_key] = field.default | |
elif field.default_factory is not MISSING: | |
options[missing_key] = field.default_factory() | |
else: | |
options['required'] = True | |
if options.get(missing_key, ...) is None: | |
options['allow_none'] = True | |
if _is_optional(type_): | |
options.setdefault(missing_key, None) | |
options['allow_none'] = True | |
if len(type_.__args__) == 2: | |
# Union[str, int, None] is optional too, but it has more than 1 typed field. | |
type_ = [tp for tp in type_.__args__ if tp is not type(None)][0] | |
if metadata.letter_case is not None: | |
options['data_key'] = metadata.letter_case(field.name) | |
t = build_type(type_, options, mixin, field, cls) | |
if field.metadata.get('dataclasses_json', {}).get('decoder'): | |
# If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion | |
# logic. | |
# From Marshmallow's documentation for the _deserialize method: | |
# "Deserialize value. Concrete :class:`Field` classes should implement this method. " | |
# This is the method that Field implementations override to perform the actual deserialization logic. | |
# In this case we specifically override this method instead of `deserialize` to minimize potential | |
# side effects, and only cancel the actual value deserialization. | |
t._deserialize = lambda v, *_a, **_kw: v | |
# if type(t) is not fields.Field: # If we use `isinstance` we would return nothing. | |
if field.type != typing.Optional[CatchAllVar]: | |
schema[field.name] = t | |
return schema | |
def build_schema(cls: typing.Type[A], | |
mixin, | |
infer_missing, | |
partial) -> typing.Type["SchemaType[A]"]: | |
Meta = type('Meta', | |
(), | |
{'fields': tuple(field.name for field in dc_fields(cls) # type: ignore | |
if | |
field.name != 'dataclass_json_config' and field.type != | |
typing.Optional[CatchAllVar]), | |
# TODO #180 | |
# 'render_module': global_config.json_module | |
}) | |
def make_instance(self, kvs, **kwargs): | |
return _decode_dataclass(cls, kvs, partial) | |
def dumps(self, *args, **kwargs): | |
if 'cls' not in kwargs: | |
kwargs['cls'] = _ExtendedEncoder | |
return Schema.dumps(self, *args, **kwargs) | |
def dump(self, obj, *, many=None): | |
many = self.many if many is None else bool(many) | |
dumped = Schema.dump(self, obj, many=many) | |
# TODO This is hacky, but the other option I can think of is to generate a different schema | |
# depending on dump and load, which is even more hacky | |
# The only problem is the catch-all field, we can't statically create a schema for it, | |
# so we just update the dumped dict | |
if many: | |
for i, _obj in enumerate(obj): | |
dumped[i].update( | |
_handle_undefined_parameters_safe(cls=_obj, kvs={}, | |
usage="dump")) | |
else: | |
dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={}, | |
usage="dump")) | |
return dumped | |
schema_ = schema(cls, mixin, infer_missing) | |
DataClassSchema: typing.Type["SchemaType[A]"] = type( | |
f'{cls.__name__.capitalize()}Schema', | |
(Schema,), | |
{'Meta': Meta, | |
f'make_{cls.__name__.lower()}': make_instance, | |
'dumps': dumps, | |
'dump': dump, | |
**schema_}) | |
return DataClassSchema | |