import jax.numpy as jnp import jax import torch from dataclasses import dataclass import sympy import sympy as sp from sympy import Matrix, Symbol import math from sde_redefined_param import SDEDimension @dataclass class SDEPolynomialConfig: name = "Custom" variable = Symbol('t', nonnegative=True, real=True) drift_dimension = SDEDimension.SCALAR diffusion_dimension = SDEDimension.SCALAR diffusion_matrix_dimension = SDEDimension.SCALAR drift_degree = 20 diffusion_degree = 20 drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True)]) # square parameters to ensure positive definiteness diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True)]) @property def drift(self): return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit())) @property def diffusion(self): return sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.diffusion_degree+1)]]), self.diffusion_parameters.applyfunc(lambda x: x**2)).doit()) # TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t)) diffusion_matrix = 1 initial_variable_value = 0 max_variable_value = 1 # math.inf min_sample_value = 1e-6 module = 'jax' drift_integral_form=True diffusion_integral_form=True diffusion_integral_decomposition = 'cholesky' # ldl target = "epsilon" # x0