AltLuv commited on
Commit
d270254
1 Parent(s): d2df9fb

Create scheduler_config.py

Browse files
Files changed (1) hide show
  1. scheduler/scheduler_config.py +43 -0
scheduler/scheduler_config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax.numpy as jnp
2
+ import jax
3
+ import torch
4
+ from dataclasses import dataclass
5
+ import sympy
6
+ import sympy as sp
7
+ from sympy import Matrix, Symbol
8
+ import math
9
+ from sde_redefined_param import SDEDimension
10
+ @dataclass
11
+ class SDEBaseLineConfig:
12
+ name = "Custom"
13
+ variable = Symbol('t', nonnegative=True, real=True)
14
+
15
+ drift_dimension = SDEDimension.SCALAR
16
+ diffusion_dimension = SDEDimension.SCALAR
17
+ diffusion_matrix_dimension = SDEDimension.SCALAR
18
+
19
+ # TODO (KLAUS): HANDLE THE PARAMETERS BEING Ø
20
+ drift_parameters = Matrix([sympy.symbols("f1")])
21
+ diffusion_parameters = Matrix([sympy.symbols("l1")])
22
+
23
+ drift = 0
24
+
25
+ sigma_min = 0.002
26
+ sigma_max = 80
27
+ diffusion = sigma_min * (sigma_max/sigma_min)**variable * sympy.sqrt(2 * sympy.log(sigma_max/sigma_min))
28
+ # TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
29
+ diffusion_matrix = 1
30
+
31
+ initial_variable_value = 0
32
+ max_variable_value = 1 # math.inf
33
+ min_sample_value = 1e-6
34
+
35
+ module = 'jax'
36
+
37
+ drift_integral_form=False
38
+ diffusion_integral_form=False
39
+ diffusion_integral_decomposition = 'cholesky' # ldl
40
+
41
+
42
+
43
+ target = "epsilon" # x0