Spaces:
Build error
Build error
| # functions to transform a c class into a dataclass | |
| from collections import OrderedDict | |
| from textwrap import dedent | |
| import operator | |
| from . import ExprNodes | |
| from . import Nodes | |
| from . import PyrexTypes | |
| from . import Builtin | |
| from . import Naming | |
| from .Errors import error, warning | |
| from .Code import UtilityCode, TempitaUtilityCode, PyxCodeWriter | |
| from .Visitor import VisitorTransform | |
| from .StringEncoding import EncodedString | |
| from .TreeFragment import TreeFragment | |
| from .ParseTreeTransforms import NormalizeTree, SkipDeclarations | |
| from .Options import copy_inherited_directives | |
| _dataclass_loader_utilitycode = None | |
| def make_dataclasses_module_callnode(pos): | |
| global _dataclass_loader_utilitycode | |
| if not _dataclass_loader_utilitycode: | |
| python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py") | |
| python_utility_code = EncodedString(python_utility_code.impl) | |
| _dataclass_loader_utilitycode = TempitaUtilityCode.load( | |
| "SpecificModuleLoader", "Dataclasses.c", | |
| context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()}) | |
| return ExprNodes.PythonCapiCallNode( | |
| pos, "__Pyx_Load_dataclasses_Module", | |
| PyrexTypes.CFuncType(PyrexTypes.py_object_type, []), | |
| utility_code=_dataclass_loader_utilitycode, | |
| args=[], | |
| ) | |
| def make_dataclass_call_helper(pos, callable, kwds): | |
| utility_code = UtilityCode.load_cached("DataclassesCallHelper", "Dataclasses.c") | |
| func_type = PyrexTypes.CFuncType( | |
| PyrexTypes.py_object_type, [ | |
| PyrexTypes.CFuncTypeArg("callable", PyrexTypes.py_object_type, None), | |
| PyrexTypes.CFuncTypeArg("kwds", PyrexTypes.py_object_type, None) | |
| ], | |
| ) | |
| return ExprNodes.PythonCapiCallNode( | |
| pos, | |
| function_name="__Pyx_DataclassesCallHelper", | |
| func_type=func_type, | |
| utility_code=utility_code, | |
| args=[callable, kwds], | |
| ) | |
| class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations): | |
| """ | |
| Cython (and Python) normally treats | |
| class A: | |
| x = 1 | |
| as generating a class attribute. However for dataclasses the `= 1` should be interpreted as | |
| a default value to initialize an instance attribute with. | |
| This transform therefore removes the `x=1` assignment so that the class attribute isn't | |
| generated, while recording what it has removed so that it can be used in the initialization. | |
| """ | |
| def __init__(self, names): | |
| super(RemoveAssignmentsToNames, self).__init__() | |
| self.names = names | |
| self.removed_assignments = {} | |
| def visit_CClassNode(self, node): | |
| self.visitchildren(node) | |
| return node | |
| def visit_PyClassNode(self, node): | |
| return node # go no further | |
| def visit_FuncDefNode(self, node): | |
| return node # go no further | |
| def visit_SingleAssignmentNode(self, node): | |
| if node.lhs.is_name and node.lhs.name in self.names: | |
| if node.lhs.name in self.removed_assignments: | |
| warning(node.pos, ("Multiple assignments for '%s' in dataclass; " | |
| "using most recent") % node.lhs.name, 1) | |
| self.removed_assignments[node.lhs.name] = node.rhs | |
| return [] | |
| return node | |
| # I believe cascaded assignment is always a syntax error with annotations | |
| # so there's no need to define visit_CascadedAssignmentNode | |
| def visit_Node(self, node): | |
| self.visitchildren(node) | |
| return node | |
| class TemplateCode(object): | |
| """ | |
| Adds the ability to keep track of placeholder argument names to PyxCodeWriter. | |
| Also adds extra_stats which are nodes bundled at the end when this | |
| is converted to a tree. | |
| """ | |
| _placeholder_count = 0 | |
| def __init__(self, writer=None, placeholders=None, extra_stats=None): | |
| self.writer = PyxCodeWriter() if writer is None else writer | |
| self.placeholders = {} if placeholders is None else placeholders | |
| self.extra_stats = [] if extra_stats is None else extra_stats | |
| def add_code_line(self, code_line): | |
| self.writer.putln(code_line) | |
| def add_code_lines(self, code_lines): | |
| for line in code_lines: | |
| self.writer.putln(line) | |
| def reset(self): | |
| # don't attempt to reset placeholders - it really doesn't matter if | |
| # we have unused placeholders | |
| self.writer.reset() | |
| def empty(self): | |
| return self.writer.empty() | |
| def indenter(self): | |
| return self.writer.indenter() | |
| def new_placeholder(self, field_names, value): | |
| name = self._new_placeholder_name(field_names) | |
| self.placeholders[name] = value | |
| return name | |
| def add_extra_statements(self, statements): | |
| if self.extra_stats is None: | |
| assert False, "Can only use add_extra_statements on top-level writer" | |
| self.extra_stats.extend(statements) | |
| def _new_placeholder_name(self, field_names): | |
| while True: | |
| name = "DATACLASS_PLACEHOLDER_%d" % self._placeholder_count | |
| if (name not in self.placeholders | |
| and name not in field_names): | |
| # make sure name isn't already used and doesn't | |
| # conflict with a variable name (which is unlikely but possible) | |
| break | |
| self._placeholder_count += 1 | |
| return name | |
| def generate_tree(self, level='c_class'): | |
| stat_list_node = TreeFragment( | |
| self.writer.getvalue(), | |
| level=level, | |
| pipeline=[NormalizeTree(None)], | |
| ).substitute(self.placeholders) | |
| stat_list_node.stats += self.extra_stats | |
| return stat_list_node | |
| def insertion_point(self): | |
| new_writer = self.writer.insertion_point() | |
| return TemplateCode( | |
| writer=new_writer, | |
| placeholders=self.placeholders, | |
| extra_stats=self.extra_stats | |
| ) | |
| class _MISSING_TYPE(object): | |
| pass | |
| MISSING = _MISSING_TYPE() | |
| class Field(object): | |
| """ | |
| Field is based on the dataclasses.field class from the standard library module. | |
| It is used internally during the generation of Cython dataclasses to keep track | |
| of the settings for individual attributes. | |
| Attributes of this class are stored as nodes so they can be used in code construction | |
| more readily (i.e. we store BoolNode rather than bool) | |
| """ | |
| default = MISSING | |
| default_factory = MISSING | |
| private = False | |
| literal_keys = ("repr", "hash", "init", "compare", "metadata") | |
| # default values are defined by the CPython dataclasses.field | |
| def __init__(self, pos, default=MISSING, default_factory=MISSING, | |
| repr=None, hash=None, init=None, | |
| compare=None, metadata=None, | |
| is_initvar=False, is_classvar=False, | |
| **additional_kwds): | |
| if default is not MISSING: | |
| self.default = default | |
| if default_factory is not MISSING: | |
| self.default_factory = default_factory | |
| self.repr = repr or ExprNodes.BoolNode(pos, value=True) | |
| self.hash = hash or ExprNodes.NoneNode(pos) | |
| self.init = init or ExprNodes.BoolNode(pos, value=True) | |
| self.compare = compare or ExprNodes.BoolNode(pos, value=True) | |
| self.metadata = metadata or ExprNodes.NoneNode(pos) | |
| self.is_initvar = is_initvar | |
| self.is_classvar = is_classvar | |
| for k, v in additional_kwds.items(): | |
| # There should not be any additional keywords! | |
| error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k) | |
| for field_name in self.literal_keys: | |
| field_value = getattr(self, field_name) | |
| if not field_value.is_literal: | |
| error(field_value.pos, | |
| "cython.dataclasses.field parameter '%s' must be a literal value" % field_name) | |
| def iterate_record_node_arguments(self): | |
| for key in (self.literal_keys + ('default', 'default_factory')): | |
| value = getattr(self, key) | |
| if value is not MISSING: | |
| yield key, value | |
| def process_class_get_fields(node): | |
| var_entries = node.scope.var_entries | |
| # order of definition is used in the dataclass | |
| var_entries = sorted(var_entries, key=operator.attrgetter('pos')) | |
| var_names = [entry.name for entry in var_entries] | |
| # don't treat `x = 1` as an assignment of a class attribute within the dataclass | |
| transform = RemoveAssignmentsToNames(var_names) | |
| transform(node) | |
| default_value_assignments = transform.removed_assignments | |
| base_type = node.base_type | |
| fields = OrderedDict() | |
| while base_type: | |
| if base_type.is_external or not base_type.scope.implemented: | |
| warning(node.pos, "Cannot reliably handle Cython dataclasses with base types " | |
| "in external modules since it is not possible to tell what fields they have", 2) | |
| if base_type.dataclass_fields: | |
| fields = base_type.dataclass_fields.copy() | |
| break | |
| base_type = base_type.base_type | |
| for entry in var_entries: | |
| name = entry.name | |
| is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar") | |
| # TODO - classvars aren't included in "var_entries" so are missed here | |
| # and thus this code is never triggered | |
| is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar") | |
| if name in default_value_assignments: | |
| assignment = default_value_assignments[name] | |
| if (isinstance(assignment, ExprNodes.CallNode) and ( | |
| assignment.function.as_cython_attribute() == "dataclasses.field" or | |
| Builtin.exprnode_to_known_standard_library_name( | |
| assignment.function, node.scope) == "dataclasses.field")): | |
| # I believe most of this is well-enforced when it's treated as a directive | |
| # but it doesn't hurt to make sure | |
| valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode) | |
| and isinstance(assignment.positional_args, ExprNodes.TupleNode) | |
| and not assignment.positional_args.args | |
| and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode))) | |
| valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args) | |
| if not (valid_general_call or valid_simple_call): | |
| error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist " | |
| "of compile-time keyword arguments") | |
| continue | |
| keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {} | |
| if 'default' in keyword_args and 'default_factory' in keyword_args: | |
| error(assignment.pos, "cannot specify both default and default_factory") | |
| continue | |
| field = Field(node.pos, **keyword_args) | |
| else: | |
| if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]: | |
| # The standard library module generates a TypeError at runtime | |
| # in this situation. | |
| # Error message is copied from CPython | |
| error(assignment.pos, "mutable default <class '{0}'> for field {1} is not allowed: " | |
| "use default_factory".format(assignment.type.name, name)) | |
| field = Field(node.pos, default=assignment) | |
| else: | |
| field = Field(node.pos) | |
| field.is_initvar = is_initvar | |
| field.is_classvar = is_classvar | |
| if entry.visibility == "private": | |
| field.private = True | |
| fields[name] = field | |
| node.entry.type.dataclass_fields = fields | |
| return fields | |
| def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform): | |
| # default argument values from https://docs.python.org/3/library/dataclasses.html | |
| kwargs = dict(init=True, repr=True, eq=True, | |
| order=False, unsafe_hash=False, | |
| frozen=False, kw_only=False) | |
| if dataclass_args is not None: | |
| if dataclass_args[0]: | |
| error(node.pos, "cython.dataclasses.dataclass takes no positional arguments") | |
| for k, v in dataclass_args[1].items(): | |
| if k not in kwargs: | |
| error(node.pos, | |
| "cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k) | |
| if not isinstance(v, ExprNodes.BoolNode): | |
| error(node.pos, | |
| "Arguments passed to cython.dataclasses.dataclass must be True or False") | |
| kwargs[k] = v.value | |
| kw_only = kwargs['kw_only'] | |
| fields = process_class_get_fields(node) | |
| dataclass_module = make_dataclasses_module_callnode(node.pos) | |
| # create __dataclass_params__ attribute. I try to use the exact | |
| # `_DataclassParams` class defined in the standard library module if at all possible | |
| # for maximum duck-typing compatibility. | |
| dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module, | |
| attribute=EncodedString("_DataclassParams")) | |
| dataclass_params_keywords = ExprNodes.DictNode.from_pairs( | |
| node.pos, | |
| [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), | |
| ExprNodes.BoolNode(node.pos, value=v)) | |
| for k, v in kwargs.items() ] + | |
| [ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), | |
| ExprNodes.BoolNode(node.pos, value=v)) | |
| for k, v in [('kw_only', kw_only), ('match_args', False), | |
| ('slots', False), ('weakref_slot', False)] | |
| ]) | |
| dataclass_params = make_dataclass_call_helper( | |
| node.pos, dataclass_params_func, dataclass_params_keywords) | |
| dataclass_params_assignment = Nodes.SingleAssignmentNode( | |
| node.pos, | |
| lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")), | |
| rhs = dataclass_params) | |
| dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module) | |
| stats = Nodes.StatListNode(node.pos, | |
| stats=[dataclass_params_assignment] + dataclass_fields_stats) | |
| code = TemplateCode() | |
| generate_init_code(code, kwargs['init'], node, fields, kw_only) | |
| generate_repr_code(code, kwargs['repr'], node, fields) | |
| generate_eq_code(code, kwargs['eq'], node, fields) | |
| generate_order_code(code, kwargs['order'], node, fields) | |
| generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) | |
| stats.stats += code.generate_tree().stats | |
| # turn off annotation typing, so all arguments to __init__ are accepted as | |
| # generic objects and thus can accept _HAS_DEFAULT_FACTORY. | |
| # Type conversion comes later | |
| comp_directives = Nodes.CompilerDirectivesNode(node.pos, | |
| directives=copy_inherited_directives(node.scope.directives, annotation_typing=False), | |
| body=stats) | |
| comp_directives.analyse_declarations(node.scope) | |
| # probably already in this scope, but it doesn't hurt to make sure | |
| analyse_decs_transform.enter_scope(node, node.scope) | |
| analyse_decs_transform.visit(comp_directives) | |
| analyse_decs_transform.exit_scope() | |
| node.body.stats.append(comp_directives) | |
| def generate_init_code(code, init, node, fields, kw_only): | |
| """ | |
| Notes on CPython generated "__init__": | |
| * Implemented in `_init_fn`. | |
| * The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as | |
| the default argument for fields that need constructing with a factory | |
| function is copied from the CPython implementation. (`None` isn't | |
| suitable because it could also be a value for the user to pass.) | |
| There's no real reason why it needs importing from the dataclasses module | |
| though - it could equally be a value generated by Cython when the module loads. | |
| * seen_default and the associated error message are copied directly from Python | |
| * Call to user-defined __post_init__ function (if it exists) is copied from | |
| CPython. | |
| Cython behaviour deviates a little here (to be decided if this is right...) | |
| Because the class variable from the assignment does not exist Cython fields will | |
| return None (or whatever their type default is) if not initialized while Python | |
| dataclasses will fall back to looking up the class variable. | |
| """ | |
| if not init or node.scope.lookup_here("__init__"): | |
| return | |
| # selfname behaviour copied from the cpython module | |
| selfname = "__dataclass_self__" if "self" in fields else "self" | |
| args = [selfname] | |
| if kw_only: | |
| args.append("*") | |
| function_start_point = code.insertion_point() | |
| code = code.insertion_point() | |
| # create a temp to get _HAS_DEFAULT_FACTORY | |
| dataclass_module = make_dataclasses_module_callnode(node.pos) | |
| has_default_factory = ExprNodes.AttributeNode( | |
| node.pos, | |
| obj=dataclass_module, | |
| attribute=EncodedString("_HAS_DEFAULT_FACTORY") | |
| ) | |
| default_factory_placeholder = code.new_placeholder(fields, has_default_factory) | |
| seen_default = False | |
| for name, field in fields.items(): | |
| entry = node.scope.lookup(name) | |
| if entry.annotation: | |
| annotation = u": %s" % entry.annotation.string.value | |
| else: | |
| annotation = u"" | |
| assignment = u'' | |
| if field.default is not MISSING or field.default_factory is not MISSING: | |
| seen_default = True | |
| if field.default_factory is not MISSING: | |
| ph_name = default_factory_placeholder | |
| else: | |
| ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node | |
| assignment = u" = %s" % ph_name | |
| elif seen_default and not kw_only and field.init.value: | |
| error(entry.pos, ("non-default argument '%s' follows default argument " | |
| "in dataclass __init__") % name) | |
| code.reset() | |
| return | |
| if field.init.value: | |
| args.append(u"%s%s%s" % (name, annotation, assignment)) | |
| if field.is_initvar: | |
| continue | |
| elif field.default_factory is MISSING: | |
| if field.init.value: | |
| code.add_code_line(u" %s.%s = %s" % (selfname, name, name)) | |
| elif assignment: | |
| # not an argument to the function, but is still initialized | |
| code.add_code_line(u" %s.%s%s" % (selfname, name, assignment)) | |
| else: | |
| ph_name = code.new_placeholder(fields, field.default_factory) | |
| if field.init.value: | |
| # close to: | |
| # def __init__(self, name=_PLACEHOLDER_VALUE): | |
| # self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name | |
| code.add_code_line(u" %s.%s = %s() if %s is %s else %s" % ( | |
| selfname, name, ph_name, name, default_factory_placeholder, name)) | |
| else: | |
| # still need to use the default factory to initialize | |
| code.add_code_line(u" %s.%s = %s()" % ( | |
| selfname, name, ph_name)) | |
| if node.scope.lookup("__post_init__"): | |
| post_init_vars = ", ".join(name for name, field in fields.items() | |
| if field.is_initvar) | |
| code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars)) | |
| if code.empty(): | |
| code.add_code_line(" pass") | |
| args = u", ".join(args) | |
| function_start_point.add_code_line(u"def __init__(%s):" % args) | |
| def generate_repr_code(code, repr, node, fields): | |
| """ | |
| The core of the CPython implementation is just: | |
| ['return self.__class__.__qualname__ + f"(' + | |
| ', '.join([f"{f.name}={{self.{f.name}!r}}" | |
| for f in fields]) + | |
| ')"'], | |
| The only notable difference here is self.__class__.__qualname__ -> type(self).__name__ | |
| which is because Cython currently supports Python 2. | |
| However, it also has some guards for recursive repr invocations. In the standard | |
| library implementation they're done with a wrapper decorator that captures a set | |
| (with the set keyed by id and thread). Here we create a set as a thread local | |
| variable and key only by id. | |
| """ | |
| if not repr or node.scope.lookup("__repr__"): | |
| return | |
| # The recursive guard is likely a little costly, so skip it if possible. | |
| # is_gc_simple defines where it can contain recursive objects | |
| needs_recursive_guard = False | |
| for name in fields.keys(): | |
| entry = node.scope.lookup(name) | |
| type_ = entry.type | |
| if type_.is_memoryviewslice: | |
| type_ = type_.dtype | |
| if not type_.is_pyobject: | |
| continue # no GC | |
| if not type_.is_gc_simple: | |
| needs_recursive_guard = True | |
| break | |
| if needs_recursive_guard: | |
| code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()") | |
| code.add_code_line("__pyx_recursive_repr_guard.running = set()") | |
| code.add_code_line("def __repr__(self):") | |
| if needs_recursive_guard: | |
| code.add_code_line(" key = id(self)") | |
| code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running") | |
| code.add_code_line(" if key in guard_set: return '...'") | |
| code.add_code_line(" guard_set.add(key)") | |
| code.add_code_line(" try:") | |
| strs = [u"%s={self.%s!r}" % (name, name) | |
| for name, field in fields.items() | |
| if field.repr.value and not field.is_initvar] | |
| format_string = u", ".join(strs) | |
| code.add_code_line(u' name = getattr(type(self), "__qualname__", type(self).__name__)') | |
| code.add_code_line(u" return f'{name}(%s)'" % format_string) | |
| if needs_recursive_guard: | |
| code.add_code_line(" finally:") | |
| code.add_code_line(" guard_set.remove(key)") | |
| def generate_cmp_code(code, op, funcname, node, fields): | |
| if node.scope.lookup_here(funcname): | |
| return | |
| names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)] | |
| code.add_code_lines([ | |
| "def %s(self, other):" % funcname, | |
| " if other.__class__ is not self.__class__:" | |
| " return NotImplemented", | |
| # | |
| " cdef %s other_cast" % node.class_name, | |
| " other_cast = <%s>other" % node.class_name, | |
| ]) | |
| # The Python implementation of dataclasses.py does a tuple comparison | |
| # (roughly): | |
| # return self._attributes_to_tuple() {op} other._attributes_to_tuple() | |
| # | |
| # For the Cython implementation a tuple comparison isn't an option because | |
| # not all attributes can be converted to Python objects and stored in a tuple | |
| # | |
| # TODO - better diagnostics of whether the types support comparison before | |
| # generating the code. Plus, do we want to convert C structs to dicts and | |
| # compare them that way (I think not, but it might be in demand)? | |
| checks = [] | |
| op_without_equals = op.replace('=', '') | |
| for name in names: | |
| if op != '==': | |
| # tuple comparison rules - early elements take precedence | |
| code.add_code_line(" if self.%s %s other_cast.%s: return True" % ( | |
| name, op_without_equals, name)) | |
| code.add_code_line(" if self.%s != other_cast.%s: return False" % ( | |
| name, name)) | |
| if "=" in op: | |
| code.add_code_line(" return True") # "() == ()" is True | |
| else: | |
| code.add_code_line(" return False") | |
| def generate_eq_code(code, eq, node, fields): | |
| if not eq: | |
| return | |
| generate_cmp_code(code, "==", "__eq__", node, fields) | |
| def generate_order_code(code, order, node, fields): | |
| if not order: | |
| return | |
| for op, name in [("<", "__lt__"), | |
| ("<=", "__le__"), | |
| (">", "__gt__"), | |
| (">=", "__ge__")]: | |
| generate_cmp_code(code, op, name, node, fields) | |
| def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields): | |
| """ | |
| Copied from CPython implementation - the intention is to follow this as far as | |
| is possible: | |
| # +------------------- unsafe_hash= parameter | |
| # | +----------- eq= parameter | |
| # | | +--- frozen= parameter | |
| # | | | | |
| # v v v | | | | |
| # | no | yes | <--- class has explicitly defined __hash__ | |
| # +=======+=======+=======+========+========+ | |
| # | False | False | False | | | No __eq__, use the base class __hash__ | |
| # +-------+-------+-------+--------+--------+ | |
| # | False | False | True | | | No __eq__, use the base class __hash__ | |
| # +-------+-------+-------+--------+--------+ | |
| # | False | True | False | None | | <-- the default, not hashable | |
| # +-------+-------+-------+--------+--------+ | |
| # | False | True | True | add | | Frozen, so hashable, allows override | |
| # +-------+-------+-------+--------+--------+ | |
| # | True | False | False | add | raise | Has no __eq__, but hashable | |
| # +-------+-------+-------+--------+--------+ | |
| # | True | False | True | add | raise | Has no __eq__, but hashable | |
| # +-------+-------+-------+--------+--------+ | |
| # | True | True | False | add | raise | Not frozen, but hashable | |
| # +-------+-------+-------+--------+--------+ | |
| # | True | True | True | add | raise | Frozen, so hashable | |
| # +=======+=======+=======+========+========+ | |
| # For boxes that are blank, __hash__ is untouched and therefore | |
| # inherited from the base class. If the base is object, then | |
| # id-based hashing is used. | |
| The Python implementation creates a tuple of all the fields, then hashes them. | |
| This implementation creates a tuple of all the hashes of all the fields and hashes that. | |
| The reason for this slight difference is to avoid to-Python conversions for anything | |
| that Cython knows how to hash directly (It doesn't look like this currently applies to | |
| anything though...). | |
| """ | |
| hash_entry = node.scope.lookup_here("__hash__") | |
| if hash_entry: | |
| # TODO ideally assignment of __hash__ to None shouldn't trigger this | |
| # but difficult to get the right information here | |
| if unsafe_hash: | |
| # error message taken from CPython dataclasses module | |
| error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name) | |
| return | |
| if not unsafe_hash: | |
| if not eq: | |
| return | |
| if not frozen: | |
| code.add_extra_statements([ | |
| Nodes.SingleAssignmentNode( | |
| node.pos, | |
| lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")), | |
| rhs=ExprNodes.NoneNode(node.pos), | |
| ) | |
| ]) | |
| return | |
| names = [ | |
| name for name, field in fields.items() | |
| if not field.is_initvar and ( | |
| field.compare.value if field.hash.value is None else field.hash.value) | |
| ] | |
| # make a tuple of the hashes | |
| hash_tuple_items = u", ".join(u"self.%s" % name for name in names) | |
| if hash_tuple_items: | |
| hash_tuple_items += u"," # ensure that one arg form is a tuple | |
| # if we're here we want to generate a hash | |
| code.add_code_lines([ | |
| "def __hash__(self):", | |
| " return hash((%s))" % hash_tuple_items, | |
| ]) | |
| def get_field_type(pos, entry): | |
| """ | |
| sets the .type attribute for a field | |
| Returns the annotation if possible (since this is what the dataclasses | |
| module does). If not (for example, attributes defined with cdef) then | |
| it creates a string fallback. | |
| """ | |
| if entry.annotation: | |
| # Right now it doesn't look like cdef classes generate an | |
| # __annotations__ dict, therefore it's safe to just return | |
| # entry.annotation | |
| # (TODO: remove .string if we ditch PEP563) | |
| return entry.annotation.string | |
| # If they do in future then we may need to look up into that | |
| # to duplicating the node. The code below should do this: | |
| #class_name_node = ExprNodes.NameNode(pos, name=entry.scope.name) | |
| #annotations = ExprNodes.AttributeNode( | |
| # pos, obj=class_name_node, | |
| # attribute=EncodedString("__annotations__") | |
| #) | |
| #return ExprNodes.IndexNode( | |
| # pos, base=annotations, | |
| # index=ExprNodes.StringNode(pos, value=entry.name) | |
| #) | |
| else: | |
| # it's slightly unclear what the best option is here - we could | |
| # try to return PyType_Type. This case should only happen with | |
| # attributes defined with cdef so Cython is free to make it's own | |
| # decision | |
| s = EncodedString(entry.type.declaration_code("", for_display=1)) | |
| return ExprNodes.StringNode(pos, value=s) | |
| class FieldRecordNode(ExprNodes.ExprNode): | |
| """ | |
| __dataclass_fields__ contains a bunch of field objects recording how each field | |
| of the dataclass was initialized (mainly corresponding to the arguments passed to | |
| the "field" function). This node is used for the attributes of these field objects. | |
| If possible, coerces `arg` to a Python object. | |
| Otherwise, generates a sensible backup string. | |
| """ | |
| subexprs = ['arg'] | |
| def __init__(self, pos, arg): | |
| super(FieldRecordNode, self).__init__(pos, arg=arg) | |
| def analyse_types(self, env): | |
| self.arg.analyse_types(env) | |
| self.type = self.arg.type | |
| return self | |
| def coerce_to_pyobject(self, env): | |
| if self.arg.type.can_coerce_to_pyobject(env): | |
| return self.arg.coerce_to_pyobject(env) | |
| else: | |
| # A string representation of the code that gave the field seems like a reasonable | |
| # fallback. This'll mostly happen for "default" and "default_factory" where the | |
| # type may be a C-type that can't be converted to Python. | |
| return self._make_string() | |
| def _make_string(self): | |
| from .AutoDocTransforms import AnnotationWriter | |
| writer = AnnotationWriter(description="Dataclass field") | |
| string = writer.write(self.arg) | |
| return ExprNodes.StringNode(self.pos, value=EncodedString(string)) | |
| def generate_evaluation_code(self, code): | |
| return self.arg.generate_evaluation_code(code) | |
| def _set_up_dataclass_fields(node, fields, dataclass_module): | |
| # For defaults and default_factories containing things like lambda, | |
| # they're already declared in the class scope, and it creates a big | |
| # problem if multiple copies are floating around in both the __init__ | |
| # function, and in the __dataclass_fields__ structure. | |
| # Therefore, create module-level constants holding these values and | |
| # pass those around instead | |
| # | |
| # If possible we use the `Field` class defined in the standard library | |
| # module so that the information stored here is as close to a regular | |
| # dataclass as is possible. | |
| variables_assignment_stats = [] | |
| for name, field in fields.items(): | |
| if field.private: | |
| continue # doesn't appear in the public interface | |
| for attrname in [ "default", "default_factory" ]: | |
| field_default = getattr(field, attrname) | |
| if field_default is MISSING or field_default.is_literal or field_default.is_name: | |
| # some simple cases where we don't need to set up | |
| # the variable as a module-level constant | |
| continue | |
| global_scope = node.scope.global_scope() | |
| module_field_name = global_scope.mangle( | |
| global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name), | |
| name) | |
| # create an entry in the global scope for this variable to live | |
| field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name)) | |
| field_node.entry = global_scope.declare_var( | |
| field_node.name, type=field_default.type or PyrexTypes.unspecified_type, | |
| pos=field_default.pos, cname=field_node.name, is_cdef=True, | |
| # TODO: do we need to set 'pytyping_modifiers' here? | |
| ) | |
| # replace the field so that future users just receive the namenode | |
| setattr(field, attrname, field_node) | |
| variables_assignment_stats.append( | |
| Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default)) | |
| placeholders = {} | |
| field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module, | |
| attribute=EncodedString("field")) | |
| dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[]) | |
| dc_fields_namevalue_assignments = [] | |
| for name, field in fields.items(): | |
| if field.private: | |
| continue # doesn't appear in the public interface | |
| type_placeholder_name = "PLACEHOLDER_%s" % name | |
| placeholders[type_placeholder_name] = get_field_type( | |
| node.pos, node.scope.entries[name] | |
| ) | |
| # defining these make the fields introspect more like a Python dataclass | |
| field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name | |
| if field.is_initvar: | |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( | |
| node.pos, obj=dataclass_module, | |
| attribute=EncodedString("_FIELD_INITVAR") | |
| ) | |
| elif field.is_classvar: | |
| # TODO - currently this isn't triggered | |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( | |
| node.pos, obj=dataclass_module, | |
| attribute=EncodedString("_FIELD_CLASSVAR") | |
| ) | |
| else: | |
| placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode( | |
| node.pos, obj=dataclass_module, | |
| attribute=EncodedString("_FIELD") | |
| ) | |
| dc_field_keywords = ExprNodes.DictNode.from_pairs( | |
| node.pos, | |
| [(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)), | |
| FieldRecordNode(node.pos, arg=v)) | |
| for k, v in field.iterate_record_node_arguments()] | |
| ) | |
| dc_field_call = make_dataclass_call_helper( | |
| node.pos, field_func, dc_field_keywords | |
| ) | |
| dc_fields.key_value_pairs.append( | |
| ExprNodes.DictItemNode( | |
| node.pos, | |
| key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)), | |
| value=dc_field_call)) | |
| dc_fields_namevalue_assignments.append( | |
| dedent(u"""\ | |
| __dataclass_fields__[{0!r}].name = {0!r} | |
| __dataclass_fields__[{0!r}].type = {1} | |
| __dataclass_fields__[{0!r}]._field_type = {2} | |
| """).format(name, type_placeholder_name, field_type_placeholder_name)) | |
| dataclass_fields_assignment = \ | |
| Nodes.SingleAssignmentNode(node.pos, | |
| lhs = ExprNodes.NameNode(node.pos, | |
| name=EncodedString("__dataclass_fields__")), | |
| rhs = dc_fields) | |
| dc_fields_namevalue_assignments = u"\n".join(dc_fields_namevalue_assignments) | |
| dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments, | |
| level="c_class", | |
| pipeline=[NormalizeTree(None)]) | |
| dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders) | |
| return (variables_assignment_stats | |
| + [dataclass_fields_assignment] | |
| + dc_fields_namevalue_assignments.stats) | |