|
|
|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
''' |
|
newtwork_config = { |
|
"epochs": 150, |
|
"batch_size": 250, |
|
"n_steps": 16, # timestep |
|
"dataset": "CAPS", |
|
"in_channels": 1, |
|
"data_path": "./data", |
|
"lr": 0.001, |
|
"n_class": 10, |
|
"latent_dim": 128, |
|
"input_size": 32, |
|
"model": "FSVAE" ,# FSVAE or FSVAE_large |
|
"k": 20, # multiplier of channel |
|
"scheduled": True, # whether to apply scheduled sampling |
|
"loss_func": 'kld', # mmd or kld |
|
"accum_iter" : 1, |
|
"devices": [0], |
|
} |
|
|
|
hidden_dims = [32, 64, 128, 256] |
|
|
|
''' |
|
|
|
class FSAEConfig(PretrainedConfig): |
|
model_type = "fsae" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int = 1, |
|
hidden_dims : List[int] = [32, 64, 128, 256], |
|
k : int = 20, |
|
n_steps : int = 16, |
|
latent_dim : int = 128, |
|
scheduled : bool = True, |
|
|
|
dt:float = 5, |
|
a:float = 0.25, |
|
aa: float = 0.5, |
|
Vth : float = 0.2, |
|
tau : float = 0.25, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
self.in_channels = in_channels |
|
self.hidden_dims = hidden_dims |
|
self.k = k |
|
self.n_steps = n_steps |
|
self.latent_dim = latent_dim |
|
self.scheduled = scheduled |
|
self.dt = dt |
|
self.a = a |
|
self.aa = aa |
|
self.Vth = Vth |
|
self.tau = tau |
|
super().__init__(**kwargs) |