jmemon's picture
Files: Epoch -1
92697e6
raw
history blame
No virus
890 Bytes
from dataclasses import dataclass
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 4
eval_batch_size = 4 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 1e-4
lr_warmup_steps = 500
save_image_epochs = 1
save_model_epochs = 3
mixed_precision = 'fp16' # `no` for float32, `fp16` for automatic mixed precision
output_dir = 'ddpm-paintings-128-finetuned-cifar10' # the model name locally and on the HF Hub
push_to_hub = True # whether to upload the saved model to the HF Hub
hub_model_id = 'jmemon/ddpm-paintings-128-finetuned-cifar10' # the name of the repository to create on the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0