File size: 6,896 Bytes
bfa59ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
trainer:
  target: trainer.TrainerSDTurboSR

sd_pipe:
  target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
  num_train_steps: 1000
  enable_grad_checkpoint: True
  compile: False
  vae_split: 8
  params:
    pretrained_model_name_or_path: stabilityai/sd-turbo
    cache_dir: weights
    use_safetensors: True
    torch_dtype: torch.float16

llpips:
  target: latent_lpips.lpips.LPIPS
  ckpt_path: weights/vgg16_sdturbo_lpips.pth
  compile: False
  params:
    pretrained: False
    net: vgg16
    lpips: True
    spatial: False
    pnet_rand: False
    pnet_tune: True
    use_dropout: True
    eval_mode: True
    latent: True
    in_chans: 4
    verbose: True

model:
  target: diffusers.models.autoencoders.NoisePredictor
  ckpt_start_path: ~     # only used for training the intermidiate model
  ckpt_path: ~           # For initializing
  compile: False
  params:
    in_channels: 3
    down_block_types:
      - AttnDownBlock2D
      - AttnDownBlock2D
    up_block_types:
      - AttnUpBlock2D
      - AttnUpBlock2D
    block_out_channels:
      - 256      # 192, 256
      - 512      # 384, 512
    layers_per_block: 
      - 3
      - 3
    act_fn: silu
    latent_channels: 4
    norm_num_groups: 32
    sample_size: 128
    mid_block_add_attention: True
    resnet_time_scale_shift: default
    temb_channels: 512
    attention_head_dim: 64 
    freq_shift: 0
    flip_sin_to_cos: True
    double_z: True

discriminator:
  target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator
  enable_grad_checkpoint: True
  compile: False
  params:
    sample_size: 64
    in_channels: 4
    center_input_sample: False
    flip_sin_to_cos: True
    freq_shift: 0
    down_block_types:
      - DownBlock2D
      - CrossAttnDownBlock2D
      - CrossAttnDownBlock2D
    mid_block_type: UNetMidBlock2DCrossAttn
    up_block_types:
      - CrossAttnUpBlock2D
      - CrossAttnUpBlock2D
      - UpBlock2D
    only_cross_attention: False
    block_out_channels:
      - 128
      - 256
      - 512
    layers_per_block:
      - 1
      - 2
      - 2
    downsample_padding: 1
    mid_block_scale_factor: 1
    dropout: 0.0
    act_fn: silu
    norm_num_groups: 32
    norm_eps: 1e-5
    cross_attention_dim: 1024
    transformer_layers_per_block: 1
    reverse_transformer_layers_per_block: ~
    encoder_hid_dim: ~
    encoder_hid_dim_type: ~
    attention_head_dim:
      - 8 
      - 16 
      - 16 
    num_attention_heads: ~
    dual_cross_attention: False
    use_linear_projection: False
    class_embed_type: ~
    addition_embed_type: text 
    addition_time_embed_dim: 256
    num_class_embeds: ~
    upcast_attention: ~
    resnet_time_scale_shift: default
    resnet_skip_time_act: False
    resnet_out_scale_factor: 1.0
    time_embedding_type: positional
    time_embedding_dim: ~
    time_embedding_act_fn: ~
    timestep_post_act: ~
    time_cond_proj_dim: ~
    conv_in_kernel: 3
    conv_out_kernel: 3
    projection_class_embeddings_input_dim: 2560
    attention_type: default
    class_embeddings_concat: False
    mid_block_only_cross_attention: ~
    cross_attention_norm: ~
    addition_embed_type_num_heads: 64

degradation:
  sf: 4
  # the first degradation process
  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
  resize_range: [0.15, 1.5]
  gaussian_noise_prob: 0.5
  noise_range: [1, 30]
  poisson_scale_range: [0.05, 3.0]
  gray_noise_prob: 0.4
  jpeg_range: [30, 95]

  # the second degradation process
  second_order_prob: 0.5
  second_blur_prob: 0.8
  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
  resize_range2: [0.3, 1.2]
  gaussian_noise_prob2: 0.5
  noise_range2: [1, 25]
  poisson_scale_range2: [0.05, 2.5]
  gray_noise_prob2: 0.4
  jpeg_range2: [30, 95]

  gt_size: 512 
  resize_back: False
  use_sharp: False

data:
  train:
    type: realesrgan
    params:
      data_source: 
        source1:
          root_path: /mnt/sfs-common/zsyue/database/FFHQ
          image_path: images1024
          moment_path: ~
          text_path: ~
          im_ext: png
          length: 20000
        source2:
          root_path: /mnt/sfs-common/zsyue/database/LSDIR/train
          image_path: images 
          moment_path: ~
          text_path: ~
          im_ext: png
      max_token_length: 77   # 77
      io_backend:
        type: disk
      blur_kernel_size: 21
      kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob: 0.1
      blur_sigma: [0.2, 3.0]
      betag_range: [0.5, 4.0]
      betap_range: [1, 2.0]

      blur_kernel_size2: 15
      kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob2: 0.1
      blur_sigma2: [0.2, 1.5]
      betag_range2: [0.5, 4.0]
      betap_range2: [1, 2.0]

      final_sinc_prob: 0.8

      gt_size: ${degradation.gt_size}
      use_hflip: True
      use_rot: False
      random_crop: True
  val:
    type: base
    params:
      dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq
      transform_type: default
      transform_kwargs:
          mean: 0.0
          std: 1.0
      extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt
      extra_transform_type: default
      extra_transform_kwargs:
          mean: 0.0
          std: 1.0
      im_exts: png
      length: 16
      recursive: False

train:
  # predict started inverser
  start_mode: True
  # learning rate
  lr: 5e-5                      # learning rate 
  lr_min: 5e-5                  # learning rate 
  lr_schedule: ~
  warmup_iterations: 2000
  # discriminator
  lr_dis: 5e-5                  # learning rate for dicriminator
  weight_decay_dis: 1e-3        # weight decay for dicriminator
  dis_init_iterations: 10000    # iterations used for updating the discriminator
  dis_update_freq: 1            
  # dataloader
  batch: 64               
  microbatch: 16 
  num_workers: 4
  prefetch_factor: 2            
  use_text: True
  # optimization settings
  weight_decay: 0               
  ema_rate: 0.999
  iterations: 200000            # total iterations
  # logging
  save_freq: 5000
  log_freq: [200, 5000]         # [training loss, training images, val images]
  local_logging: True           # manually save images
  tf_logging: False             # tensorboard logging
  # loss 
  loss_type: L2
  loss_coef:
    ldif: 1.0
  timesteps: [200, 100]
  num_inference_steps: 5
  # mixed precision
  use_amp: True                
  use_fsdp: False                
  # random seed 
  seed: 123456                 
  global_seeding: False
  noise_detach: False

validate:
  batch: 2
  use_ema: True            
  log_freq: 4      # logging frequence
  val_y_channel: True