Surbhi123's picture
Upload folder using huggingface_hub
64772a4 verified
# functions to transform a c class into a dataclass
from collections import OrderedDict
from textwrap import dedent
import operator
from . import ExprNodes
from . import Nodes
from . import PyrexTypes
from . import Builtin
from . import Naming
from .Errors import error, warning
from .Code import UtilityCode, TempitaUtilityCode, PyxCodeWriter
from .Visitor import VisitorTransform
from .StringEncoding import EncodedString
from .TreeFragment import TreeFragment
from .ParseTreeTransforms import NormalizeTree, SkipDeclarations
from .Options import copy_inherited_directives
_dataclass_loader_utilitycode = None
def make_dataclasses_module_callnode(pos):
global _dataclass_loader_utilitycode
if not _dataclass_loader_utilitycode:
python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py")
python_utility_code = EncodedString(python_utility_code.impl)
_dataclass_loader_utilitycode = TempitaUtilityCode.load(
"SpecificModuleLoader", "Dataclasses.c",
context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()})
return ExprNodes.PythonCapiCallNode(
pos, "__Pyx_Load_dataclasses_Module",
PyrexTypes.CFuncType(PyrexTypes.py_object_type, []),
utility_code=_dataclass_loader_utilitycode,
args=[],
)
def make_dataclass_call_helper(pos, callable, kwds):
utility_code = UtilityCode.load_cached("DataclassesCallHelper", "Dataclasses.c")
func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("callable", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("kwds", PyrexTypes.py_object_type, None)
],
)
return ExprNodes.PythonCapiCallNode(
pos,
function_name="__Pyx_DataclassesCallHelper",
func_type=func_type,
utility_code=utility_code,
args=[callable, kwds],
)
class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations):
"""
Cython (and Python) normally treats
class A:
x = 1
as generating a class attribute. However for dataclasses the `= 1` should be interpreted as
a default value to initialize an instance attribute with.
This transform therefore removes the `x=1` assignment so that the class attribute isn't
generated, while recording what it has removed so that it can be used in the initialization.
"""
def __init__(self, names):
super(RemoveAssignmentsToNames, self).__init__()
self.names = names
self.removed_assignments = {}
def visit_CClassNode(self, node):
self.visitchildren(node)
return node
def visit_PyClassNode(self, node):
return node # go no further
def visit_FuncDefNode(self, node):
return node # go no further
def visit_SingleAssignmentNode(self, node):
if node.lhs.is_name and node.lhs.name in self.names:
if node.lhs.name in self.removed_assignments:
warning(node.pos, ("Multiple assignments for '%s' in dataclass; "
"using most recent") % node.lhs.name, 1)
self.removed_assignments[node.lhs.name] = node.rhs
return []
return node
# I believe cascaded assignment is always a syntax error with annotations
# so there's no need to define visit_CascadedAssignmentNode
def visit_Node(self, node):
self.visitchildren(node)
return node
class TemplateCode(object):
"""
Adds the ability to keep track of placeholder argument names to PyxCodeWriter.
Also adds extra_stats which are nodes bundled at the end when this
is converted to a tree.
"""
_placeholder_count = 0
def __init__(self, writer=None, placeholders=None, extra_stats=None):
self.writer = PyxCodeWriter() if writer is None else writer
self.placeholders = {} if placeholders is None else placeholders
self.extra_stats = [] if extra_stats is None else extra_stats
def add_code_line(self, code_line):
self.writer.putln(code_line)
def add_code_lines(self, code_lines):
for line in code_lines:
self.writer.putln(line)
def reset(self):
# don't attempt to reset placeholders - it really doesn't matter if
# we have unused placeholders
self.writer.reset()
def empty(self):
return self.writer.empty()
def indenter(self):
return self.writer.indenter()
def new_placeholder(self, field_names, value):
name = self._new_placeholder_name(field_names)
self.placeholders[name] = value
return name
def add_extra_statements(self, statements):
if self.extra_stats is None:
assert False, "Can only use add_extra_statements on top-level writer"
self.extra_stats.extend(statements)
def _new_placeholder_name(self, field_names):
while True:
name = "DATACLASS_PLACEHOLDER_%d" % self._placeholder_count
if (name not in self.placeholders
and name not in field_names):
# make sure name isn't already used and doesn't
# conflict with a variable name (which is unlikely but possible)
break
self._placeholder_count += 1
return name
def generate_tree(self, level='c_class'):
stat_list_node = TreeFragment(
self.writer.getvalue(),
level=level,
pipeline=[NormalizeTree(None)],
).substitute(self.placeholders)
stat_list_node.stats += self.extra_stats
return stat_list_node
def insertion_point(self):
new_writer = self.writer.insertion_point()
return TemplateCode(
writer=new_writer,
placeholders=self.placeholders,
extra_stats=self.extra_stats
)
class _MISSING_TYPE(object):
pass
MISSING = _MISSING_TYPE()
class Field(object):
"""
Field is based on the dataclasses.field class from the standard library module.
It is used internally during the generation of Cython dataclasses to keep track
of the settings for individual attributes.
Attributes of this class are stored as nodes so they can be used in code construction
more readily (i.e. we store BoolNode rather than bool)
"""
default = MISSING
default_factory = MISSING
private = False
literal_keys = ("repr", "hash", "init", "compare", "metadata")
# default values are defined by the CPython dataclasses.field
def __init__(self, pos, default=MISSING, default_factory=MISSING,
repr=None, hash=None, init=None,
compare=None, metadata=None,
is_initvar=False, is_classvar=False,
**additional_kwds):
if default is not MISSING:
self.default = default
if default_factory is not MISSING:
self.default_factory = default_factory
self.repr = repr or ExprNodes.BoolNode(pos, value=True)
self.hash = hash or ExprNodes.NoneNode(pos)
self.init = init or ExprNodes.BoolNode(pos, value=True)
self.compare = compare or ExprNodes.BoolNode(pos, value=True)
self.metadata = metadata or ExprNodes.NoneNode(pos)
self.is_initvar = is_initvar
self.is_classvar = is_classvar
for k, v in additional_kwds.items():
# There should not be any additional keywords!
error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k)
for field_name in self.literal_keys:
field_value = getattr(self, field_name)
if not field_value.is_literal:
error(field_value.pos,
"cython.dataclasses.field parameter '%s' must be a literal value" % field_name)
def iterate_record_node_arguments(self):
for key in (self.literal_keys + ('default', 'default_factory')):
value = getattr(self, key)
if value is not MISSING:
yield key, value
def process_class_get_fields(node):
var_entries = node.scope.var_entries
# order of definition is used in the dataclass
var_entries = sorted(var_entries, key=operator.attrgetter('pos'))
var_names = [entry.name for entry in var_entries]
# don't treat `x = 1` as an assignment of a class attribute within the dataclass
transform = RemoveAssignmentsToNames(var_names)
transform(node)
default_value_assignments = transform.removed_assignments
base_type = node.base_type
fields = OrderedDict()
while base_type:
if base_type.is_external or not base_type.scope.implemented:
warning(node.pos, "Cannot reliably handle Cython dataclasses with base types "
"in external modules since it is not possible to tell what fields they have", 2)
if base_type.dataclass_fields:
fields = base_type.dataclass_fields.copy()
break
base_type = base_type.base_type
for entry in var_entries:
name = entry.name
is_initvar = entry.declared_with_pytyping_modifier("dataclasses.InitVar")
# TODO - classvars aren't included in "var_entries" so are missed here
# and thus this code is never triggered
is_classvar = entry.declared_with_pytyping_modifier("typing.ClassVar")
if name in default_value_assignments:
assignment = default_value_assignments[name]
if (isinstance(assignment, ExprNodes.CallNode) and (
assignment.function.as_cython_attribute() == "dataclasses.field" or
Builtin.exprnode_to_known_standard_library_name(
assignment.function, node.scope) == "dataclasses.field")):
# I believe most of this is well-enforced when it's treated as a directive
# but it doesn't hurt to make sure
valid_general_call = (isinstance(assignment, ExprNodes.GeneralCallNode)
and isinstance(assignment.positional_args, ExprNodes.TupleNode)
and not assignment.positional_args.args
and (assignment.keyword_args is None or isinstance(assignment.keyword_args, ExprNodes.DictNode)))
valid_simple_call = (isinstance(assignment, ExprNodes.SimpleCallNode) and not assignment.args)
if not (valid_general_call or valid_simple_call):
error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist "
"of compile-time keyword arguments")
continue
keyword_args = assignment.keyword_args.as_python_dict() if valid_general_call and assignment.keyword_args else {}
if 'default' in keyword_args and 'default_factory' in keyword_args:
error(assignment.pos, "cannot specify both default and default_factory")
continue
field = Field(node.pos, **keyword_args)
else:
if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]:
# The standard library module generates a TypeError at runtime
# in this situation.
# Error message is copied from CPython
error(assignment.pos, "mutable default <class '{0}'> for field {1} is not allowed: "
"use default_factory".format(assignment.type.name, name))
field = Field(node.pos, default=assignment)
else:
field = Field(node.pos)
field.is_initvar = is_initvar
field.is_classvar = is_classvar
if entry.visibility == "private":
field.private = True
fields[name] = field
node.entry.type.dataclass_fields = fields
return fields
def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
# default argument values from https://docs.python.org/3/library/dataclasses.html
kwargs = dict(init=True, repr=True, eq=True,
order=False, unsafe_hash=False,
frozen=False, kw_only=False)
if dataclass_args is not None:
if dataclass_args[0]:
error(node.pos, "cython.dataclasses.dataclass takes no positional arguments")
for k, v in dataclass_args[1].items():
if k not in kwargs:
error(node.pos,
"cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k)
if not isinstance(v, ExprNodes.BoolNode):
error(node.pos,
"Arguments passed to cython.dataclasses.dataclass must be True or False")
kwargs[k] = v.value
kw_only = kwargs['kw_only']
fields = process_class_get_fields(node)
dataclass_module = make_dataclasses_module_callnode(node.pos)
# create __dataclass_params__ attribute. I try to use the exact
# `_DataclassParams` class defined in the standard library module if at all possible
# for maximum duck-typing compatibility.
dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
attribute=EncodedString("_DataclassParams"))
dataclass_params_keywords = ExprNodes.DictNode.from_pairs(
node.pos,
[ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
ExprNodes.BoolNode(node.pos, value=v))
for k, v in kwargs.items() ] +
[ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
ExprNodes.BoolNode(node.pos, value=v))
for k, v in [('kw_only', kw_only), ('match_args', False),
('slots', False), ('weakref_slot', False)]
])
dataclass_params = make_dataclass_call_helper(
node.pos, dataclass_params_func, dataclass_params_keywords)
dataclass_params_assignment = Nodes.SingleAssignmentNode(
node.pos,
lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")),
rhs = dataclass_params)
dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module)
stats = Nodes.StatListNode(node.pos,
stats=[dataclass_params_assignment] + dataclass_fields_stats)
code = TemplateCode()
generate_init_code(code, kwargs['init'], node, fields, kw_only)
generate_repr_code(code, kwargs['repr'], node, fields)
generate_eq_code(code, kwargs['eq'], node, fields)
generate_order_code(code, kwargs['order'], node, fields)
generate_hash_code(code, kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields)
stats.stats += code.generate_tree().stats
# turn off annotation typing, so all arguments to __init__ are accepted as
# generic objects and thus can accept _HAS_DEFAULT_FACTORY.
# Type conversion comes later
comp_directives = Nodes.CompilerDirectivesNode(node.pos,
directives=copy_inherited_directives(node.scope.directives, annotation_typing=False),
body=stats)
comp_directives.analyse_declarations(node.scope)
# probably already in this scope, but it doesn't hurt to make sure
analyse_decs_transform.enter_scope(node, node.scope)
analyse_decs_transform.visit(comp_directives)
analyse_decs_transform.exit_scope()
node.body.stats.append(comp_directives)
def generate_init_code(code, init, node, fields, kw_only):
"""
Notes on CPython generated "__init__":
* Implemented in `_init_fn`.
* The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
the default argument for fields that need constructing with a factory
function is copied from the CPython implementation. (`None` isn't
suitable because it could also be a value for the user to pass.)
There's no real reason why it needs importing from the dataclasses module
though - it could equally be a value generated by Cython when the module loads.
* seen_default and the associated error message are copied directly from Python
* Call to user-defined __post_init__ function (if it exists) is copied from
CPython.
Cython behaviour deviates a little here (to be decided if this is right...)
Because the class variable from the assignment does not exist Cython fields will
return None (or whatever their type default is) if not initialized while Python
dataclasses will fall back to looking up the class variable.
"""
if not init or node.scope.lookup_here("__init__"):
return
# selfname behaviour copied from the cpython module
selfname = "__dataclass_self__" if "self" in fields else "self"
args = [selfname]
if kw_only:
args.append("*")
function_start_point = code.insertion_point()
code = code.insertion_point()
# create a temp to get _HAS_DEFAULT_FACTORY
dataclass_module = make_dataclasses_module_callnode(node.pos)
has_default_factory = ExprNodes.AttributeNode(
node.pos,
obj=dataclass_module,
attribute=EncodedString("_HAS_DEFAULT_FACTORY")
)
default_factory_placeholder = code.new_placeholder(fields, has_default_factory)
seen_default = False
for name, field in fields.items():
entry = node.scope.lookup(name)
if entry.annotation:
annotation = u": %s" % entry.annotation.string.value
else:
annotation = u""
assignment = u''
if field.default is not MISSING or field.default_factory is not MISSING:
seen_default = True
if field.default_factory is not MISSING:
ph_name = default_factory_placeholder
else:
ph_name = code.new_placeholder(fields, field.default) # 'default' should be a node
assignment = u" = %s" % ph_name
elif seen_default and not kw_only and field.init.value:
error(entry.pos, ("non-default argument '%s' follows default argument "
"in dataclass __init__") % name)
code.reset()
return
if field.init.value:
args.append(u"%s%s%s" % (name, annotation, assignment))
if field.is_initvar:
continue
elif field.default_factory is MISSING:
if field.init.value:
code.add_code_line(u" %s.%s = %s" % (selfname, name, name))
elif assignment:
# not an argument to the function, but is still initialized
code.add_code_line(u" %s.%s%s" % (selfname, name, assignment))
else:
ph_name = code.new_placeholder(fields, field.default_factory)
if field.init.value:
# close to:
# def __init__(self, name=_PLACEHOLDER_VALUE):
# self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
code.add_code_line(u" %s.%s = %s() if %s is %s else %s" % (
selfname, name, ph_name, name, default_factory_placeholder, name))
else:
# still need to use the default factory to initialize
code.add_code_line(u" %s.%s = %s()" % (
selfname, name, ph_name))
if node.scope.lookup("__post_init__"):
post_init_vars = ", ".join(name for name, field in fields.items()
if field.is_initvar)
code.add_code_line(" %s.__post_init__(%s)" % (selfname, post_init_vars))
if code.empty():
code.add_code_line(" pass")
args = u", ".join(args)
function_start_point.add_code_line(u"def __init__(%s):" % args)
def generate_repr_code(code, repr, node, fields):
"""
The core of the CPython implementation is just:
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'],
The only notable difference here is self.__class__.__qualname__ -> type(self).__name__
which is because Cython currently supports Python 2.
However, it also has some guards for recursive repr invocations. In the standard
library implementation they're done with a wrapper decorator that captures a set
(with the set keyed by id and thread). Here we create a set as a thread local
variable and key only by id.
"""
if not repr or node.scope.lookup("__repr__"):
return
# The recursive guard is likely a little costly, so skip it if possible.
# is_gc_simple defines where it can contain recursive objects
needs_recursive_guard = False
for name in fields.keys():
entry = node.scope.lookup(name)
type_ = entry.type
if type_.is_memoryviewslice:
type_ = type_.dtype
if not type_.is_pyobject:
continue # no GC
if not type_.is_gc_simple:
needs_recursive_guard = True
break
if needs_recursive_guard:
code.add_code_line("__pyx_recursive_repr_guard = __import__('threading').local()")
code.add_code_line("__pyx_recursive_repr_guard.running = set()")
code.add_code_line("def __repr__(self):")
if needs_recursive_guard:
code.add_code_line(" key = id(self)")
code.add_code_line(" guard_set = self.__pyx_recursive_repr_guard.running")
code.add_code_line(" if key in guard_set: return '...'")
code.add_code_line(" guard_set.add(key)")
code.add_code_line(" try:")
strs = [u"%s={self.%s!r}" % (name, name)
for name, field in fields.items()
if field.repr.value and not field.is_initvar]
format_string = u", ".join(strs)
code.add_code_line(u' name = getattr(type(self), "__qualname__", type(self).__name__)')
code.add_code_line(u" return f'{name}(%s)'" % format_string)
if needs_recursive_guard:
code.add_code_line(" finally:")
code.add_code_line(" guard_set.remove(key)")
def generate_cmp_code(code, op, funcname, node, fields):
if node.scope.lookup_here(funcname):
return
names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]
code.add_code_lines([
"def %s(self, other):" % funcname,
" if other.__class__ is not self.__class__:"
" return NotImplemented",
#
" cdef %s other_cast" % node.class_name,
" other_cast = <%s>other" % node.class_name,
])
# The Python implementation of dataclasses.py does a tuple comparison
# (roughly):
# return self._attributes_to_tuple() {op} other._attributes_to_tuple()
#
# For the Cython implementation a tuple comparison isn't an option because
# not all attributes can be converted to Python objects and stored in a tuple
#
# TODO - better diagnostics of whether the types support comparison before
# generating the code. Plus, do we want to convert C structs to dicts and
# compare them that way (I think not, but it might be in demand)?
checks = []
op_without_equals = op.replace('=', '')
for name in names:
if op != '==':
# tuple comparison rules - early elements take precedence
code.add_code_line(" if self.%s %s other_cast.%s: return True" % (
name, op_without_equals, name))
code.add_code_line(" if self.%s != other_cast.%s: return False" % (
name, name))
if "=" in op:
code.add_code_line(" return True") # "() == ()" is True
else:
code.add_code_line(" return False")
def generate_eq_code(code, eq, node, fields):
if not eq:
return
generate_cmp_code(code, "==", "__eq__", node, fields)
def generate_order_code(code, order, node, fields):
if not order:
return
for op, name in [("<", "__lt__"),
("<=", "__le__"),
(">", "__gt__"),
(">=", "__ge__")]:
generate_cmp_code(code, op, name, node, fields)
def generate_hash_code(code, unsafe_hash, eq, frozen, node, fields):
"""
Copied from CPython implementation - the intention is to follow this as far as
is possible:
# +------------------- unsafe_hash= parameter
# | +----------- eq= parameter
# | | +--- frozen= parameter
# | | |
# v v v | | |
# | no | yes | <--- class has explicitly defined __hash__
# +=======+=======+=======+========+========+
# | False | False | False | | | No __eq__, use the base class __hash__
# +-------+-------+-------+--------+--------+
# | False | False | True | | | No __eq__, use the base class __hash__
# +-------+-------+-------+--------+--------+
# | False | True | False | None | | <-- the default, not hashable
# +-------+-------+-------+--------+--------+
# | False | True | True | add | | Frozen, so hashable, allows override
# +-------+-------+-------+--------+--------+
# | True | False | False | add | raise | Has no __eq__, but hashable
# +-------+-------+-------+--------+--------+
# | True | False | True | add | raise | Has no __eq__, but hashable
# +-------+-------+-------+--------+--------+
# | True | True | False | add | raise | Not frozen, but hashable
# +-------+-------+-------+--------+--------+
# | True | True | True | add | raise | Frozen, so hashable
# +=======+=======+=======+========+========+
# For boxes that are blank, __hash__ is untouched and therefore
# inherited from the base class. If the base is object, then
# id-based hashing is used.
The Python implementation creates a tuple of all the fields, then hashes them.
This implementation creates a tuple of all the hashes of all the fields and hashes that.
The reason for this slight difference is to avoid to-Python conversions for anything
that Cython knows how to hash directly (It doesn't look like this currently applies to
anything though...).
"""
hash_entry = node.scope.lookup_here("__hash__")
if hash_entry:
# TODO ideally assignment of __hash__ to None shouldn't trigger this
# but difficult to get the right information here
if unsafe_hash:
# error message taken from CPython dataclasses module
error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name)
return
if not unsafe_hash:
if not eq:
return
if not frozen:
code.add_extra_statements([
Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
rhs=ExprNodes.NoneNode(node.pos),
)
])
return
names = [
name for name, field in fields.items()
if not field.is_initvar and (
field.compare.value if field.hash.value is None else field.hash.value)
]
# make a tuple of the hashes
hash_tuple_items = u", ".join(u"self.%s" % name for name in names)
if hash_tuple_items:
hash_tuple_items += u"," # ensure that one arg form is a tuple
# if we're here we want to generate a hash
code.add_code_lines([
"def __hash__(self):",
" return hash((%s))" % hash_tuple_items,
])
def get_field_type(pos, entry):
"""
sets the .type attribute for a field
Returns the annotation if possible (since this is what the dataclasses
module does). If not (for example, attributes defined with cdef) then
it creates a string fallback.
"""
if entry.annotation:
# Right now it doesn't look like cdef classes generate an
# __annotations__ dict, therefore it's safe to just return
# entry.annotation
# (TODO: remove .string if we ditch PEP563)
return entry.annotation.string
# If they do in future then we may need to look up into that
# to duplicating the node. The code below should do this:
#class_name_node = ExprNodes.NameNode(pos, name=entry.scope.name)
#annotations = ExprNodes.AttributeNode(
# pos, obj=class_name_node,
# attribute=EncodedString("__annotations__")
#)
#return ExprNodes.IndexNode(
# pos, base=annotations,
# index=ExprNodes.StringNode(pos, value=entry.name)
#)
else:
# it's slightly unclear what the best option is here - we could
# try to return PyType_Type. This case should only happen with
# attributes defined with cdef so Cython is free to make it's own
# decision
s = EncodedString(entry.type.declaration_code("", for_display=1))
return ExprNodes.StringNode(pos, value=s)
class FieldRecordNode(ExprNodes.ExprNode):
"""
__dataclass_fields__ contains a bunch of field objects recording how each field
of the dataclass was initialized (mainly corresponding to the arguments passed to
the "field" function). This node is used for the attributes of these field objects.
If possible, coerces `arg` to a Python object.
Otherwise, generates a sensible backup string.
"""
subexprs = ['arg']
def __init__(self, pos, arg):
super(FieldRecordNode, self).__init__(pos, arg=arg)
def analyse_types(self, env):
self.arg.analyse_types(env)
self.type = self.arg.type
return self
def coerce_to_pyobject(self, env):
if self.arg.type.can_coerce_to_pyobject(env):
return self.arg.coerce_to_pyobject(env)
else:
# A string representation of the code that gave the field seems like a reasonable
# fallback. This'll mostly happen for "default" and "default_factory" where the
# type may be a C-type that can't be converted to Python.
return self._make_string()
def _make_string(self):
from .AutoDocTransforms import AnnotationWriter
writer = AnnotationWriter(description="Dataclass field")
string = writer.write(self.arg)
return ExprNodes.StringNode(self.pos, value=EncodedString(string))
def generate_evaluation_code(self, code):
return self.arg.generate_evaluation_code(code)
def _set_up_dataclass_fields(node, fields, dataclass_module):
# For defaults and default_factories containing things like lambda,
# they're already declared in the class scope, and it creates a big
# problem if multiple copies are floating around in both the __init__
# function, and in the __dataclass_fields__ structure.
# Therefore, create module-level constants holding these values and
# pass those around instead
#
# If possible we use the `Field` class defined in the standard library
# module so that the information stored here is as close to a regular
# dataclass as is possible.
variables_assignment_stats = []
for name, field in fields.items():
if field.private:
continue # doesn't appear in the public interface
for attrname in [ "default", "default_factory" ]:
field_default = getattr(field, attrname)
if field_default is MISSING or field_default.is_literal or field_default.is_name:
# some simple cases where we don't need to set up
# the variable as a module-level constant
continue
global_scope = node.scope.global_scope()
module_field_name = global_scope.mangle(
global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name),
name)
# create an entry in the global scope for this variable to live
field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name))
field_node.entry = global_scope.declare_var(
field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
pos=field_default.pos, cname=field_node.name, is_cdef=True,
# TODO: do we need to set 'pytyping_modifiers' here?
)
# replace the field so that future users just receive the namenode
setattr(field, attrname, field_node)
variables_assignment_stats.append(
Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default))
placeholders = {}
field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
attribute=EncodedString("field"))
dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[])
dc_fields_namevalue_assignments = []
for name, field in fields.items():
if field.private:
continue # doesn't appear in the public interface
type_placeholder_name = "PLACEHOLDER_%s" % name
placeholders[type_placeholder_name] = get_field_type(
node.pos, node.scope.entries[name]
)
# defining these make the fields introspect more like a Python dataclass
field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name
if field.is_initvar:
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD_INITVAR")
)
elif field.is_classvar:
# TODO - currently this isn't triggered
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD_CLASSVAR")
)
else:
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD")
)
dc_field_keywords = ExprNodes.DictNode.from_pairs(
node.pos,
[(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
FieldRecordNode(node.pos, arg=v))
for k, v in field.iterate_record_node_arguments()]
)
dc_field_call = make_dataclass_call_helper(
node.pos, field_func, dc_field_keywords
)
dc_fields.key_value_pairs.append(
ExprNodes.DictItemNode(
node.pos,
key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)),
value=dc_field_call))
dc_fields_namevalue_assignments.append(
dedent(u"""\
__dataclass_fields__[{0!r}].name = {0!r}
__dataclass_fields__[{0!r}].type = {1}
__dataclass_fields__[{0!r}]._field_type = {2}
""").format(name, type_placeholder_name, field_type_placeholder_name))
dataclass_fields_assignment = \
Nodes.SingleAssignmentNode(node.pos,
lhs = ExprNodes.NameNode(node.pos,
name=EncodedString("__dataclass_fields__")),
rhs = dc_fields)
dc_fields_namevalue_assignments = u"\n".join(dc_fields_namevalue_assignments)
dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments,
level="c_class",
pipeline=[NormalizeTree(None)])
dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders)
return (variables_assignment_stats
+ [dataclass_fields_assignment]
+ dc_fields_namevalue_assignments.stats)