|
import jax.numpy as jnp |
|
import jax |
|
import torch |
|
from dataclasses import dataclass |
|
import sympy |
|
import sympy as sp |
|
from sympy import Matrix, Symbol |
|
import math |
|
from sde_redefined_param import SDEDimension |
|
@dataclass |
|
class SDEPolynomialConfig: |
|
name = "Custom" |
|
variable = Symbol('t', nonnegative=True, real=True) |
|
drift_dimension = SDEDimension.SCALAR |
|
diffusion_dimension = SDEDimension.SCALAR |
|
diffusion_matrix_dimension = SDEDimension.SCALAR |
|
|
|
drift_degree = 20 |
|
diffusion_degree = 20 |
|
|
|
drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True)]) |
|
|
|
|
|
diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True)]) |
|
|
|
@property |
|
def drift(self): |
|
return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit())) |
|
@property |
|
def diffusion(self): |
|
return sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.diffusion_degree+1)]]), self.diffusion_parameters.applyfunc(lambda x: x**2)).doit()) |
|
|
|
|
|
diffusion_matrix = 1 |
|
|
|
initial_variable_value = 0 |
|
max_variable_value = 1 |
|
min_sample_value = 1e-6 |
|
|
|
module = 'jax' |
|
|
|
drift_integral_form=True |
|
diffusion_integral_form=True |
|
diffusion_integral_decomposition = 'cholesky' |
|
|
|
|
|
|
|
target = "epsilon" |
|
|