DiffusionSR / src /config.py
shekkari21's picture
Commiting all the super resolution files
3c45764
"""
Configuration file with all training, model, and data parameters.
"""
import os
import torch
from pathlib import Path
# ============================================================================
# Project Settings
# ============================================================================
_project_root = Path(__file__).parent.parent
# ============================================================================
# Device Settings
# ============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ============================================================================
# Training Parameters
# ============================================================================
# Learning rate
lr = 1e-5 # Original ResShift setting
lr_min = 1e-5
lr_schedule = None
learning_rate = lr # Alias for backward compatibility
warmup_iterations = 100 # ~12.5% of total iterations (800), linear warmup from 0 to base_lr
# Dataloader
batch = [64, 64] # Original ResShift: adjust based on your GPU memory
batch_size = batch[0] # Use first value from batch list
microbatch = 100
num_workers = 4
prefetch_factor = 2
# Optimization settings
weight_decay = 0
ema_rate = 0.999
iterations = 3200 # 64 epochs for DIV2K (800 images / 64 batch_size = 12.5 batches per epoch)
# Save logging
save_freq = 200
log_freq = [50, 100] # [training loss, training images]
local_logging = True
tf_logging = False
# Validation settings
use_ema_val = True
val_freq = 100 # Run validation every 100 iterations
val_y_channel = True
val_resolution = 64 # model.params.lq_size
val_padding_mode = "reflect"
# Training setting
use_amp = True # Mixed precision training
seed = 123456
global_seeding = False
# Model compile
compile_flag = True
compile_mode = "reduce-overhead"
# ============================================================================
# Diffusion/Noise Schedule Parameters
# ============================================================================
sf = 4
schedule_name = "exponential"
schedule_power = 0.3 # Original ResShift setting
etas_end = 0.99 # Original ResShift setting
T = 15 # Original ResShift: 15 timesteps
min_noise_level = 0.04 # Original ResShift setting
eta_1 = min_noise_level # Alias for backward compatibility
eta_T = etas_end # Alias for backward compatibility
p = schedule_power # Alias for backward compatibility
kappa = 2.0
k = kappa # Alias for backward compatibility
weighted_mse = False
predict_type = "xstart" # Predict x0, not noise (key difference!)
timestep_respacing = None
scale_factor = 1.0
normalize_input = True
latent_flag = True # Working in latent space
# ============================================================================
# Model Architecture Parameters
# ============================================================================
# ResShift model architecture based on model_channels and channel_mult
# Initial Conv: 3 → 160
# Encoder Stage 1: 160 → 320 (downsample to 128x128)
# Encoder Stage 2: 320 → 320 (downsample to 64x64)
# Encoder Stage 3: 320 → 640 (downsample to 32x32)
# Encoder Stage 4: 640 (no downsampling, stays 32x32)
# Decoder Stage 1: 640 → 320 (upsample to 64x64)
# Decoder Stage 2: 320 → 320 (upsample to 128x128)
# Decoder Stage 3: 320 → 160 (upsample to 256x256)
# Decoder Stage 4: 160 → 3 (final output)
# Model params from ResShift configuration
image_size = 64 # Latent space: 64×64 (not 256×256 pixel space)
in_channels = 3
model_channels = 160 # Original ResShift: base channels
out_channels = 3
attention_resolutions = [64, 32, 16, 8] # Latent space resolutions
dropout = 0
channel_mult = [1, 2, 2, 4] # Original ResShift: 160, 320, 320, 640 channels
num_res_blocks = [2, 2, 2, 2]
conv_resample = True
dims = 2
use_fp16 = False
num_head_channels = 32
use_scale_shift_norm = True
resblock_updown = False
swin_depth = 2
swin_embed_dim = 192 # Original ResShift setting
window_size = 8 # Original ResShift setting (not 7)
mlp_ratio = 2.0 # Original ResShift uses 2.0, not 4
cond_lq = True # Enable LR conditioning
lq_size = 64 # LR latent size (same as image_size)
# U-Net architecture parameters based on ResShift configuration
# Initial conv: 3 → model_channels * channel_mult[0] = 160
initial_conv_out_channels = model_channels * channel_mult[0] # 160
# Encoder stage channels (based on channel_mult progression)
es1_in_channels = initial_conv_out_channels # 160
es1_out_channels = model_channels * channel_mult[1] # 320
es2_in_channels = es1_out_channels # 320
es2_out_channels = model_channels * channel_mult[2] # 320
es3_in_channels = es2_out_channels # 320
es3_out_channels = model_channels * channel_mult[3] # 640
es4_in_channels = es3_out_channels # 640
es4_out_channels = es3_out_channels # 640 (no downsampling)
# Decoder stage channels (reverse of encoder)
ds1_in_channels = es4_out_channels # 640
ds1_out_channels = es2_out_channels # 320
ds2_in_channels = ds1_out_channels # 320
ds2_out_channels = es2_out_channels # 320
ds3_in_channels = ds2_out_channels # 320
ds3_out_channels = es1_out_channels # 160
ds4_in_channels = ds3_out_channels # 160
ds4_out_channels = initial_conv_out_channels # 160
# Other model parameters
n_groupnorm_groups = 8 # Standard value
shift_size = window_size // 2 # Shift size for shifted window attention (should be window_size // 2, not swin_depth)
timestep_embed_dim = model_channels * 4 # Original ResShift: 160 * 4 = 640
num_heads = num_head_channels # Note: config has num_head_channels, but we need num_heads
# ============================================================================
# Autoencoder Parameters (from YAML, for reference)
# ============================================================================
autoencoder_ckpt_path = "pretrained_weights/autoencoder_vq_f4.pth"
autoencoder_use_fp16 = False # Temporarily disabled for CPU testing (FP16 is slow/hangs on CPU)
autoencoder_embed_dim = 3
autoencoder_n_embed = 8192
autoencoder_double_z = False
autoencoder_z_channels = 3
autoencoder_resolution = 256
autoencoder_in_channels = 3
autoencoder_out_ch = 3
autoencoder_ch = 128
autoencoder_ch_mult = [1, 2, 4]
autoencoder_num_res_blocks = 2
autoencoder_attn_resolutions = []
autoencoder_dropout = 0.0
autoencoder_padding_mode = "zeros"
# ============================================================================
# Degradation Parameters (used by realesrgan.py)
# ============================================================================
# Blur kernel settings (used for both first and second degradation)
blur_kernel_size = 21
kernel_list = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob = [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
# First degradation stage
resize_prob = [0.2, 0.7, 0.1] # up, down, keep
resize_range = [0.15, 1.5]
gaussian_noise_prob = 0.5
noise_range = [1, 30]
poisson_scale_range = [0.05, 3.0]
gray_noise_prob = 0.4
jpeg_range = [30, 95]
data_train_blur_sigma = [0.2, 3.0]
data_train_betag_range = [0.5, 4.0]
data_train_betap_range = [1, 2.0]
data_train_sinc_prob = 0.1
# Second degradation stage
second_order_prob = 0.5
second_blur_prob = 0.8
resize_prob2 = [0.3, 0.4, 0.3] # up, down, keep
resize_range2 = [0.3, 1.2]
gaussian_noise_prob2 = 0.5
noise_range2 = [1, 25]
poisson_scale_range2 = [0.05, 2.5]
gray_noise_prob2 = 0.4
jpeg_range2 = [30, 95]
data_train_blur_kernel_size2 = 15
data_train_blur_sigma2 = [0.2, 1.5]
data_train_betag_range2 = [0.5, 4.0]
data_train_betap_range2 = [1, 2.0]
data_train_sinc_prob2 = 0.1
# Final sinc filter
data_train_final_sinc_prob = 0.8
final_sinc_prob = data_train_final_sinc_prob # Alias for backward compatibility
# Other degradation settings
gt_size = 256
resize_back = False
use_sharp = False
# ============================================================================
# Data Parameters
# ============================================================================
# Data paths - using defaults based on project structure
dir_HR = str(_project_root / "data" / "DIV2K_train_HR")
dir_LR = str(_project_root / "data" / "DIV2K_train_LR_bicubic" / "X4")
dir_valid_HR = str(_project_root / "data" / "DIV2K_valid_HR")
dir_valid_LR = str(_project_root / "data" / "DIV2K_valid_LR_bicubic" / "X4")
# Patch size (used by dataset)
patch_size = gt_size # 256
# Scale factor (from degradation.sf)
scale = sf # 4