import os |
import torch |
import torch.nn as nn |
import torch.optim as optim |
import wandb |
from torch.utils.data import DataLoader |
from torchvision import transforms |
from tqdm import tqdm |
from data.LQGT_dataset import LQGTDataset, LQGTValDataset |
from model import decoder, discriminator, encoder |
from opt.option import args |
from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair, |
VGG19PerceptualLoss) |
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
wandb.init(project='SR', config=args) |
if args.gpu_id is not None: |
os.environ['CUDA_VISIBLE_DEVICES'] = "0" |
print('using GPU 0') |
else: |
print('use --gpu_id to specify GPU ID to use') |
exit() |
device = torch.device('cuda') |
if not os.path.exists(args.snap_path): |
os.mkdir(args.snap_path) |
print("Loading dataset...") |
train_dataset = LQGTDataset( |
db_path=args.dir_data, |
transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()]) |
) |
val_dataset = LQGTValDataset( |
db_path=args.dir_data, |
transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()]) |
) |
train_loader = DataLoader( |
train_dataset, |
batch_size=args.batch_size, |
num_workers=args.num_workers, |
drop_last=True, |
shuffle=True |
) |
val_loader = DataLoader( |
val_dataset, |
batch_size=args.batch_size, |
num_workers=args.num_workers, |
shuffle=False |
) |
print("Create model") |
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device) |
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device) |
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device) |
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device) |
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
wandb.watch(model_Disc_feat) |
wandb.watch(model_Disc_img_LR) |
wandb.watch(model_Enc) |
wandb.watch(model_Dec_Id) |
wandb.watch(model_Dec_SR) |
print("Define Loss") |
loss_L1 = nn.L1Loss().to(device) |
loss_MSE = nn.MSELoss().to(device) |
loss_adversarial = nn.BCEWithLogitsLoss().to(device) |
loss_percept = VGG19PerceptualLoss().to(device) |
print("Define Optimizer") |
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters()) |
optimizer_G = optim.Adam( |
params_G, |
lr=args.lr_G, |
betas=(args.beta1, args.beta2), |
weight_decay=args.weight_decay, |
amsgrad=True |
) |
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters()) |
optimizer_D = optim.Adam( |
params_D, |
lr=args.lr_D, |
betas=(args.beta1, args.beta2), |
weight_decay=args.weight_decay, |
amsgrad=True |
) |
print("Define Scheduler") |
iter_indices = [args.interval1, args.interval2, args.interval3] |
scheduler_G = optim.lr_scheduler.MultiStepLR( |
optimizer=optimizer_G, |
milestones=iter_indices, |
gamma=0.5 |
) |
scheduler_D = optim.lr_scheduler.MultiStepLR( |
optimizer=optimizer_D, |
milestones=iter_indices, |
gamma=0.5 |
) |
model_Enc = nn.DataParallel(model_Enc) |
model_Dec_Id = nn.DataParallel(model_Dec_Id) |
model_Dec_SR = nn.DataParallel(model_Dec_SR) |
print("Load model weight") |
if args.checkpoint is not None: |
checkpoint = torch.load(args.checkpoint) |
model_Enc.load_state_dict(checkpoint['model_Enc']) |
model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id']) |
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR']) |
model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat']) |
model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR']) |
model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR']) |
optimizer_D.load_state_dict(checkpoint['optimizer_D']) |
optimizer_G.load_state_dict(checkpoint['optimizer_G']) |
scheduler_D.load_state_dict(checkpoint['scheduler_D']) |
scheduler_G.load_state_dict(checkpoint['scheduler_G']) |
start_epoch = checkpoint['epoch'] |
else: |
start_epoch = 0 |
if args.pretrained is not None: |
ckpt = torch.load(args.pretrained) |
ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3) |
model_Dec_SR.load_state_dict(ckpt["params"]) |
PSNR = PeakSignalNoiseRatio().to(device) |
SSIM = StructuralSimilarityIndexMeasure().to(device) |
LPIPS = LearnedPerceptualImagePatchSimilarity().to(device) |
if args.phase == "train": |
for epoch in range(start_epoch, args.epochs): |
model_Enc.train() |
model_Dec_Id.train() |
model_Dec_SR.train() |
model_Disc_feat.train() |
model_Disc_img_LR.train() |
model_Disc_img_HR.train() |
running_loss_D_total = 0.0 |
running_loss_G_total = 0.0 |
running_loss_align = 0.0 |
running_loss_rec = 0.0 |
running_loss_res = 0.0 |
running_loss_sty = 0.0 |
running_loss_idt = 0.0 |
running_loss_cyc = 0.0 |
iter = 0 |
for data in tqdm(train_loader): |
iter += 1 |
X_t, Y_s = data['img_LQ'], data['img_GT'] |
ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic') |
X_s = ds4(Y_s) |
X_t = X_t.cuda(non_blocking=True) |
X_s = X_s.cuda(non_blocking=True) |
Y_s = Y_s.cuda(non_blocking=True) |
batch_size = X_t.size(0) |
real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True) |
fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True) |
model_Disc_feat.zero_grad() |
model_Disc_img_LR.zero_grad() |
model_Disc_img_HR.zero_grad() |
for i in range(args.n_disc): |
F_t = model_Enc(X_t) |
F_s = model_Enc(X_s) |
output_Disc_F_t = model_Disc_feat(F_t.detach()) |
output_Disc_F_s = model_Disc_feat(F_s.detach()) |
loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label) |
loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label) |
loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2 |
Y_s_s = model_Dec_SR(F_s) |
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach()) |
output_Disc_Y_s = model_Disc_img_HR(Y_s) |
loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label) |
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2 |
X_s_t = model_Dec_Id(F_s) |
output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach()) |
output_Disc_X_t = model_Disc_img_LR(X_t) |
loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label) |
loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label) |
loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2 |
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach()) |
output_Disc_Y_s = model_Disc_img_HR(Y_s) |
loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label) |
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2 |
loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc |
loss_D_total.backward() |
optimizer_D.step() |
scheduler_D.step() |
model_Enc.zero_grad() |
model_Dec_Id.zero_grad() |
model_Dec_SR.zero_grad() |
for i in range(args.n_gen): |
F_t = model_Enc(X_t) |
F_s = model_Enc(X_s) |
output_Disc_F_t = model_Disc_feat(F_t) |
output_Disc_F_s = model_Disc_feat(F_s) |
loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2) |
loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2) |
L_align_E = loss_G_F_t + loss_G_F_s |
Y_s_s = model_Dec_SR(F_s) |
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s) |
loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s) |
loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s) |
loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label) |
L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s |
X_t_t = model_Dec_Id(F_t) |
L_res_G_t = loss_L1(X_t, X_t_t) |
X_s_t = model_Dec_Id(F_s) |
output_Disc_X_s_t = model_Disc_img_LR(X_s_t) |
loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label) |
L_sty_G_t = loss_G_X_s_t |
F_s_tilda = model_Enc(model_Dec_Id(F_s)) |
L_idt_G_t = loss_L1(F_s, F_s_tilda) |
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s) |
loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s) |
loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s) |
loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label) |
L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s |
loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR |
loss_G_total.backward() |
optimizer_G.step() |
scheduler_G.step() |
running_loss_D_total += loss_D_total.item() |
running_loss_G_total += loss_G_total.item() |
running_loss_align += L_align_E.item() |
running_loss_rec += L_rec_G_SR.item() |
running_loss_res += L_res_G_t.item() |
running_loss_sty += L_sty_G_t.item() |
running_loss_idt += L_idt_G_t.item() |
running_loss_cyc += L_cyc_G_t_G_SR.item() |
if iter % args.log_interval == 0: |
wandb.log( |
{ |
"loss_D_total_step": running_loss_D_total/iter, |
"loss_G_total_step": running_loss_G_total/iter, |
"loss_align_step": running_loss_align/iter, |
"loss_rec_step": running_loss_rec/iter, |
"loss_res_step": running_loss_res/iter, |
"loss_sty_step": running_loss_sty/iter, |
"loss_idt_step": running_loss_idt/iter, |
"loss_cyc_step": running_loss_cyc/iter, |
} |
) |
total_PSNR = 0 |
total_SSIM = 0 |
total_LPIPS = 0 |
val_iter = 0 |
with torch.no_grad(): |
model_Enc.eval() |
model_Dec_SR.eval() |
for batch_idx, batch in enumerate(val_loader): |
val_iter += 1 |
source = batch["img_LQ"].to(device) |
target = batch["img_GT"].to(device) |
feat = model_Enc(source) |
out = model_Dec_SR(feat) |
total_PSNR += PSNR(out, target) |
total_SSIM += SSIM(out, target) |
total_LPIPS += LPIPS(out, target) |
wandb.log( |
{ |
"epoch": epoch, |
"lr": optimizer_G.param_groups[0]['lr'], |
"loss_D_total_epoch": running_loss_D_total/iter, |
"loss_G_total_epoch": running_loss_G_total/iter, |
"loss_align_epoch": running_loss_align/iter, |
"loss_rec_epoch": running_loss_rec/iter, |
"loss_res_epoch": running_loss_res/iter, |
"loss_sty_epoch": running_loss_sty/iter, |
"loss_idt_epoch": running_loss_idt/iter, |
"loss_cyc_epoch": running_loss_cyc/iter, |
"PSNR_val": total_PSNR/val_iter, |
"SSIM_val": total_SSIM/val_iter, |
"LPIPS_val": total_LPIPS/val_iter |
} |
) |
if (epoch+1) % args.save_freq == 0: |
weights_file_name = 'epoch_%d.pth' % (epoch+1) |
weights_file = os.path.join(args.snap_path, weights_file_name) |
torch.save({ |
'epoch': epoch, |
'model_Enc': model_Enc.state_dict(), |
'model_Dec_Id': model_Dec_Id.state_dict(), |
'model_Dec_SR': model_Dec_SR.state_dict(), |
'model_Disc_feat': model_Disc_feat.state_dict(), |
'model_Disc_img_LR': model_Disc_img_LR.state_dict(), |
'model_Disc_img_HR': model_Disc_img_HR.state_dict(), |
'optimizer_D': optimizer_D.state_dict(), |
'optimizer_G': optimizer_G.state_dict(), |
'scheduler_D': scheduler_D.state_dict(), |
'scheduler_G': scheduler_G.state_dict(), |
}, weights_file) |
print('save weights of epoch %d' % (epoch+1)) |
else: |
total_PSNR = 0 |
total_SSIM = 0 |
total_LPIPS = 0 |
val_iter = 0 |
with torch.no_grad(): |
model_Enc.eval() |
model_Dec_SR.eval() |
for batch_idx, batch in enumerate(val_loader): |
val_iter += 1 |
source = batch["img_LQ"].to(device) |
target = batch["img_GT"].to(device) |
feat = model_Enc(source) |
out = model_Dec_SR(feat) |
total_PSNR += PSNR(out, target) |
total_SSIM += SSIM(out, target) |
total_LPIPS += LPIPS(out, target) |
print("PSNR_val: ", total_PSNR/val_iter) |
print("SSIM_val: ", total_SSIM/val_iter) |
print("LPIPS_val: ", total_LPIPS/val_iter) |