import jax.numpy as jnp import jax import torch from dataclasses import dataclass import sympy import sympy as sp from sympy import Matrix, Symbol import math from sde_redefined_param import SDEDimension @dataclass class SDEParameterizedBaseLineConfig: name = "Custom" variable = Symbol('t', nonnegative=True, real=True) drift_dimension = SDEDimension.SCALAR diffusion_dimension = SDEDimension.SCALAR diffusion_matrix_dimension = SDEDimension.SCALAR # TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø drift_parameters = Matrix([sympy.symbols("f1", real=True)]) diffusion_parameters = Matrix([sympy.symbols("sigma_min sigma_max", real=True)]) drift = 0 sigma_min = sympy.Abs(diffusion_parameters[0]) #0.002 sigma_max = sympy.Abs(diffusion_parameters[1]) #80 diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.Abs(sympy.log(sigma_max/sigma_min))) # TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t)) diffusion_matrix = 1 initial_variable_value = 0 max_variable_value = 1 # math.inf min_sample_value = 0 module = 'jax' drift_integral_form=False diffusion_integral_form=False diffusion_integral_decomposition = 'cholesky' # ldl non_symbolic_parameters = {'diffusion': torch.tensor([0.002, 80.])} target = "epsilon" # x0