Spaces:
Paused
Paused
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| class TrainerSubConfig: | |
| trainer_type: str = "" | |
| trainer: dict = field(default_factory=dict) | |
| class ExprimentConfig: | |
| trainers: List[dict] = field(default_factory=lambda: []) | |
| init_config: dict = field(default_factory=dict) | |
| pretrained_model_name_or_path: str = "" | |
| pretrained_unet_state_dict_path: str = "" | |
| # expriments related parameters | |
| linear_beta_schedule: bool = False | |
| zero_snr: bool = False | |
| prediction_type: Optional[str] = None | |
| seed: Optional[int] = None | |
| max_train_steps: int = 1000000 | |
| gradient_accumulation_steps: int = 1 | |
| learning_rate: float = 1e-4 | |
| lr_scheduler: str = "constant" | |
| lr_warmup_steps: int = 500 | |
| use_8bit_adam: bool = False | |
| adam_beta1: float = 0.9 | |
| adam_beta2: float = 0.999 | |
| adam_weight_decay: float = 1e-2 | |
| adam_epsilon: float = 1e-08 | |
| max_grad_norm: float = 1.0 | |
| mixed_precision: Optional[str] = None # ["no", "fp16", "bf16", "fp8"] | |
| skip_training: bool = False | |
| debug: bool = False |