import copy import json import sys import warnings from collections import defaultdict, namedtuple from dataclasses import (MISSING, fields, is_dataclass # type: ignore ) from datetime import datetime, timezone from decimal import Decimal from enum import Enum from typing import (Any, Collection, Mapping, Union, get_type_hints, Tuple, TypeVar, Type) from uuid import UUID from typing_inspect import is_union_type # type: ignore from dataclasses_json import cfg from dataclasses_json.utils import (_get_type_cons, _get_type_origin, _handle_undefined_parameters_safe, _is_collection, _is_mapping, _is_new_type, _is_optional, _isinstance_safe, _get_type_arg_param, _get_type_args, _is_counter, _NO_ARGS, _issubclass_safe, _is_tuple) Json = Union[dict, list, str, int, float, bool, None] confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude'] FieldOverride = namedtuple('FieldOverride', confs) # type: ignore class _ExtendedEncoder(json.JSONEncoder): def default(self, o) -> Json: result: Json if _isinstance_safe(o, Collection): if _isinstance_safe(o, Mapping): result = dict(o) else: result = list(o) elif _isinstance_safe(o, datetime): result = o.timestamp() elif _isinstance_safe(o, UUID): result = str(o) elif _isinstance_safe(o, Enum): result = o.value elif _isinstance_safe(o, Decimal): result = str(o) else: result = json.JSONEncoder.default(self, o) return result def _user_overrides_or_exts(cls): global_metadata = defaultdict(dict) encoders = cfg.global_config.encoders decoders = cfg.global_config.decoders mm_fields = cfg.global_config.mm_fields for field in fields(cls): if field.type in encoders: global_metadata[field.name]['encoder'] = encoders[field.type] if field.type in decoders: global_metadata[field.name]['decoder'] = decoders[field.type] if field.type in mm_fields: global_metadata[field.name]['mm_field'] = mm_fields[field.type] try: cls_config = (cls.dataclass_json_config if cls.dataclass_json_config is not None else {}) except AttributeError: cls_config = {} overrides = {} for field in fields(cls): field_config = {} # first apply global overrides or extensions field_metadata = global_metadata[field.name] if 'encoder' in field_metadata: field_config['encoder'] = field_metadata['encoder'] if 'decoder' in field_metadata: field_config['decoder'] = field_metadata['decoder'] if 'mm_field' in field_metadata: field_config['mm_field'] = field_metadata['mm_field'] # then apply class-level overrides or extensions field_config.update(cls_config) # last apply field-level overrides or extensions field_config.update(field.metadata.get('dataclasses_json', {})) overrides[field.name] = FieldOverride(*map(field_config.get, confs)) return overrides def _encode_json_type(value, default=_ExtendedEncoder().default): if isinstance(value, Json.__args__): # type: ignore if isinstance(value, list): return [_encode_json_type(i) for i in value] elif isinstance(value, dict): return {k: _encode_json_type(v) for k, v in value.items()} else: return value return default(value) def _encode_overrides(kvs, overrides, encode_json=False): override_kvs = {} for k, v in kvs.items(): if k in overrides: exclude = overrides[k].exclude # If the exclude predicate returns true, the key should be # excluded from encoding, so skip the rest of the loop if exclude and exclude(v): continue letter_case = overrides[k].letter_case original_key = k k = letter_case(k) if letter_case is not None else k if k in override_kvs: raise ValueError( f"Multiple fields map to the same JSON " f"key after letter case encoding: {k}" ) encoder = overrides[original_key].encoder v = encoder(v) if encoder is not None else v if encode_json: v = _encode_json_type(v) override_kvs[k] = v return override_kvs def _decode_letter_case_overrides(field_names, overrides): """Override letter case of field names for encode/decode""" names = {} for field_name in field_names: field_override = overrides.get(field_name) if field_override is not None: letter_case = field_override.letter_case if letter_case is not None: names[letter_case(field_name)] = field_name return names def _decode_dataclass(cls, kvs, infer_missing): if _isinstance_safe(kvs, cls): return kvs overrides = _user_overrides_or_exts(cls) kvs = {} if kvs is None and infer_missing else kvs field_names = [field.name for field in fields(cls)] decode_names = _decode_letter_case_overrides(field_names, overrides) kvs = {decode_names.get(k, k): v for k, v in kvs.items()} missing_fields = {field for field in fields(cls) if field.name not in kvs} for field in missing_fields: if field.default is not MISSING: kvs[field.name] = field.default elif field.default_factory is not MISSING: kvs[field.name] = field.default_factory() elif infer_missing: kvs[field.name] = None # Perform undefined parameter action kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from") init_kwargs = {} types = get_type_hints(cls) for field in fields(cls): # The field should be skipped from being added # to init_kwargs as it's not intended as a constructor argument. if not field.init: continue field_value = kvs[field.name] field_type = types[field.name] if field_value is None: if not _is_optional(field_type): warning = ( f"value of non-optional type {field.name} detected " f"when decoding {cls.__name__}" ) if infer_missing: warnings.warn( f"Missing {warning} and was defaulted to None by " f"infer_missing=True. " f"Set infer_missing=False (the default) to prevent " f"this behavior.", RuntimeWarning ) else: warnings.warn( f"'NoneType' object {warning}.", RuntimeWarning ) init_kwargs[field.name] = field_value continue while True: if not _is_new_type(field_type): break field_type = field_type.__supertype__ if (field.name in overrides and overrides[field.name].decoder is not None): # FIXME hack if field_type is type(field_value): init_kwargs[field.name] = field_value else: init_kwargs[field.name] = overrides[field.name].decoder( field_value) elif is_dataclass(field_type): # FIXME this is a band-aid to deal with the value already being # serialized when handling nested marshmallow schema # proper fix is to investigate the marshmallow schema generation # code if is_dataclass(field_value): value = field_value else: value = _decode_dataclass(field_type, field_value, infer_missing) init_kwargs[field.name] = value elif _is_supported_generic(field_type) and field_type != str: init_kwargs[field.name] = _decode_generic(field_type, field_value, infer_missing) else: init_kwargs[field.name] = _support_extended_types(field_type, field_value) return cls(**init_kwargs) def _support_extended_types(field_type, field_value): if _issubclass_safe(field_type, datetime): # FIXME this is a hack to deal with mm already decoding # the issue is we want to leverage mm fields' missing argument # but need this for the object creation hook if isinstance(field_value, datetime): res = field_value else: tz = datetime.now(timezone.utc).astimezone().tzinfo res = datetime.fromtimestamp(field_value, tz=tz) elif _issubclass_safe(field_type, Decimal): res = (field_value if isinstance(field_value, Decimal) else Decimal(field_value)) elif _issubclass_safe(field_type, UUID): res = (field_value if isinstance(field_value, UUID) else UUID(field_value)) elif _issubclass_safe(field_type, (int, float, str, bool)): res = (field_value if isinstance(field_value, field_type) else field_type(field_value)) else: res = field_value return res def _is_supported_generic(type_): if type_ is _NO_ARGS: return False not_str = not _issubclass_safe(type_, str) is_enum = _issubclass_safe(type_, Enum) return (not_str and _is_collection(type_)) or _is_optional( type_) or is_union_type(type_) or is_enum def _decode_generic(type_, value, infer_missing): if value is None: res = value elif _issubclass_safe(type_, Enum): # Convert to an Enum using the type as a constructor. # Assumes a direct match is found. res = type_(value) # FIXME this is a hack to fix a deeper underlying issue. A refactor is due. elif _is_collection(type_): if _is_mapping(type_) and not _is_counter(type_): k_type, v_type = _get_type_args(type_, (Any, Any)) # a mapping type has `.keys()` and `.values()` # (see collections.abc) ks = _decode_dict_keys(k_type, value.keys(), infer_missing) vs = _decode_items(v_type, value.values(), infer_missing) xs = zip(ks, vs) elif _is_tuple(type_): types = _get_type_args(type_) if Ellipsis in types: xs = _decode_items(types[0], value, infer_missing) else: xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing) elif _is_counter(type_): xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values())) else: xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing) # get the constructor if using corresponding generic type in `typing` # otherwise fallback on constructing using type_ itself materialize_type = type_ try: materialize_type = _get_type_cons(type_) except (TypeError, AttributeError): pass res = materialize_type(xs) else: # Optional or Union _args = _get_type_args(type_) if _args is _NO_ARGS: # Any, just accept res = value elif _is_optional(type_) and len(_args) == 2: # Optional type_arg = _get_type_arg_param(type_, 0) if is_dataclass(type_arg) or is_dataclass(value): res = _decode_dataclass(type_arg, value, infer_missing) elif _is_supported_generic(type_arg): res = _decode_generic(type_arg, value, infer_missing) else: res = _support_extended_types(type_arg, value) else: # Union (already decoded or try to decode a dataclass) type_options = _get_type_args(type_) res = value # assume already decoded if type(value) is dict and dict not in type_options: for type_option in type_options: if is_dataclass(type_option): try: res = _decode_dataclass(type_option, value, infer_missing) break except (KeyError, ValueError, AttributeError): continue if res == value: warnings.warn( f"Failed to decode {value} Union dataclasses." f"Expected Union to include a matching dataclass and it didn't." ) return res def _decode_dict_keys(key_type, xs, infer_missing): """ Because JSON object keys must be strs, we need the extra step of decoding them back into the user's chosen python type """ decode_function = key_type # handle NoneType keys... it's weird to type a Dict as NoneType keys # but it's valid... # Issue #341 and PR #346: # This is a special case for Python 3.7 and Python 3.8. # By some reason, "unbound" dicts are counted # as having key type parameter to be TypeVar('KT') if key_type is None or key_type == Any or isinstance(key_type, TypeVar): decode_function = key_type = (lambda x: x) # handle a nested python dict that has tuples for keys. E.g. for # Dict[Tuple[int], int], key_type will be typing.Tuple[int], but # decode_function should be tuple, so map() doesn't break. # # Note: _get_type_origin() will return typing.Tuple for python # 3.6 and tuple for 3.7 and higher. elif _get_type_origin(key_type) in {tuple, Tuple}: decode_function = tuple key_type = key_type return map(decode_function, _decode_items(key_type, xs, infer_missing)) def _decode_items(type_args, xs, infer_missing): """ This is a tricky situation where we need to check both the annotated type info (which is usually a type from `typing`) and check the value's type directly using `type()`. If the type_arg is a generic we can use the annotated type, but if the type_arg is a typevar we need to extract the reified type information hence the check of `is_dataclass(vs)` """ def _decode_item(type_arg, x): if is_dataclass(type_arg) or is_dataclass(xs): return _decode_dataclass(type_arg, x, infer_missing) if _is_supported_generic(type_arg): return _decode_generic(type_arg, x, infer_missing) return x def handle_pep0673(pre_0673_hint: str) -> Union[Type, str]: for module in sys.modules: maybe_resolved = getattr(sys.modules[module], type_args, None) if maybe_resolved: return maybe_resolved warnings.warn(f"Could not resolve self-reference for type {pre_0673_hint}, " f"decoded type might be incorrect or decode might fail altogether.") return pre_0673_hint # Before https://peps.python.org/pep-0673 (3.11+) self-type hints are simply strings if sys.version_info.minor < 11 and type_args is not type and type(type_args) is str: type_args = handle_pep0673(type_args) if _isinstance_safe(type_args, Collection) and not _issubclass_safe(type_args, Enum): if len(type_args) == len(xs): return list(_decode_item(type_arg, x) for type_arg, x in zip(type_args, xs)) else: raise TypeError(f"Number of types specified in the collection type {str(type_args)} " f"does not match number of elements in the collection. In case you are working with tuples" f"take a look at this document " f"docs.python.org/3/library/typing.html#annotating-tuples.") return list(_decode_item(type_args, x) for x in xs) def _asdict(obj, encode_json=False): """ A re-implementation of `asdict` (based on the original in the `dataclasses` source) to support arbitrary Collection and Mapping types. """ if is_dataclass(obj): result = [] overrides = _user_overrides_or_exts(obj) for field in fields(obj): if overrides[field.name].encoder: value = getattr(obj, field.name) else: value = _asdict( getattr(obj, field.name), encode_json=encode_json ) result.append((field.name, value)) result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result), usage="to") return _encode_overrides(dict(result), _user_overrides_or_exts(obj), encode_json=encode_json) elif isinstance(obj, Mapping): return dict((_asdict(k, encode_json=encode_json), _asdict(v, encode_json=encode_json)) for k, v in obj.items()) # enum.IntFlag and enum.Flag are regarded as collections in Python 3.11, thus a check against Enum is needed elif isinstance(obj, Collection) and not isinstance(obj, (str, bytes, Enum)): return list(_asdict(v, encode_json=encode_json) for v in obj) else: return copy.deepcopy(obj)