import jax.numpy as jnp import jax import torch from dataclasses import dataclass import sympy import sympy as sp from sympy import Matrix, Symbol import math from sde_redefined_param import SDEDimension @dataclass class SDEPolynomialConfig: name = "Custom" initial_variable_value = 0 max_variable_value = 1# math.inf min_sample_value = 1e-6 variable = Symbol('t', nonnegative=True, real=True, domain=sympy.Interval(initial_variable_value, max_variable_value, left_open=False, right_open=False)) drift_dimension = SDEDimension.SCALAR diffusion_dimension = SDEDimension.SCALAR diffusion_matrix_dimension = SDEDimension.SCALAR drift_degree = 20 diffusion_degree = 20 drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True, nonzero=True)]) diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True, nonzero=True)]) @property def drift(self): transformed_variable = self.variable return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[transformed_variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit())) @property def diffusion(self): return self.variable**(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(0,self.diffusion_degree)]]),self.diffusion_parameters.applyfunc(lambda x: x**2)).doit())) # TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t)) diffusion_matrix = 1 module = 'jax' drift_integral_form=True diffusion_integral_form=True diffusion_integral_decomposition = 'cholesky' # ldl target = "epsilon" # x0 non_symbolic_parameters = {'drift': torch.ones(drift_degree), 'diffusion': torch.ones(diffusion_degree)}