trainer:
  target: trainer.TrainerDifIRLPIPS

autoencoder:
  target: ldm.models.autoencoder.VQModelTorch
  ckpt_path: weights/autoencoder_vq_f4.pth
  use_fp16: True
  params:
    embed_dim: 3
    n_embed: 8192
    ddconfig:
      double_z: False
      z_channels: 3
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
      padding_mode: zeros

model:
  target: models.unet.UNetModelSwin
  ckpt_path: ~
  params:
    image_size: 64
    in_channels: 3
    model_channels: 160
    out_channels: ${autoencoder.params.embed_dim}
    attention_resolutions: [64,32,16,8]
    dropout: 0
    channel_mult: [1, 2, 2, 4]
    num_res_blocks: [2, 2, 2, 2]
    conv_resample: True
    dims: 2
    use_fp16: False
    num_head_channels: 32
    use_scale_shift_norm: True
    resblock_updown: False
    swin_depth: 2
    swin_embed_dim: 192
    window_size: 8
    mlp_ratio: 4
    cond_lq: True
    lq_size: 128

diffusion:
  target: models.script_util.create_gaussian_diffusion
  params:
    sf: 2
    schedule_name: exponential
    schedule_kwargs:
      power: 0.3
    etas_end: 0.99
    steps: 4
    min_noise_level: 0.2
    kappa: 2.0
    weighted_mse: False
    predict_type: xstart
    timestep_respacing: ~
    scale_factor: 1.0
    normalize_input: True
    latent_flag: True

degradation:
  sf: 2
  # the first degradation process
  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
  resize_range: [0.15, 1.5]
  gaussian_noise_prob: 0.5
  noise_range: [1, 30]
  poisson_scale_range: [0.05, 3.0]
  gray_noise_prob: 0.4
  jpeg_range: [30, 95]

  # the second degradation process
  second_order_prob: 0.5
  second_blur_prob: 0.8
  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
  resize_range2: [0.3, 1.2]
  gaussian_noise_prob2: 0.5
  noise_range2: [1, 25]
  poisson_scale_range2: [0.05, 2.5]
  gray_noise_prob2: 0.4
  jpeg_range2: [30, 95]

  gt_size: 256 
  resize_back: False
  use_sharp: False

data:
  train:
    type: realesrgan
    params:
      dir_paths: []
      txt_file_path: [
                      '/mnt/sfs-common/zsyue/database/ImageNet/files_txt/path_train_all.txt', 
                      '/mnt/sfs-common/zsyue/database/FFHQ/files_txt/files256.txt',
                     ] 
      im_exts: ['JPEG', ]
      io_backend:
        type: disk
      blur_kernel_size: 21
      kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob: 0.1
      blur_sigma: [0.2, 3.0]
      betag_range: [0.5, 4.0]
      betap_range: [1, 2.0]

      blur_kernel_size2: 15
      kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob2: 0.1
      blur_sigma2: [0.2, 1.5]
      betag_range2: [0.5, 4.0]
      betap_range2: [1, 2.0]

      final_sinc_prob: 0.8

      gt_size: 256
      crop_pad_size: 300
      use_hflip: True
      use_rot: False
      rescale_gt: True

train:
  # learning rate
  lr: 5e-5                      # learning rate 
  lr_min: 2e-5                      # learning rate 
  lr_schedule: cosin
  warmup_iterations: 5000
  # dataloader
  batch: [96, 8]                
  microbatch: 12
  num_workers: 6
  prefetch_factor: 2            
  # optimization settings
  weight_decay: 0               
  ema_rate: 0.999
  iterations: 400000            # total iterations
  # save logging
  save_freq: 10000
  log_freq: [200, 2000, 1]         # [training loss, training images, val images]
  loss_coef: [1.0, 1.0]         # [mse, lpips]
  local_logging: True           # manually save images
  tf_logging: False             # tensorboard logging
  # validation settings
  use_ema_val: True            
  val_freq: ${train.save_freq}
  val_y_channel: True
  val_resolution: ${model.params.lq_size}
  val_padding_mode: reflect
  # training setting
  use_amp: True                # amp training
  seed: 123456                 # random seed
  global_seeding: False
  # model compile
  compile:
    flag: False
    mode: reduce-overhead      # default, reduce-overhead