lemesdaniel's picture
Upload folder using huggingface_hub
e00b837 verified
# flake8: noqa
import typing
import warnings
import sys
from copy import deepcopy
from dataclasses import MISSING, is_dataclass, fields as dc_fields
from datetime import datetime
from decimal import Decimal
from uuid import UUID
from enum import Enum
from typing_inspect import is_union_type # type: ignore
from marshmallow import fields, Schema, post_load # type: ignore
from marshmallow.exceptions import ValidationError # type: ignore
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
_ExtendedEncoder, _user_overrides_or_exts)
from dataclasses_json.utils import (_is_collection, _is_optional,
_issubclass_safe, _timestamp_to_dt_aware,
_is_new_type, _get_type_origin,
_handle_undefined_parameters_safe,
CatchAllVar)
class _TimestampField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if value is not None:
return value.timestamp()
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
def _deserialize(self, value, attr, data, **kwargs):
if value is not None:
return _timestamp_to_dt_aware(value)
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
class _IsoField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if value is not None:
return value.isoformat()
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
def _deserialize(self, value, attr, data, **kwargs):
if value is not None:
return datetime.fromisoformat(value)
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
class _UnionField(fields.Field):
def __init__(self, desc, cls, field, *args, **kwargs):
self.desc = desc
self.cls = cls
self.field = field
super().__init__(*args, **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
if self.allow_none and value is None:
return None
for type_, schema_ in self.desc.items():
if _issubclass_safe(type(value), type_):
if is_dataclass(value):
res = schema_._serialize(value, attr, obj, **kwargs)
res['__type'] = str(type_.__name__)
return res
break
elif isinstance(value, _get_type_origin(type_)):
return schema_._serialize(value, attr, obj, **kwargs)
else:
warnings.warn(
f'The type "{type(value).__name__}" (value: "{value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be serialized properly.')
return super()._serialize(value, attr, obj, **kwargs)
def _deserialize(self, value, attr, data, **kwargs):
tmp_value = deepcopy(value)
if isinstance(tmp_value, dict) and '__type' in tmp_value:
dc_name = tmp_value['__type']
for type_, schema_ in self.desc.items():
if is_dataclass(type_) and type_.__name__ == dc_name:
del tmp_value['__type']
return schema_._deserialize(tmp_value, attr, data, **kwargs)
elif isinstance(tmp_value, dict):
warnings.warn(
f'Attempting to deserialize "dict" (value: "{tmp_value}) '
f'that does not have a "__type" type specifier field into'
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).'
f'Deserialization may fail, or deserialization to wrong type may occur.'
)
return super()._deserialize(tmp_value, attr, data, **kwargs)
else:
for type_, schema_ in self.desc.items():
if isinstance(tmp_value, _get_type_origin(type_)):
return schema_._deserialize(tmp_value, attr, data, **kwargs)
else:
warnings.warn(
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be deserialized properly.')
return super()._deserialize(tmp_value, attr, data, **kwargs)
class _TupleVarLen(fields.List):
"""
variable-length homogeneous tuples
"""
def _deserialize(self, value, attr, data, **kwargs):
optional_list = super()._deserialize(value, attr, data, **kwargs)
return None if optional_list is None else tuple(optional_list)
TYPES = {
typing.Mapping: fields.Mapping,
typing.MutableMapping: fields.Mapping,
typing.List: fields.List,
typing.Dict: fields.Dict,
typing.Tuple: fields.Tuple,
typing.Callable: fields.Function,
typing.Any: fields.Raw,
dict: fields.Dict,
list: fields.List,
tuple: fields.Tuple,
str: fields.Str,
int: fields.Int,
float: fields.Float,
bool: fields.Bool,
datetime: _TimestampField,
UUID: fields.UUID,
Decimal: fields.Decimal,
CatchAllVar: fields.Dict,
}
A = typing.TypeVar('A')
JsonData = typing.Union[str, bytes, bytearray]
TEncoded = typing.Dict[str, typing.Any]
TOneOrMulti = typing.Union[typing.List[A], A]
TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded]
if sys.version_info >= (3, 7) or typing.TYPE_CHECKING:
class SchemaF(Schema, typing.Generic[A]):
"""Lift Schema into a type constructor"""
def __init__(self, *args, **kwargs):
"""
Raises exception because this class should not be inherited.
This class is helper only.
"""
super().__init__(*args, **kwargs)
raise NotImplementedError()
@typing.overload
def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: # type: ignore
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
pass
@typing.overload
def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded:
pass
def dump(self, obj: TOneOrMulti, # type: ignore
many: typing.Optional[bool] = None) -> TOneOrMultiEncoded:
pass
@typing.overload
def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args,
**kwargs) -> str:
pass
@typing.overload
def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str:
pass
def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, # type: ignore
**kwargs) -> str:
pass
@typing.overload # type: ignore
def load(self, data: typing.List[TEncoded],
many: bool = True, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> \
typing.List[A]:
# ignore the mypy error of the decorator because mm does not define lists as an allowed input type
pass
@typing.overload
def load(self, data: TEncoded,
many: None = None, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> A:
pass
def load(self, data: TOneOrMultiEncoded,
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> TOneOrMulti:
pass
@typing.overload # type: ignore
def loads(self, json_data: JsonData, # type: ignore
many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
**kwargs) -> typing.List[A]:
# ignore the mypy error of the decorator because mm does not define bytes as correct input data
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
# for the return type overlap
pass
def loads(self, json_data: JsonData,
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
**kwargs) -> TOneOrMulti:
pass
SchemaType = SchemaF[A]
else:
SchemaType = Schema
def build_type(type_, options, mixin, field, cls):
def inner(type_, options):
while True:
if not _is_new_type(type_):
break
type_ = type_.__supertype__
if is_dataclass(type_):
if _issubclass_safe(type_, mixin):
options['field_many'] = bool(
_is_supported_generic(field.type) and _is_collection(
field.type))
return fields.Nested(type_.schema(), **options)
else:
warnings.warn(f"Nested dataclass field {field.name} of type "
f"{field.type} detected in "
f"{cls.__name__} that is not an instance of "
f"dataclass_json. Did you mean to recursively "
f"serialize this field? If so, make sure to "
f"augment {type_} with either the "
f"`dataclass_json` decorator or mixin.")
return fields.Field(**options)
origin = getattr(type_, '__origin__', type_)
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
a is not type(None)]
if type_ == Ellipsis:
return type_
if _is_optional(type_):
options["allow_none"] = True
if origin is tuple:
if len(args) == 2 and args[1] == Ellipsis:
return _TupleVarLen(args[0], **options)
else:
return fields.Tuple(args, **options)
if origin in TYPES:
return TYPES[origin](*args, **options)
if _issubclass_safe(origin, Enum):
return fields.Enum(enum=origin, by_value=True, *args, **options)
if is_union_type(type_):
union_types = [a for a in getattr(type_, '__args__', []) if
a is not type(None)]
union_desc = dict(zip(union_types, args))
return _UnionField(union_desc, cls, field, **options)
warnings.warn(
f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} "
f"It's advised to pass the correct marshmallow type to `mm_field`.")
return fields.Field(**options)
return inner(type_, options)
def schema(cls, mixin, infer_missing):
schema = {}
overrides = _user_overrides_or_exts(cls)
# TODO check the undefined parameters and add the proper schema action
# https://marshmallow.readthedocs.io/en/stable/quickstart.html
for field in dc_fields(cls):
metadata = overrides[field.name]
if metadata.mm_field is not None:
schema[field.name] = metadata.mm_field
else:
type_ = field.type
options: typing.Dict[str, typing.Any] = {}
missing_key = 'missing' if infer_missing else 'default'
if field.default is not MISSING:
options[missing_key] = field.default
elif field.default_factory is not MISSING:
options[missing_key] = field.default_factory()
else:
options['required'] = True
if options.get(missing_key, ...) is None:
options['allow_none'] = True
if _is_optional(type_):
options.setdefault(missing_key, None)
options['allow_none'] = True
if len(type_.__args__) == 2:
# Union[str, int, None] is optional too, but it has more than 1 typed field.
type_ = [tp for tp in type_.__args__ if tp is not type(None)][0]
if metadata.letter_case is not None:
options['data_key'] = metadata.letter_case(field.name)
t = build_type(type_, options, mixin, field, cls)
if field.metadata.get('dataclasses_json', {}).get('decoder'):
# If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion
# logic.
# From Marshmallow's documentation for the _deserialize method:
# "Deserialize value. Concrete :class:`Field` classes should implement this method. "
# This is the method that Field implementations override to perform the actual deserialization logic.
# In this case we specifically override this method instead of `deserialize` to minimize potential
# side effects, and only cancel the actual value deserialization.
t._deserialize = lambda v, *_a, **_kw: v
# if type(t) is not fields.Field: # If we use `isinstance` we would return nothing.
if field.type != typing.Optional[CatchAllVar]:
schema[field.name] = t
return schema
def build_schema(cls: typing.Type[A],
mixin,
infer_missing,
partial) -> typing.Type["SchemaType[A]"]:
Meta = type('Meta',
(),
{'fields': tuple(field.name for field in dc_fields(cls) # type: ignore
if
field.name != 'dataclass_json_config' and field.type !=
typing.Optional[CatchAllVar]),
# TODO #180
# 'render_module': global_config.json_module
})
@post_load
def make_instance(self, kvs, **kwargs):
return _decode_dataclass(cls, kvs, partial)
def dumps(self, *args, **kwargs):
if 'cls' not in kwargs:
kwargs['cls'] = _ExtendedEncoder
return Schema.dumps(self, *args, **kwargs)
def dump(self, obj, *, many=None):
many = self.many if many is None else bool(many)
dumped = Schema.dump(self, obj, many=many)
# TODO This is hacky, but the other option I can think of is to generate a different schema
# depending on dump and load, which is even more hacky
# The only problem is the catch-all field, we can't statically create a schema for it,
# so we just update the dumped dict
if many:
for i, _obj in enumerate(obj):
dumped[i].update(
_handle_undefined_parameters_safe(cls=_obj, kvs={},
usage="dump"))
else:
dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={},
usage="dump"))
return dumped
schema_ = schema(cls, mixin, infer_missing)
DataClassSchema: typing.Type["SchemaType[A]"] = type(
f'{cls.__name__.capitalize()}Schema',
(Schema,),
{'Meta': Meta,
f'make_{cls.__name__.lower()}': make_instance,
'dumps': dumps,
'dump': dump,
**schema_})
return DataClassSchema