File size: 1,099 Bytes
37aeb5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class TrainerSubConfig:
    trainer_type: str = ""
    trainer: dict = field(default_factory=dict)


@dataclass
class ExprimentConfig:
    trainers: List[dict] = field(default_factory=lambda: [])
    init_config: dict = field(default_factory=dict)
    pretrained_model_name_or_path: str = ""
    pretrained_unet_state_dict_path: str = ""
    # expriments related parameters
    linear_beta_schedule: bool = False
    zero_snr: bool = False
    prediction_type: Optional[str] = None
    seed: Optional[int] = None
    max_train_steps: int = 1000000
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-4
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 500
    use_8bit_adam: bool = False
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_weight_decay: float = 1e-2
    adam_epsilon: float = 1e-08
    max_grad_norm: float = 1.0
    mixed_precision: Optional[str] = None       # ["no", "fp16", "bf16", "fp8"]
    skip_training: bool = False
    debug: bool = False