seed: 2022 | |
data: | |
target: datasets.cifar10.CIFAR10 | |
params: | |
root: ~/data/CIFAR-10/ | |
img_size: 32 | |
img_channels: 3 | |
num_classes: 10 | |
dataloader: | |
num_workers: 4 | |
pin_memory: true | |
prefetch_factor: 2 | |
model: | |
target: models.unet_categorial_adagn.UNetCategorialAdaGN | |
params: | |
in_channels: 3 | |
out_channels: 3 | |
dim: 128 | |
dim_mults: | |
- 1 | |
- 2 | |
- 2 | |
- 2 | |
use_attn: | |
- false | |
- true | |
- true | |
- false | |
num_res_blocks: 2 | |
num_classes: 10 | |
attn_head_dims: 64 | |
resblock_updown: true | |
dropout: 0.1 | |
diffusion: | |
target: diffusions.cfg.ddpm_cfg.DDPMCFG | |
params: | |
total_steps: 1000 | |
beta_schedule: cosine | |
beta_start: 0.0001 | |
beta_end: 0.02 | |
objective: pred_eps | |
var_type: fixed_large | |
train: | |
n_steps: 800000 | |
batch_size: 128 | |
micro_batch: 0 | |
clip_grad_norm: 1.0 | |
ema_decay: 0.9999 | |
ema_gradual: true | |
print_freq: 400 | |
save_freq: 10000 | |
sample_freq: 5000 | |
n_samples_each_class: 10 | |
p_uncond: 0.2 | |
optim: | |
target: torch.optim.AdamW | |
params: | |
lr: 0.0002 | |