Create scheduler_config.py
Browse files
scheduler/scheduler_config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jax.numpy as jnp
|
2 |
+
import jax
|
3 |
+
import torch
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import sympy
|
6 |
+
import sympy as sp
|
7 |
+
from sympy import Matrix, Symbol
|
8 |
+
import math
|
9 |
+
from sde_redefined_param import SDEDimension
|
10 |
+
@dataclass
|
11 |
+
class SDEBaseLineConfig:
|
12 |
+
name = "Custom"
|
13 |
+
variable = Symbol('t', nonnegative=True, real=True)
|
14 |
+
|
15 |
+
drift_dimension = SDEDimension.SCALAR
|
16 |
+
diffusion_dimension = SDEDimension.SCALAR
|
17 |
+
diffusion_matrix_dimension = SDEDimension.SCALAR
|
18 |
+
|
19 |
+
# TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø
|
20 |
+
drift_parameters = Matrix([sympy.symbols("f1")])
|
21 |
+
diffusion_parameters = Matrix([sympy.symbols("l1")])
|
22 |
+
|
23 |
+
drift = 0
|
24 |
+
|
25 |
+
sigma_min = 0.002
|
26 |
+
sigma_max = 80
|
27 |
+
diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.log(sigma_max/sigma_min))
|
28 |
+
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
|
29 |
+
diffusion_matrix = 1
|
30 |
+
|
31 |
+
initial_variable_value = 0
|
32 |
+
max_variable_value = 1 # math.inf
|
33 |
+
min_sample_value = 1e-6
|
34 |
+
|
35 |
+
module = 'jax'
|
36 |
+
|
37 |
+
drift_integral_form=False
|
38 |
+
diffusion_integral_form=False
|
39 |
+
diffusion_integral_decomposition = 'cholesky' # ldl
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
target = "epsilon" # x0
|