File size: 2,542 Bytes
27486b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
model:
  base_learning_rate: 1.0e-4
  target: sgm.models.diffusion.DiffusionEngine
  params:
    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.Denoiser
      params:
        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
          params:
            sigma_data: 1.0

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        in_channels: 1
        out_channels: 1
        model_channels: 32
        attention_resolutions: []
        num_res_blocks: 4
        channel_mult: [1, 2, 2]
        num_head_channels: 32
        num_classes: sequential
        adm_in_channels: 128

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          - is_trainable: True
            input_key: cls
            ucg_rate: 0.2
            target: sgm.modules.encoders.modules.ClassEmbedder
            params:
              embed_dim: 128
              n_classes: 10

    first_stage_config:
      target: sgm.models.autoencoder.IdentityFirstStage

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        loss_type: l1
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
          params:
            sigma_data: 1.0
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 3.0

data:
  target: sgm.data.mnist.MNISTLoader
  params:
    batch_size: 512
    num_workers: 1

lightning:
  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 25000

    image_logger:
      target: main.ImageLogger
      params:
        disabled: False
        batch_frequency: 1000
        max_images: 64
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          N: 64
          n_rows: 8

  trainer:
    devices: 0,
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 20