|
import os |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import wandb |
|
from torch.utils.data import DataLoader |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
|
|
from data.LQGT_dataset import LQGTDataset, LQGTValDataset |
|
from model import decoder, discriminator, encoder |
|
from opt.option import args |
|
from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair, |
|
VGG19PerceptualLoss) |
|
|
|
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure |
|
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity |
|
|
|
wandb.init(project='SR', config=args) |
|
|
|
|
|
|
|
|
|
if args.gpu_id is not None: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0" |
|
print('using GPU 0') |
|
else: |
|
print('use --gpu_id to specify GPU ID to use') |
|
exit() |
|
|
|
device = torch.device('cuda') |
|
|
|
|
|
if not os.path.exists(args.snap_path): |
|
os.mkdir(args.snap_path) |
|
|
|
print("Loading dataset...") |
|
|
|
train_dataset = LQGTDataset( |
|
db_path=args.dir_data, |
|
transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()]) |
|
) |
|
|
|
val_dataset = LQGTValDataset( |
|
db_path=args.dir_data, |
|
transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()]) |
|
) |
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
drop_last=True, |
|
shuffle=True |
|
) |
|
|
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_workers, |
|
shuffle=False |
|
) |
|
|
|
|
|
print("Create model") |
|
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device) |
|
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device) |
|
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device) |
|
|
|
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device) |
|
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
|
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wandb.watch(model_Disc_feat) |
|
wandb.watch(model_Disc_img_LR) |
|
wandb.watch(model_Enc) |
|
wandb.watch(model_Dec_Id) |
|
wandb.watch(model_Dec_SR) |
|
|
|
|
|
print("Define Loss") |
|
|
|
loss_L1 = nn.L1Loss().to(device) |
|
loss_MSE = nn.MSELoss().to(device) |
|
loss_adversarial = nn.BCEWithLogitsLoss().to(device) |
|
loss_percept = VGG19PerceptualLoss().to(device) |
|
|
|
|
|
print("Define Optimizer") |
|
|
|
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters()) |
|
optimizer_G = optim.Adam( |
|
params_G, |
|
lr=args.lr_G, |
|
betas=(args.beta1, args.beta2), |
|
weight_decay=args.weight_decay, |
|
amsgrad=True |
|
) |
|
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters()) |
|
optimizer_D = optim.Adam( |
|
params_D, |
|
lr=args.lr_D, |
|
betas=(args.beta1, args.beta2), |
|
weight_decay=args.weight_decay, |
|
amsgrad=True |
|
) |
|
|
|
print("Define Scheduler") |
|
|
|
iter_indices = [args.interval1, args.interval2, args.interval3] |
|
scheduler_G = optim.lr_scheduler.MultiStepLR( |
|
optimizer=optimizer_G, |
|
milestones=iter_indices, |
|
gamma=0.5 |
|
) |
|
scheduler_D = optim.lr_scheduler.MultiStepLR( |
|
optimizer=optimizer_D, |
|
milestones=iter_indices, |
|
gamma=0.5 |
|
) |
|
|
|
|
|
model_Enc = nn.DataParallel(model_Enc) |
|
model_Dec_Id = nn.DataParallel(model_Dec_Id) |
|
model_Dec_SR = nn.DataParallel(model_Dec_SR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Load model weight") |
|
|
|
if args.checkpoint is not None: |
|
checkpoint = torch.load(args.checkpoint) |
|
|
|
model_Enc.load_state_dict(checkpoint['model_Enc']) |
|
model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id']) |
|
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR']) |
|
model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat']) |
|
model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR']) |
|
model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR']) |
|
|
|
optimizer_D.load_state_dict(checkpoint['optimizer_D']) |
|
optimizer_G.load_state_dict(checkpoint['optimizer_G']) |
|
|
|
scheduler_D.load_state_dict(checkpoint['scheduler_D']) |
|
scheduler_G.load_state_dict(checkpoint['scheduler_G']) |
|
|
|
start_epoch = checkpoint['epoch'] |
|
else: |
|
start_epoch = 0 |
|
|
|
|
|
if args.pretrained is not None: |
|
ckpt = torch.load(args.pretrained) |
|
ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3) |
|
model_Dec_SR.load_state_dict(ckpt["params"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PSNR = PeakSignalNoiseRatio().to(device) |
|
SSIM = StructuralSimilarityIndexMeasure().to(device) |
|
LPIPS = LearnedPerceptualImagePatchSimilarity().to(device) |
|
|
|
if args.phase == "train": |
|
for epoch in range(start_epoch, args.epochs): |
|
|
|
model_Enc.train() |
|
model_Dec_Id.train() |
|
model_Dec_SR.train() |
|
|
|
|
|
model_Disc_feat.train() |
|
model_Disc_img_LR.train() |
|
model_Disc_img_HR.train() |
|
|
|
running_loss_D_total = 0.0 |
|
running_loss_G_total = 0.0 |
|
|
|
running_loss_align = 0.0 |
|
running_loss_rec = 0.0 |
|
running_loss_res = 0.0 |
|
running_loss_sty = 0.0 |
|
running_loss_idt = 0.0 |
|
running_loss_cyc = 0.0 |
|
|
|
iter = 0 |
|
|
|
for data in tqdm(train_loader): |
|
iter += 1 |
|
|
|
|
|
|
|
|
|
X_t, Y_s = data['img_LQ'], data['img_GT'] |
|
|
|
ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic') |
|
X_s = ds4(Y_s) |
|
|
|
X_t = X_t.cuda(non_blocking=True) |
|
X_s = X_s.cuda(non_blocking=True) |
|
Y_s = Y_s.cuda(non_blocking=True) |
|
|
|
|
|
batch_size = X_t.size(0) |
|
real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True) |
|
fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True) |
|
|
|
|
|
|
|
|
|
|
|
model_Disc_feat.zero_grad() |
|
model_Disc_img_LR.zero_grad() |
|
model_Disc_img_HR.zero_grad() |
|
|
|
for i in range(args.n_disc): |
|
|
|
F_t = model_Enc(X_t) |
|
F_s = model_Enc(X_s) |
|
|
|
|
|
|
|
output_Disc_F_t = model_Disc_feat(F_t.detach()) |
|
output_Disc_F_s = model_Disc_feat(F_s.detach()) |
|
|
|
loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label) |
|
loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label) |
|
loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2 |
|
|
|
|
|
|
|
Y_s_s = model_Dec_SR(F_s) |
|
|
|
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach()) |
|
output_Disc_Y_s = model_Disc_img_HR(Y_s) |
|
|
|
loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label) |
|
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
|
loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2 |
|
|
|
|
|
|
|
X_s_t = model_Dec_Id(F_s) |
|
|
|
output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach()) |
|
output_Disc_X_t = model_Disc_img_LR(X_t) |
|
|
|
loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label) |
|
loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label) |
|
loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2 |
|
|
|
|
|
|
|
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
|
|
|
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach()) |
|
output_Disc_Y_s = model_Disc_img_HR(Y_s) |
|
|
|
loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label) |
|
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label) |
|
loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2 |
|
|
|
|
|
loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc |
|
loss_D_total.backward() |
|
optimizer_D.step() |
|
|
|
|
|
|
|
scheduler_D.step() |
|
|
|
|
|
|
|
|
|
|
|
model_Enc.zero_grad() |
|
model_Dec_Id.zero_grad() |
|
model_Dec_SR.zero_grad() |
|
|
|
for i in range(args.n_gen): |
|
|
|
F_t = model_Enc(X_t) |
|
F_s = model_Enc(X_s) |
|
|
|
|
|
|
|
output_Disc_F_t = model_Disc_feat(F_t) |
|
output_Disc_F_s = model_Disc_feat(F_s) |
|
|
|
loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2) |
|
loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2) |
|
L_align_E = loss_G_F_t + loss_G_F_s |
|
|
|
|
|
|
|
Y_s_s = model_Dec_SR(F_s) |
|
|
|
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s) |
|
|
|
loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s) |
|
|
|
loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s) |
|
|
|
loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label) |
|
L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s |
|
|
|
|
|
X_t_t = model_Dec_Id(F_t) |
|
L_res_G_t = loss_L1(X_t, X_t_t) |
|
|
|
|
|
|
|
X_s_t = model_Dec_Id(F_s) |
|
|
|
output_Disc_X_s_t = model_Disc_img_LR(X_s_t) |
|
|
|
loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label) |
|
L_sty_G_t = loss_G_X_s_t |
|
|
|
|
|
F_s_tilda = model_Enc(model_Dec_Id(F_s)) |
|
L_idt_G_t = loss_L1(F_s, F_s_tilda) |
|
|
|
|
|
|
|
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s))) |
|
|
|
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s) |
|
|
|
loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s) |
|
|
|
loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s) |
|
|
|
loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label) |
|
L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s |
|
|
|
|
|
loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR |
|
loss_G_total.backward() |
|
optimizer_G.step() |
|
scheduler_G.step() |
|
|
|
|
|
|
|
|
|
|
|
running_loss_D_total += loss_D_total.item() |
|
running_loss_G_total += loss_G_total.item() |
|
|
|
running_loss_align += L_align_E.item() |
|
running_loss_rec += L_rec_G_SR.item() |
|
running_loss_res += L_res_G_t.item() |
|
running_loss_sty += L_sty_G_t.item() |
|
running_loss_idt += L_idt_G_t.item() |
|
running_loss_cyc += L_cyc_G_t_G_SR.item() |
|
if iter % args.log_interval == 0: |
|
wandb.log( |
|
{ |
|
"loss_D_total_step": running_loss_D_total/iter, |
|
"loss_G_total_step": running_loss_G_total/iter, |
|
"loss_align_step": running_loss_align/iter, |
|
"loss_rec_step": running_loss_rec/iter, |
|
"loss_res_step": running_loss_res/iter, |
|
"loss_sty_step": running_loss_sty/iter, |
|
"loss_idt_step": running_loss_idt/iter, |
|
"loss_cyc_step": running_loss_cyc/iter, |
|
} |
|
) |
|
|
|
total_PSNR = 0 |
|
total_SSIM = 0 |
|
total_LPIPS = 0 |
|
val_iter = 0 |
|
with torch.no_grad(): |
|
model_Enc.eval() |
|
model_Dec_SR.eval() |
|
for batch_idx, batch in enumerate(val_loader): |
|
val_iter += 1 |
|
source = batch["img_LQ"].to(device) |
|
target = batch["img_GT"].to(device) |
|
|
|
feat = model_Enc(source) |
|
out = model_Dec_SR(feat) |
|
|
|
total_PSNR += PSNR(out, target) |
|
total_SSIM += SSIM(out, target) |
|
total_LPIPS += LPIPS(out, target) |
|
|
|
wandb.log( |
|
{ |
|
"epoch": epoch, |
|
"lr": optimizer_G.param_groups[0]['lr'], |
|
"loss_D_total_epoch": running_loss_D_total/iter, |
|
"loss_G_total_epoch": running_loss_G_total/iter, |
|
"loss_align_epoch": running_loss_align/iter, |
|
"loss_rec_epoch": running_loss_rec/iter, |
|
"loss_res_epoch": running_loss_res/iter, |
|
"loss_sty_epoch": running_loss_sty/iter, |
|
"loss_idt_epoch": running_loss_idt/iter, |
|
"loss_cyc_epoch": running_loss_cyc/iter, |
|
"PSNR_val": total_PSNR/val_iter, |
|
"SSIM_val": total_SSIM/val_iter, |
|
"LPIPS_val": total_LPIPS/val_iter |
|
} |
|
) |
|
|
|
|
|
if (epoch+1) % args.save_freq == 0: |
|
weights_file_name = 'epoch_%d.pth' % (epoch+1) |
|
weights_file = os.path.join(args.snap_path, weights_file_name) |
|
torch.save({ |
|
'epoch': epoch, |
|
|
|
'model_Enc': model_Enc.state_dict(), |
|
'model_Dec_Id': model_Dec_Id.state_dict(), |
|
'model_Dec_SR': model_Dec_SR.state_dict(), |
|
'model_Disc_feat': model_Disc_feat.state_dict(), |
|
'model_Disc_img_LR': model_Disc_img_LR.state_dict(), |
|
'model_Disc_img_HR': model_Disc_img_HR.state_dict(), |
|
|
|
'optimizer_D': optimizer_D.state_dict(), |
|
'optimizer_G': optimizer_G.state_dict(), |
|
|
|
'scheduler_D': scheduler_D.state_dict(), |
|
'scheduler_G': scheduler_G.state_dict(), |
|
}, weights_file) |
|
print('save weights of epoch %d' % (epoch+1)) |
|
else: |
|
|
|
total_PSNR = 0 |
|
total_SSIM = 0 |
|
total_LPIPS = 0 |
|
val_iter = 0 |
|
with torch.no_grad(): |
|
model_Enc.eval() |
|
model_Dec_SR.eval() |
|
for batch_idx, batch in enumerate(val_loader): |
|
val_iter += 1 |
|
source = batch["img_LQ"].to(device) |
|
target = batch["img_GT"].to(device) |
|
|
|
feat = model_Enc(source) |
|
out = model_Dec_SR(feat) |
|
|
|
total_PSNR += PSNR(out, target) |
|
total_SSIM += SSIM(out, target) |
|
total_LPIPS += LPIPS(out, target) |
|
print("PSNR_val: ", total_PSNR/val_iter) |
|
print("SSIM_val: ", total_SSIM/val_iter) |
|
print("LPIPS_val: ", total_LPIPS/val_iter) |