Spaces:
Running
Running
File size: 5,753 Bytes
a06bfc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####
import collections as co
import functools as ft
import sympy
import torch
def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_
_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
sympy.div: torch.div,
sympy.Abs: torch.abs,
sympy.sign: torch.sign,
# Note: May raise error for ints.
sympy.ceiling: torch.ceil,
sympy.floor: torch.floor,
sympy.log: torch.log,
sympy.exp: torch.exp,
sympy.sqrt: torch.sqrt,
sympy.cos: torch.cos,
sympy.acos: torch.acos,
sympy.sin: torch.sin,
sympy.asin: torch.asin,
sympy.tan: torch.tan,
sympy.atan: torch.atan,
sympy.atan2: torch.atan2,
# Note: May give NaN for complex results.
sympy.cosh: torch.cosh,
sympy.acosh: torch.acosh,
sympy.sinh: torch.sinh,
sympy.asinh: torch.asinh,
sympy.tanh: torch.tanh,
sympy.atanh: torch.atanh,
sympy.Pow: torch.pow,
sympy.re: torch.real,
sympy.im: torch.imag,
sympy.arg: torch.angle,
# Note: May raise error for ints and complexes
sympy.erf: torch.erf,
sympy.loggamma: torch.lgamma,
sympy.Eq: torch.eq,
sympy.Ne: torch.ne,
sympy.StrictGreaterThan: torch.gt,
sympy.StrictLessThan: torch.lt,
sympy.LessThan: torch.le,
sympy.GreaterThan: torch.ge,
sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,
sympy.Not: torch.logical_not,
sympy.Max: torch.max,
sympy.Min: torch.min,
# Matrices
sympy.MatAdd: torch.add,
sympy.HadamardProduct: torch.mul,
sympy.Trace: torch.trace,
# Note: May raise error for integer matrices.
sympy.Determinant: torch.det,
}
class _Node(torch.nn.Module):
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)
self._sympy_func = expr.func
if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
self.register_buffer('_value', torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
self._torch_func = _func_lookup[expr.func]
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)
def sympy(self, _memodict):
if issubclass(self._sympy_func, sympy.Float):
return self._sympy_func(self._value.item())
elif issubclass(self._sympy_func, sympy.UnevaluatedExpr):
return self._sympy_func(self._value.item())
elif issubclass(self._sympy_func, sympy.Integer):
return self._sympy_func(self._value)
elif issubclass(self._sympy_func, sympy.Symbol):
return self._sympy_func(self._name)
else:
if issubclass(self._sympy_func, (sympy.Min, sympy.Max)):
evaluate = False
else:
evaluate = True
args = []
for arg in self._args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = arg.sympy(_memodict)
_memodict[arg] = arg_
args.append(arg_)
return self._sympy_func(*args, evaluate=evaluate)
def forward(self, memodict):
args = []
for arg in self._args:
try:
arg_ = memodict[arg]
except KeyError:
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)
class SingleSymPyModule(torch.nn.Module):
def __init__(self, expression, symbols_in,
extra_funcs=None, **kwargs):
super().__init__(**kwargs)
if extra_funcs is None:
extra_funcs = {}
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
_memodict = {}
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
self._expression_string = str(expression)
self.symbols_in = [str(symbol) for symbol in symbols_in]
def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"
def sympy(self):
_memodict = {}
return self._node.sympy(_memodict)
def forward(self, X):
symbols = {symbol: X[:, i]
for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)
def sympy2torch(expression, symbols_in):
return SingleSymPyModule(expression, symbols_in)
|