AltLuv's picture
End of training
46439a4
raw
history blame contribute delete
No virus
1.39 kB
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEParameterizedBaseLineConfig:
name = "Custom"
variable = Symbol('t', nonnegative=True, real=True)
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
# TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø
drift_parameters = Matrix([sympy.symbols("f1", real=True)])
diffusion_parameters = Matrix([sympy.symbols("sigma_min sigma_max", real=True)])
drift = 0
sigma_min = sympy.Abs(diffusion_parameters[0]) #0.002
sigma_max = sympy.Abs(diffusion_parameters[1]) #80
diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.Abs(sympy.log(sigma_max/sigma_min)))
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
initial_variable_value = 0
max_variable_value = 1 # math.inf
min_sample_value = 0
module = 'jax'
drift_integral_form=False
diffusion_integral_form=False
diffusion_integral_decomposition = 'cholesky' # ldl
non_symbolic_parameters = {'diffusion': torch.tensor([0.002, 80.])}
target = "epsilon" # x0