|
from dataclasses import dataclass, field |
|
from typing import List, Optional |
|
|
|
|
|
@dataclass |
|
class TrainerSubConfig: |
|
trainer_type: str = "" |
|
trainer: dict = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
class ExprimentConfig: |
|
trainers: List[dict] = field(default_factory=lambda: []) |
|
init_config: dict = field(default_factory=dict) |
|
pretrained_model_name_or_path: str = "" |
|
pretrained_unet_state_dict_path: str = "" |
|
|
|
linear_beta_schedule: bool = False |
|
zero_snr: bool = False |
|
prediction_type: Optional[str] = None |
|
seed: Optional[int] = None |
|
max_train_steps: int = 1000000 |
|
gradient_accumulation_steps: int = 1 |
|
learning_rate: float = 1e-4 |
|
lr_scheduler: str = "constant" |
|
lr_warmup_steps: int = 500 |
|
use_8bit_adam: bool = False |
|
adam_beta1: float = 0.9 |
|
adam_beta2: float = 0.999 |
|
adam_weight_decay: float = 1e-2 |
|
adam_epsilon: float = 1e-08 |
|
max_grad_norm: float = 1.0 |
|
mixed_precision: Optional[str] = None |
|
skip_training: bool = False |
|
debug: bool = False |