pokemon-polynomial-20 / scheduler /scheduler_config.py
AltLuv's picture
End of training
885da14
raw
history blame contribute delete
No virus
1.53 kB
import jax.numpy as jnp
import jax
import torch
from dataclasses import dataclass
import sympy
import sympy as sp
from sympy import Matrix, Symbol
import math
from sde_redefined_param import SDEDimension
@dataclass
class SDEPolynomialConfig:
name = "Custom"
variable = Symbol('t', nonnegative=True, real=True)
drift_dimension = SDEDimension.SCALAR
diffusion_dimension = SDEDimension.SCALAR
diffusion_matrix_dimension = SDEDimension.SCALAR
drift_degree = 20
diffusion_degree = 20
drift_parameters = Matrix([sympy.symbols(f"f:{drift_degree}", real=True)])
# square parameters to ensure positive definiteness
diffusion_parameters = Matrix([sympy.symbols(f"l:{diffusion_degree}", real=True)])
@property
def drift(self):
return -sympy.Abs(sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.drift_degree+1)]]), self.drift_parameters).doit()))
@property
def diffusion(self):
return sum(sympy.HadamardProduct(Matrix([[self.variable**i for i in range(1,self.diffusion_degree+1)]]), self.diffusion_parameters.applyfunc(lambda x: x**2)).doit())
# TODO (KLAUS) : in the SDE SAMPLING CHANGING Q impacts how we sample z ~ N(0, Q*(delta t))
diffusion_matrix = 1
initial_variable_value = 0
max_variable_value = 1 # math.inf
min_sample_value = 1e-6
module = 'jax'
drift_integral_form=True
diffusion_integral_form=True
diffusion_integral_decomposition = 'cholesky' # ldl
target = "epsilon" # x0