File size: 860 Bytes
92697e6
9d4f27d
92697e6
 
 
 
 
 
 
 
 
 
 
 
 
 
9d4f27d
92697e6
 
9d4f27d
92697e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from dataclasses import dataclass
from pathlib import Path


@dataclass
class TrainingConfig:
    image_size = 128  # the generated image resolution
    train_batch_size = 4
    eval_batch_size = 4  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 1
    save_model_epochs = 3
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = str(Path(__file__).parent)

    push_to_hub = True  # whether to upload the saved model to the HF Hub
    hub_model_id = 'jmemon/ddpm-paintings-128-finetuned-celebahq'  # the name of the repository to create on the HF Hub
    hub_private_repo = False
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0