| ''' |
| Use llvmlite to create executable functions from SymPy expressions |
| |
| This module requires llvmlite (https://github.com/numba/llvmlite). |
| ''' |
|
|
| import ctypes |
|
|
| from sympy.external import import_module |
| from sympy.printing.printer import Printer |
| from sympy.core.singleton import S |
| from sympy.tensor.indexed import IndexedBase |
| from sympy.utilities.decorator import doctest_depends_on |
|
|
| llvmlite = import_module('llvmlite') |
| if llvmlite: |
| ll = import_module('llvmlite.ir').ir |
| llvm = import_module('llvmlite.binding').binding |
| llvm.initialize() |
| llvm.initialize_native_target() |
| llvm.initialize_native_asmprinter() |
|
|
|
|
| __doctest_requires__ = {('llvm_callable'): ['llvmlite']} |
|
|
|
|
| class LLVMJitPrinter(Printer): |
| '''Convert expressions to LLVM IR''' |
| def __init__(self, module, builder, fn, *args, **kwargs): |
| self.func_arg_map = kwargs.pop("func_arg_map", {}) |
| if not llvmlite: |
| raise ImportError("llvmlite is required for LLVMJITPrinter") |
| super().__init__(*args, **kwargs) |
| self.fp_type = ll.DoubleType() |
| self.module = module |
| self.builder = builder |
| self.fn = fn |
| self.ext_fn = {} |
| self.tmp_var = {} |
|
|
| def _add_tmp_var(self, name, value): |
| self.tmp_var[name] = value |
|
|
| def _print_Number(self, n): |
| return ll.Constant(self.fp_type, float(n)) |
|
|
| def _print_Integer(self, expr): |
| return ll.Constant(self.fp_type, float(expr.p)) |
|
|
| def _print_Symbol(self, s): |
| val = self.tmp_var.get(s) |
| if not val: |
| |
| val = self.func_arg_map.get(s) |
| if not val: |
| raise LookupError("Symbol not found: %s" % s) |
| return val |
|
|
| def _print_Pow(self, expr): |
| base0 = self._print(expr.base) |
| if expr.exp == S.NegativeOne: |
| return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0) |
| if expr.exp == S.Half: |
| fn = self.ext_fn.get("sqrt") |
| if not fn: |
| fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) |
| fn = ll.Function(self.module, fn_type, "sqrt") |
| self.ext_fn["sqrt"] = fn |
| return self.builder.call(fn, [base0], "sqrt") |
| if expr.exp == 2: |
| return self.builder.fmul(base0, base0) |
|
|
| exp0 = self._print(expr.exp) |
| fn = self.ext_fn.get("pow") |
| if not fn: |
| fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) |
| fn = ll.Function(self.module, fn_type, "pow") |
| self.ext_fn["pow"] = fn |
| return self.builder.call(fn, [base0, exp0], "pow") |
|
|
| def _print_Mul(self, expr): |
| nodes = [self._print(a) for a in expr.args] |
| e = nodes[0] |
| for node in nodes[1:]: |
| e = self.builder.fmul(e, node) |
| return e |
|
|
| def _print_Add(self, expr): |
| nodes = [self._print(a) for a in expr.args] |
| e = nodes[0] |
| for node in nodes[1:]: |
| e = self.builder.fadd(e, node) |
| return e |
|
|
| |
| |
| def _print_Function(self, expr): |
| name = expr.func.__name__ |
| e0 = self._print(expr.args[0]) |
| fn = self.ext_fn.get(name) |
| if not fn: |
| fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) |
| fn = ll.Function(self.module, fn_type, name) |
| self.ext_fn[name] = fn |
| return self.builder.call(fn, [e0], name) |
|
|
| def emptyPrinter(self, expr): |
| raise TypeError("Unsupported type for LLVM JIT conversion: %s" |
| % type(expr)) |
|
|
|
|
| |
| |
| class LLVMJitCallbackPrinter(LLVMJitPrinter): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def _print_Indexed(self, expr): |
| array, idx = self.func_arg_map[expr.base] |
| offset = int(expr.indices[0].evalf()) |
| array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)]) |
| fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type)) |
| value = self.builder.load(fp_array_ptr) |
| return value |
|
|
| def _print_Symbol(self, s): |
| val = self.tmp_var.get(s) |
| if val: |
| return val |
|
|
| array, idx = self.func_arg_map.get(s, [None, 0]) |
| if not array: |
| raise LookupError("Symbol not found: %s" % s) |
| array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)]) |
| fp_array_ptr = self.builder.bitcast(array_ptr, |
| ll.PointerType(self.fp_type)) |
| value = self.builder.load(fp_array_ptr) |
| return value |
|
|
|
|
| |
| |
| exe_engines = [] |
|
|
| |
| link_names = set() |
| current_link_suffix = 0 |
|
|
|
|
| class LLVMJitCode: |
| def __init__(self, signature): |
| self.signature = signature |
| self.fp_type = ll.DoubleType() |
| self.module = ll.Module('mod1') |
| self.fn = None |
| self.llvm_arg_types = [] |
| self.llvm_ret_type = self.fp_type |
| self.param_dict = {} |
| self.link_name = '' |
|
|
| def _from_ctype(self, ctype): |
| if ctype == ctypes.c_int: |
| return ll.IntType(32) |
| if ctype == ctypes.c_double: |
| return self.fp_type |
| if ctype == ctypes.POINTER(ctypes.c_double): |
| return ll.PointerType(self.fp_type) |
| if ctype == ctypes.c_void_p: |
| return ll.PointerType(ll.IntType(32)) |
| if ctype == ctypes.py_object: |
| return ll.PointerType(ll.IntType(32)) |
|
|
| print("Unhandled ctype = %s" % str(ctype)) |
|
|
| def _create_args(self, func_args): |
| """Create types for function arguments""" |
| self.llvm_ret_type = self._from_ctype(self.signature.ret_type) |
| self.llvm_arg_types = \ |
| [self._from_ctype(a) for a in self.signature.arg_ctypes] |
|
|
| def _create_function_base(self): |
| """Create function with name and type signature""" |
| global current_link_suffix |
| default_link_name = 'jit_func' |
| current_link_suffix += 1 |
| self.link_name = default_link_name + str(current_link_suffix) |
| link_names.add(self.link_name) |
|
|
| fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types) |
| self.fn = ll.Function(self.module, fn_type, name=self.link_name) |
|
|
| def _create_param_dict(self, func_args): |
| """Mapping of symbolic values to function arguments""" |
| for i, a in enumerate(func_args): |
| self.fn.args[i].name = str(a) |
| self.param_dict[a] = self.fn.args[i] |
|
|
| def _create_function(self, expr): |
| """Create function body and return LLVM IR""" |
| bb_entry = self.fn.append_basic_block('entry') |
| builder = ll.IRBuilder(bb_entry) |
|
|
| lj = LLVMJitPrinter(self.module, builder, self.fn, |
| func_arg_map=self.param_dict) |
|
|
| ret = self._convert_expr(lj, expr) |
| lj.builder.ret(self._wrap_return(lj, ret)) |
|
|
| strmod = str(self.module) |
| return strmod |
|
|
| def _wrap_return(self, lj, vals): |
| |
| |
|
|
| |
| if self.signature.ret_type == ctypes.c_double: |
| return vals[0] |
|
|
| |
| void_ptr = ll.PointerType(ll.IntType(32)) |
|
|
| |
| wrap_type = ll.FunctionType(void_ptr, [self.fp_type]) |
| wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble") |
|
|
| wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals] |
| if len(vals) == 1: |
| final_val = wrapped_vals[0] |
| else: |
| |
|
|
| |
| tuple_arg_types = [ll.IntType(32)] |
|
|
| tuple_arg_types.extend([void_ptr]*len(vals)) |
| tuple_type = ll.FunctionType(void_ptr, tuple_arg_types) |
| tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack") |
|
|
| tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))] |
| tuple_args.extend(wrapped_vals) |
|
|
| final_val = lj.builder.call(tuple_fn, tuple_args) |
|
|
| return final_val |
|
|
| def _convert_expr(self, lj, expr): |
| try: |
| |
| if len(expr) == 2: |
| tmp_exprs = expr[0] |
| final_exprs = expr[1] |
| if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double: |
| raise NotImplementedError("Return of multiple expressions not supported for this callback") |
| for name, e in tmp_exprs: |
| val = lj._print(e) |
| lj._add_tmp_var(name, val) |
| except TypeError: |
| final_exprs = [expr] |
|
|
| vals = [lj._print(e) for e in final_exprs] |
|
|
| return vals |
|
|
| def _compile_function(self, strmod): |
| llmod = llvm.parse_assembly(strmod) |
|
|
| pmb = llvm.create_pass_manager_builder() |
| pmb.opt_level = 2 |
| pass_manager = llvm.create_module_pass_manager() |
| pmb.populate(pass_manager) |
|
|
| pass_manager.run(llmod) |
|
|
| target_machine = \ |
| llvm.Target.from_default_triple().create_target_machine() |
| exe_eng = llvm.create_mcjit_compiler(llmod, target_machine) |
| exe_eng.finalize_object() |
| exe_engines.append(exe_eng) |
|
|
| if False: |
| print("Assembly") |
| print(target_machine.emit_assembly(llmod)) |
|
|
| fptr = exe_eng.get_function_address(self.link_name) |
|
|
| return fptr |
|
|
|
|
| class LLVMJitCodeCallback(LLVMJitCode): |
| def __init__(self, signature): |
| super().__init__(signature) |
|
|
| def _create_param_dict(self, func_args): |
| for i, a in enumerate(func_args): |
| if isinstance(a, IndexedBase): |
| self.param_dict[a] = (self.fn.args[i], i) |
| self.fn.args[i].name = str(a) |
| else: |
| self.param_dict[a] = (self.fn.args[self.signature.input_arg], |
| i) |
|
|
| def _create_function(self, expr): |
| """Create function body and return LLVM IR""" |
| bb_entry = self.fn.append_basic_block('entry') |
| builder = ll.IRBuilder(bb_entry) |
|
|
| lj = LLVMJitCallbackPrinter(self.module, builder, self.fn, |
| func_arg_map=self.param_dict) |
|
|
| ret = self._convert_expr(lj, expr) |
|
|
| if self.signature.ret_arg: |
| output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg], |
| ll.PointerType(self.fp_type)) |
| for i, val in enumerate(ret): |
| index = ll.Constant(ll.IntType(32), i) |
| output_array_ptr = builder.gep(output_fp_ptr, [index]) |
| builder.store(val, output_array_ptr) |
| builder.ret(ll.Constant(ll.IntType(32), 0)) |
| else: |
| lj.builder.ret(self._wrap_return(lj, ret)) |
|
|
| strmod = str(self.module) |
| return strmod |
|
|
|
|
| class CodeSignature: |
| def __init__(self, ret_type): |
| self.ret_type = ret_type |
| self.arg_ctypes = [] |
|
|
| |
| self.input_arg = 0 |
|
|
| |
| |
| self.ret_arg = None |
|
|
|
|
| def _llvm_jit_code(args, expr, signature, callback_type): |
| """Create a native code function from a SymPy expression""" |
| if callback_type is None: |
| jit = LLVMJitCode(signature) |
| else: |
| jit = LLVMJitCodeCallback(signature) |
|
|
| jit._create_args(args) |
| jit._create_function_base() |
| jit._create_param_dict(args) |
| strmod = jit._create_function(expr) |
| if False: |
| print("LLVM IR") |
| print(strmod) |
| fptr = jit._compile_function(strmod) |
| return fptr |
|
|
|
|
| @doctest_depends_on(modules=('llvmlite', 'scipy')) |
| def llvm_callable(args, expr, callback_type=None): |
| '''Compile function from a SymPy expression |
| |
| Expressions are evaluated using double precision arithmetic. |
| Some single argument math functions (exp, sin, cos, etc.) are supported |
| in expressions. |
| |
| Parameters |
| ========== |
| |
| args : List of Symbol |
| Arguments to the generated function. Usually the free symbols in |
| the expression. Currently each one is assumed to convert to |
| a double precision scalar. |
| expr : Expr, or (Replacements, Expr) as returned from 'cse' |
| Expression to compile. |
| callback_type : string |
| Create function with signature appropriate to use as a callback. |
| Currently supported: |
| 'scipy.integrate' |
| 'scipy.integrate.test' |
| 'cubature' |
| |
| Returns |
| ======= |
| |
| Compiled function that can evaluate the expression. |
| |
| Examples |
| ======== |
| |
| >>> import sympy.printing.llvmjitcode as jit |
| >>> from sympy.abc import a |
| >>> e = a*a + a + 1 |
| >>> e1 = jit.llvm_callable([a], e) |
| >>> e.subs(a, 1.1) # Evaluate via substitution |
| 3.31000000000000 |
| >>> e1(1.1) # Evaluate using JIT-compiled code |
| 3.3100000000000005 |
| |
| |
| Callbacks for integration functions can be JIT compiled. |
| |
| >>> import sympy.printing.llvmjitcode as jit |
| >>> from sympy.abc import a |
| >>> from sympy import integrate |
| >>> from scipy.integrate import quad |
| >>> e = a*a |
| >>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate') |
| >>> integrate(e, (a, 0.0, 2.0)) |
| 2.66666666666667 |
| >>> quad(e1, 0.0, 2.0)[0] |
| 2.66666666666667 |
| |
| The 'cubature' callback is for the Python wrapper around the |
| cubature package ( https://github.com/saullocastro/cubature ) |
| and ( http://ab-initio.mit.edu/wiki/index.php/Cubature ) |
| |
| There are two signatures for the SciPy integration callbacks. |
| The first ('scipy.integrate') is the function to be passed to the |
| integration routine, and will pass the signature checks. |
| The second ('scipy.integrate.test') is only useful for directly calling |
| the function using ctypes variables. It will not pass the signature checks |
| for scipy.integrate. |
| |
| The return value from the cse module can also be compiled. This |
| can improve the performance of the compiled function. If multiple |
| expressions are given to cse, the compiled function returns a tuple. |
| The 'cubature' callback handles multiple expressions (set `fdim` |
| to match in the integration call.) |
| |
| >>> import sympy.printing.llvmjitcode as jit |
| >>> from sympy import cse |
| >>> from sympy.abc import x,y |
| >>> e1 = x*x + y*y |
| >>> e2 = 4*(x*x + y*y) + 8.0 |
| >>> after_cse = cse([e1,e2]) |
| >>> after_cse |
| ([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0]) |
| >>> j1 = jit.llvm_callable([x,y], after_cse) |
| >>> j1(1.0, 2.0) |
| (5.0, 28.0) |
| ''' |
|
|
| if not llvmlite: |
| raise ImportError("llvmlite is required for llvmjitcode") |
|
|
| signature = CodeSignature(ctypes.py_object) |
|
|
| arg_ctypes = [] |
| if callback_type is None: |
| for _ in args: |
| arg_ctype = ctypes.c_double |
| arg_ctypes.append(arg_ctype) |
| elif callback_type in ('scipy.integrate', 'scipy.integrate.test'): |
| signature.ret_type = ctypes.c_double |
| arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)] |
| arg_ctypes_formal = [ctypes.c_int, ctypes.c_double] |
| signature.input_arg = 1 |
| elif callback_type == 'cubature': |
| arg_ctypes = [ctypes.c_int, |
| ctypes.POINTER(ctypes.c_double), |
| ctypes.c_void_p, |
| ctypes.c_int, |
| ctypes.POINTER(ctypes.c_double) |
| ] |
| signature.ret_type = ctypes.c_int |
| signature.input_arg = 1 |
| signature.ret_arg = 4 |
| else: |
| raise ValueError("Unknown callback type: %s" % callback_type) |
|
|
| signature.arg_ctypes = arg_ctypes |
|
|
| fptr = _llvm_jit_code(args, expr, signature, callback_type) |
|
|
| if callback_type and callback_type == 'scipy.integrate': |
| arg_ctypes = arg_ctypes_formal |
|
|
| |
| |
| |
| |
| |
| if signature.ret_type == ctypes.py_object: |
| FUNCTYPE = ctypes.PYFUNCTYPE |
| else: |
| FUNCTYPE = ctypes.CFUNCTYPE |
|
|
| cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr) |
| return cfunc |
|
|