Spaces:
Build error
Build error
| from . import ( | |
| Nodes, | |
| ExprNodes, | |
| FusedNode, | |
| TreeFragment, | |
| Pipeline, | |
| ParseTreeTransforms, | |
| Naming, | |
| UtilNodes, | |
| ) | |
| from .Errors import error | |
| from . import PyrexTypes | |
| from .UtilityCode import CythonUtilityCode | |
| from .Code import TempitaUtilityCode, UtilityCode | |
| from .Visitor import PrintTree, TreeVisitor, VisitorTransform | |
| numpy_int_types = [ | |
| "NPY_BYTE", | |
| "NPY_INT8", | |
| "NPY_SHORT", | |
| "NPY_INT16", | |
| "NPY_INT", | |
| "NPY_INT32", | |
| "NPY_LONG", | |
| "NPY_LONGLONG", | |
| "NPY_INT64", | |
| ] | |
| numpy_uint_types = [tp.replace("NPY_", "NPY_U") for tp in numpy_int_types] | |
| # note: half float type is deliberately omitted | |
| numpy_numeric_types = ( | |
| numpy_int_types | |
| + numpy_uint_types | |
| + [ | |
| "NPY_FLOAT", | |
| "NPY_FLOAT32", | |
| "NPY_DOUBLE", | |
| "NPY_FLOAT64", | |
| "NPY_LONGDOUBLE", | |
| ] | |
| ) | |
| def _get_type_constant(pos, type_): | |
| if type_.is_complex: | |
| # 'is' checks don't seem to work for complex types | |
| if type_ == PyrexTypes.c_float_complex_type: | |
| return "NPY_CFLOAT" | |
| elif type_ == PyrexTypes.c_double_complex_type: | |
| return "NPY_CDOUBLE" | |
| elif type_ == PyrexTypes.c_longdouble_complex_type: | |
| return "NPY_CLONGDOUBLE" | |
| elif type_.is_numeric: | |
| postfix = type_.empty_declaration_code().upper().replace(" ", "") | |
| typename = "NPY_%s" % postfix | |
| if typename in numpy_numeric_types: | |
| return typename | |
| elif type_.is_pyobject: | |
| return "NPY_OBJECT" | |
| # TODO possible NPY_BOOL to bint but it needs a cast? | |
| # TODO NPY_DATETIME, NPY_TIMEDELTA, NPY_STRING, NPY_UNICODE and maybe NPY_VOID might be handleable | |
| error(pos, "Type '%s' cannot be used as a ufunc argument" % type_) | |
| class _FindCFuncDefNode(TreeVisitor): | |
| """ | |
| Finds the CFuncDefNode in the tree | |
| The assumption is that there's only one CFuncDefNode | |
| """ | |
| found_node = None | |
| def visit_Node(self, node): | |
| if self.found_node: | |
| return | |
| else: | |
| self.visitchildren(node) | |
| def visit_CFuncDefNode(self, node): | |
| self.found_node = node | |
| def __call__(self, tree): | |
| self.visit(tree) | |
| return self.found_node | |
| def get_cfunc_from_tree(tree): | |
| return _FindCFuncDefNode()(tree) | |
| class _ArgumentInfo(object): | |
| """ | |
| Everything related to defining an input/output argument for a ufunc | |
| type - PyrexType | |
| type_constant - str such as "NPY_INT8" representing numpy dtype constants | |
| """ | |
| def __init__(self, type, type_constant): | |
| self.type = type | |
| self.type_constant = type_constant | |
| class UFuncConversion(object): | |
| def __init__(self, node): | |
| self.node = node | |
| self.global_scope = node.local_scope.global_scope() | |
| self.in_definitions = self.get_in_type_info() | |
| self.out_definitions = self.get_out_type_info() | |
| def get_in_type_info(self): | |
| definitions = [] | |
| for n, arg in enumerate(self.node.args): | |
| type_const = _get_type_constant(self.node.pos, arg.type) | |
| definitions.append(_ArgumentInfo(arg.type, type_const)) | |
| return definitions | |
| def get_out_type_info(self): | |
| if self.node.return_type.is_ctuple: | |
| components = self.node.return_type.components | |
| else: | |
| components = [self.node.return_type] | |
| definitions = [] | |
| for n, type in enumerate(components): | |
| definitions.append( | |
| _ArgumentInfo(type, _get_type_constant(self.node.pos, type)) | |
| ) | |
| return definitions | |
| def generate_cy_utility_code(self): | |
| arg_types = [a.type for a in self.in_definitions] | |
| out_types = [a.type for a in self.out_definitions] | |
| inline_func_decl = self.node.entry.type.declaration_code( | |
| self.node.entry.cname, pyrex=True | |
| ) | |
| self.node.entry.used = True | |
| ufunc_cname = self.global_scope.next_id(self.node.entry.name + "_ufunc_def") | |
| will_be_called_without_gil = not (any(t.is_pyobject for t in arg_types) or | |
| any(t.is_pyobject for t in out_types)) | |
| context = dict( | |
| func_cname=ufunc_cname, | |
| in_types=arg_types, | |
| out_types=out_types, | |
| inline_func_call=self.node.entry.cname, | |
| inline_func_declaration=inline_func_decl, | |
| nogil=self.node.entry.type.nogil, | |
| will_be_called_without_gil=will_be_called_without_gil, | |
| ) | |
| code = CythonUtilityCode.load( | |
| "UFuncDefinition", | |
| "UFuncs.pyx", | |
| context=context, | |
| outer_module_scope=self.global_scope, | |
| ) | |
| tree = code.get_tree(entries_only=True) | |
| return tree | |
| def use_generic_utility_code(self): | |
| # use the invariant C utility code | |
| self.global_scope.use_utility_code( | |
| UtilityCode.load_cached("UFuncsInit", "UFuncs_C.c") | |
| ) | |
| self.global_scope.use_utility_code( | |
| UtilityCode.load_cached("NumpyImportUFunc", "NumpyImportArray.c") | |
| ) | |
| def convert_to_ufunc(node): | |
| if isinstance(node, Nodes.CFuncDefNode): | |
| if node.local_scope.parent_scope.is_c_class_scope: | |
| error(node.pos, "Methods cannot currently be converted to a ufunc") | |
| return node | |
| converters = [UFuncConversion(node)] | |
| original_node = node | |
| elif isinstance(node, FusedNode.FusedCFuncDefNode) and isinstance( | |
| node.node, Nodes.CFuncDefNode | |
| ): | |
| if node.node.local_scope.parent_scope.is_c_class_scope: | |
| error(node.pos, "Methods cannot currently be converted to a ufunc") | |
| return node | |
| converters = [UFuncConversion(n) for n in node.nodes] | |
| original_node = node.node | |
| else: | |
| error(node.pos, "Only C functions can be converted to a ufunc") | |
| return node | |
| if not converters: | |
| return # this path probably shouldn't happen | |
| del converters[0].global_scope.entries[original_node.entry.name] | |
| # the generic utility code is generic, so there's no reason to do it multiple times | |
| converters[0].use_generic_utility_code() | |
| return [node] + _generate_stats_from_converters(converters, original_node) | |
| def generate_ufunc_initialization(converters, cfunc_nodes, original_node): | |
| global_scope = converters[0].global_scope | |
| ufunc_funcs_name = global_scope.next_id(Naming.pyrex_prefix + "funcs") | |
| ufunc_types_name = global_scope.next_id(Naming.pyrex_prefix + "types") | |
| ufunc_data_name = global_scope.next_id(Naming.pyrex_prefix + "data") | |
| type_constants = [] | |
| narg_in = None | |
| narg_out = None | |
| for c in converters: | |
| in_const = [d.type_constant for d in c.in_definitions] | |
| if narg_in is not None: | |
| assert narg_in == len(in_const) | |
| else: | |
| narg_in = len(in_const) | |
| type_constants.extend(in_const) | |
| out_const = [d.type_constant for d in c.out_definitions] | |
| if narg_out is not None: | |
| assert narg_out == len(out_const) | |
| else: | |
| narg_out = len(out_const) | |
| type_constants.extend(out_const) | |
| func_cnames = [cfnode.entry.cname for cfnode in cfunc_nodes] | |
| context = dict( | |
| ufunc_funcs_name=ufunc_funcs_name, | |
| func_cnames=func_cnames, | |
| ufunc_types_name=ufunc_types_name, | |
| type_constants=type_constants, | |
| ufunc_data_name=ufunc_data_name, | |
| ) | |
| global_scope.use_utility_code( | |
| TempitaUtilityCode.load("UFuncConsts", "UFuncs_C.c", context=context) | |
| ) | |
| pos = original_node.pos | |
| func_name = original_node.entry.name | |
| docstr = original_node.doc | |
| args_to_func = '%s(), %s, %s(), %s, %s, %s, PyUFunc_None, "%s", %s, 0' % ( | |
| ufunc_funcs_name, | |
| ufunc_data_name, | |
| ufunc_types_name, | |
| len(func_cnames), | |
| narg_in, | |
| narg_out, | |
| func_name, | |
| docstr.as_c_string_literal() if docstr else "NULL", | |
| ) | |
| call_node = ExprNodes.PythonCapiCallNode( | |
| pos, | |
| function_name="PyUFunc_FromFuncAndData", | |
| # use a dummy type because it's honestly too fiddly | |
| func_type=PyrexTypes.CFuncType( | |
| PyrexTypes.py_object_type, | |
| [PyrexTypes.CFuncTypeArg("dummy", PyrexTypes.c_void_ptr_type, None)], | |
| ), | |
| args=[ | |
| ExprNodes.ConstNode( | |
| pos, type=PyrexTypes.c_void_ptr_type, value=args_to_func | |
| ) | |
| ], | |
| ) | |
| lhs_entry = global_scope.declare_var(func_name, PyrexTypes.py_object_type, pos) | |
| assgn_node = Nodes.SingleAssignmentNode( | |
| pos, | |
| lhs=ExprNodes.NameNode( | |
| pos, name=func_name, type=PyrexTypes.py_object_type, entry=lhs_entry | |
| ), | |
| rhs=call_node, | |
| ) | |
| return assgn_node | |
| def _generate_stats_from_converters(converters, node): | |
| stats = [] | |
| for converter in converters: | |
| tree = converter.generate_cy_utility_code() | |
| ufunc_node = get_cfunc_from_tree(tree) | |
| # merge in any utility code | |
| converter.global_scope.utility_code_list.extend(tree.scope.utility_code_list) | |
| stats.append(ufunc_node) | |
| stats.append(generate_ufunc_initialization(converters, stats, node)) | |
| return stats | |