Spaces:
Sleeping
Sleeping
##### | |
# From https://github.com/patrick-kidger/sympytorch | |
# Copied here to allow PySR-specific tweaks | |
##### | |
import collections as co | |
import functools as ft | |
import sympy | |
def _reduce(fn): | |
def fn_(*args): | |
return ft.reduce(fn, args) | |
return fn_ | |
torch_initialized = False | |
torch = None | |
sympytorch = None | |
PySRTorchModule = None | |
def _initialize_torch(): | |
global torch_initialized | |
global torch | |
global sympytorch | |
global PySRTorchModule | |
# Way to lazy load torch, only if this is called, | |
# but still allow this module to be loaded in __init__ | |
if not torch_initialized: | |
try: | |
import torch | |
import sympytorch | |
except ImportError: | |
raise ImportError("You need to pip install `torch` and `sympytorch` before exporting to pytorch.") | |
torch_initialized = True | |
class PySRTorchModule(torch.nn.Module): | |
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch""" | |
def __init__(self, *, expression, symbols_in, | |
selection=None, extra_funcs=None, **kwargs): | |
super().__init__(**kwargs) | |
self._module = sympytorch.SymPyModule( | |
expressions=[expression], | |
extra_funcs=extra_funcs) | |
self._selection = selection | |
self._symbols = symbols_in | |
def __repr__(self): | |
return f"{type(self).__name__}(expression={self._expression_string})" | |
def forward(self, X): | |
if self._selection is not None: | |
X = X[:, self._selection] | |
symbols = {str(symbol): X[:, i] | |
for i, symbol in enumerate(self._symbols)} | |
return self._module(**symbols)[..., 0] | |
def sympy2torch(expression, symbols_in, | |
selection=None, extra_torch_mappings=None): | |
"""Returns a module for a given sympy expression with trainable parameters; | |
This function will assume the input to the module is a matrix X, where | |
each column corresponds to each symbol you pass in `symbols_in`. | |
""" | |
global PySRTorchModule | |
_initialize_torch() | |
return PySRTorchModule(expression=expression, | |
symbols_in=symbols_in, | |
selection=selection, | |
extra_funcs=extra_torch_mappings) | |