USR-DA / train.py
DS
dump shiet
e5b70eb
import os
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from data.LQGT_dataset import LQGTDataset, LQGTValDataset
from model import decoder, discriminator, encoder
from opt.option import args
from util.utils import (RandCrop, RandHorizontalFlip, RandRotate, ToTensor, RandCrop_pair,
VGG19PerceptualLoss)
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
wandb.init(project='SR', config=args)
# device setting
if args.gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
print('using GPU 0')
else:
print('use --gpu_id to specify GPU ID to use')
exit()
device = torch.device('cuda')
# make directory for saving weights
if not os.path.exists(args.snap_path):
os.mkdir(args.snap_path)
print("Loading dataset...")
# load training dataset
train_dataset = LQGTDataset(
db_path=args.dir_data,
transform=transforms.Compose([RandCrop(args.patch_size, args.scale), RandHorizontalFlip(), RandRotate(), ToTensor()])
)
val_dataset = LQGTValDataset(
db_path=args.dir_data,
transform=transforms.Compose([RandCrop_pair(args.patch_size, args.scale), ToTensor()])
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
drop_last=True,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=False
)
print("Create model")
model_Disc_feat = discriminator.DiscriminatorVGG(in_ch=args.n_hidden_feats, image_size=args.patch_size).to(device)
model_Disc_img_LR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.patch_size).to(device)
model_Disc_img_HR = discriminator.DiscriminatorVGG(in_ch=3, image_size=args.scale*args.patch_size).to(device)
# define model (generator)
model_Enc = encoder.Encoder_RRDB(num_feat=args.n_hidden_feats).to(device)
model_Dec_Id = decoder.Decoder_Id_RRDB(num_in_ch=args.n_hidden_feats).to(device)
model_Dec_SR = decoder.Decoder_SR_RRDB(num_in_ch=args.n_hidden_feats).to(device)
# define model (discriminator)
# model_Disc_feat = discriminator.UNetDiscriminator(num_in_ch=64).to(device)
# model_Disc_img_LR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
# model_Disc_img_HR = discriminator.UNetDiscriminator(num_in_ch=3).to(device)
# wandb logging
wandb.watch(model_Disc_feat)
wandb.watch(model_Disc_img_LR)
wandb.watch(model_Enc)
wandb.watch(model_Dec_Id)
wandb.watch(model_Dec_SR)
print("Define Loss")
# loss
loss_L1 = nn.L1Loss().to(device)
loss_MSE = nn.MSELoss().to(device)
loss_adversarial = nn.BCEWithLogitsLoss().to(device)
loss_percept = VGG19PerceptualLoss().to(device)
print("Define Optimizer")
# optimizer
params_G = list(model_Enc.parameters()) + list(model_Dec_Id.parameters()) + list(model_Dec_SR.parameters())
optimizer_G = optim.Adam(
params_G,
lr=args.lr_G,
betas=(args.beta1, args.beta2),
weight_decay=args.weight_decay,
amsgrad=True
)
params_D = list(model_Disc_feat.parameters()) + list(model_Disc_img_LR.parameters()) + list(model_Disc_img_HR.parameters())
optimizer_D = optim.Adam(
params_D,
lr=args.lr_D,
betas=(args.beta1, args.beta2),
weight_decay=args.weight_decay,
amsgrad=True
)
print("Define Scheduler")
# Scheduler
iter_indices = [args.interval1, args.interval2, args.interval3]
scheduler_G = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer_G,
milestones=iter_indices,
gamma=0.5
)
scheduler_D = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer_D,
milestones=iter_indices,
gamma=0.5
)
# print("Data Parallel")
model_Enc = nn.DataParallel(model_Enc)
model_Dec_Id = nn.DataParallel(model_Dec_Id)
model_Dec_SR = nn.DataParallel(model_Dec_SR)
# define model (discriminator)
#model_Disc_feat = nn.DataParallel(model_Disc_feat)
#model_Disc_img_LR = nn.DataParallel(model_Disc_img_LR)
#model_Disc_img_HR = nn.DataParallel(model_Disc_img_HR)
print("Load model weight")
# load model weights & optimzer % scheduler
if args.checkpoint is not None:
checkpoint = torch.load(args.checkpoint)
model_Enc.load_state_dict(checkpoint['model_Enc'])
model_Dec_Id.load_state_dict(checkpoint['model_Dec_Id'])
model_Dec_SR.load_state_dict(checkpoint['model_Dec_SR'])
model_Disc_feat.load_state_dict(checkpoint['model_Disc_feat'])
model_Disc_img_LR.load_state_dict(checkpoint['model_Disc_img_LR'])
model_Disc_img_HR.load_state_dict(checkpoint['model_Disc_img_HR'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
scheduler_D.load_state_dict(checkpoint['scheduler_D'])
scheduler_G.load_state_dict(checkpoint['scheduler_G'])
start_epoch = checkpoint['epoch']
else:
start_epoch = 0
if args.pretrained is not None:
ckpt = torch.load(args.pretrained)
ckpt["params"]["conv_first.weight"] = ckpt["params"]["conv_first.weight"][:,0,:,:].expand(64,64,3,3)
model_Dec_SR.load_state_dict(ckpt["params"])
# model_Enc = model_Enc.to(device)
# model_Dec_Id = model_Dec_Id.to(device)
# model_Dec_SR = model_Dec_SR.to(device)
# # define model (discriminator)
# model_Disc_feat = model_Disc_feat.to(device)
# model_Disc_img_LR = model_Disc_img_LR.to(device)
# model_Disc_img_HR =model_Disc_img_HR.to(device)
# training
PSNR = PeakSignalNoiseRatio().to(device)
SSIM = StructuralSimilarityIndexMeasure().to(device)
LPIPS = LearnedPerceptualImagePatchSimilarity().to(device)
if args.phase == "train":
for epoch in range(start_epoch, args.epochs):
# generator
model_Enc.train()
model_Dec_Id.train()
model_Dec_SR.train()
# discriminator
model_Disc_feat.train()
model_Disc_img_LR.train()
model_Disc_img_HR.train()
running_loss_D_total = 0.0
running_loss_G_total = 0.0
running_loss_align = 0.0
running_loss_rec = 0.0
running_loss_res = 0.0
running_loss_sty = 0.0
running_loss_idt = 0.0
running_loss_cyc = 0.0
iter = 0
for data in tqdm(train_loader):
iter += 1
########################
# data load #
########################
X_t, Y_s = data['img_LQ'], data['img_GT']
ds4 = nn.Upsample(scale_factor=1/args.scale, mode='bicubic')
X_s = ds4(Y_s)
X_t = X_t.cuda(non_blocking=True)
X_s = X_s.cuda(non_blocking=True)
Y_s = Y_s.cuda(non_blocking=True)
# real label and fake label
batch_size = X_t.size(0)
real_label = torch.full((batch_size, 1), 1, dtype=X_t.dtype).cuda(non_blocking=True)
fake_label = torch.full((batch_size, 1), 0, dtype=X_t.dtype).cuda(non_blocking=True)
########################
# (1) Update D network #
########################
model_Disc_feat.zero_grad()
model_Disc_img_LR.zero_grad()
model_Disc_img_HR.zero_grad()
for i in range(args.n_disc):
# generator output (feature domain)
F_t = model_Enc(X_t)
F_s = model_Enc(X_s)
# 1. feature aligment loss (discriminator)
# output of discriminator (feature domain) (b x c(=1) x h x w)
output_Disc_F_t = model_Disc_feat(F_t.detach())
output_Disc_F_s = model_Disc_feat(F_s.detach())
# discriminator loss (feature domain)
loss_Disc_F_t = loss_MSE(output_Disc_F_t, fake_label)
loss_Disc_F_s = loss_MSE(output_Disc_F_s, real_label)
loss_Disc_feat_align = (loss_Disc_F_t + loss_Disc_F_s) / 2
# 2. SR reconstruction loss (discriminator)
# generator output (image domain)
Y_s_s = model_Dec_SR(F_s)
# output of discriminator (image domain)
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s.detach())
output_Disc_Y_s = model_Disc_img_HR(Y_s)
# discriminator loss (image domain)
loss_Disc_Y_s_s = loss_MSE(output_Disc_Y_s_s, fake_label)
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
loss_Disc_img_rec = (loss_Disc_Y_s_s + loss_Disc_Y_s) / 2
# 4. Target degradation style loss
# generator output (image domain)
X_s_t = model_Dec_Id(F_s)
# output of discriminator (image domain)
output_Disc_X_s_t = model_Disc_img_LR(X_s_t.detach())
output_Disc_X_t = model_Disc_img_LR(X_t)
# discriminator loss (image domain)
loss_Disc_X_s_t = loss_MSE(output_Disc_X_s_t, fake_label)
loss_Disc_X_t = loss_MSE(output_Disc_X_t, real_label)
loss_Disc_img_sty = (loss_Disc_X_s_t + loss_Disc_X_t) / 2
# 6. Cycle loss
# generator output (image domain)
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
# output of discriminator (image domain)
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s.detach())
output_Disc_Y_s = model_Disc_img_HR(Y_s)
# discriminator loss (image domain)
loss_Disc_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, fake_label)
loss_Disc_Y_s = loss_MSE(output_Disc_Y_s, real_label)
loss_Disc_img_cyc = (loss_Disc_Y_s_t_s + loss_Disc_Y_s) / 2
# discriminator weight update
loss_D_total = loss_Disc_feat_align + loss_Disc_img_rec + loss_Disc_img_sty + loss_Disc_img_cyc
loss_D_total.backward()
optimizer_D.step()
scheduler_D.step()
########################
# (2) Update G network #
########################
model_Enc.zero_grad()
model_Dec_Id.zero_grad()
model_Dec_SR.zero_grad()
for i in range(args.n_gen):
# generator output (feature domain)
F_t = model_Enc(X_t)
F_s = model_Enc(X_s)
# 1. feature alignment loss (generator)
# output of discriminator (feature domain)
output_Disc_F_t = model_Disc_feat(F_t)
output_Disc_F_s = model_Disc_feat(F_s)
# generator loss (feature domain)
loss_G_F_t = loss_MSE(output_Disc_F_t, (real_label + fake_label)/2)
loss_G_F_s = loss_MSE(output_Disc_F_s, (real_label + fake_label)/2)
L_align_E = loss_G_F_t + loss_G_F_s
# 2. SR reconstruction loss
# generator output (image domain)
Y_s_s = model_Dec_SR(F_s)
# output of discriminator (image domain)
output_Disc_Y_s_s = model_Disc_img_HR(Y_s_s)
# L1 loss
loss_L1_rec = loss_L1(Y_s.detach(), Y_s_s)
# perceptual loss
loss_percept_rec = loss_percept(Y_s.detach(), Y_s_s)
# adversatial loss
loss_G_Y_s_s = loss_MSE(output_Disc_Y_s_s, real_label)
L_rec_G_SR = loss_L1_rec + args.lambda_percept*loss_percept_rec + args.lambda_adv*loss_G_Y_s_s
# 3. Target LR restoration loss
X_t_t = model_Dec_Id(F_t)
L_res_G_t = loss_L1(X_t, X_t_t)
# 4. Target degredation style loss
# generator output (image domain)
X_s_t = model_Dec_Id(F_s)
# output of discriminator (img domain)
output_Disc_X_s_t = model_Disc_img_LR(X_s_t)
# generator loss (feature domain)
loss_G_X_s_t = loss_MSE(output_Disc_X_s_t, real_label)
L_sty_G_t = loss_G_X_s_t
# 5. Feature identity loss
F_s_tilda = model_Enc(model_Dec_Id(F_s))
L_idt_G_t = loss_L1(F_s, F_s_tilda)
# 6. Cycle loss
# generator output (image domain)
Y_s_t_s = model_Dec_SR(model_Enc(model_Dec_Id(F_s)))
# output of discriminator (image domain)
output_Disc_Y_s_t_s = model_Disc_img_HR(Y_s_t_s)
# L1 loss
loss_L1_cyc = loss_L1(Y_s.detach(), Y_s_t_s)
# perceptual loss
loss_percept_cyc = loss_percept(Y_s.detach(), Y_s_t_s)
# adversarial loss
loss_Y_s_t_s = loss_MSE(output_Disc_Y_s_t_s, real_label)
L_cyc_G_t_G_SR = loss_L1_cyc + args.lambda_percept*loss_percept_cyc + args.lambda_adv*loss_Y_s_t_s
# generator weight update
loss_G_total = args.lambda_align*L_align_E + args.lambda_rec*L_rec_G_SR + args.lambda_res*L_res_G_t + args.lambda_sty*L_sty_G_t + args.lambda_idt*L_idt_G_t + args.lambda_cyc*L_cyc_G_t_G_SR
loss_G_total.backward()
optimizer_G.step()
scheduler_G.step()
########################
# compute loss #
########################
running_loss_D_total += loss_D_total.item()
running_loss_G_total += loss_G_total.item()
running_loss_align += L_align_E.item()
running_loss_rec += L_rec_G_SR.item()
running_loss_res += L_res_G_t.item()
running_loss_sty += L_sty_G_t.item()
running_loss_idt += L_idt_G_t.item()
running_loss_cyc += L_cyc_G_t_G_SR.item()
if iter % args.log_interval == 0:
wandb.log(
{
"loss_D_total_step": running_loss_D_total/iter,
"loss_G_total_step": running_loss_G_total/iter,
"loss_align_step": running_loss_align/iter,
"loss_rec_step": running_loss_rec/iter,
"loss_res_step": running_loss_res/iter,
"loss_sty_step": running_loss_sty/iter,
"loss_idt_step": running_loss_idt/iter,
"loss_cyc_step": running_loss_cyc/iter,
}
)
### EVALUATE ###
total_PSNR = 0
total_SSIM = 0
total_LPIPS = 0
val_iter = 0
with torch.no_grad():
model_Enc.eval()
model_Dec_SR.eval()
for batch_idx, batch in enumerate(val_loader):
val_iter += 1
source = batch["img_LQ"].to(device)
target = batch["img_GT"].to(device)
feat = model_Enc(source)
out = model_Dec_SR(feat)
total_PSNR += PSNR(out, target)
total_SSIM += SSIM(out, target)
total_LPIPS += LPIPS(out, target)
wandb.log(
{
"epoch": epoch,
"lr": optimizer_G.param_groups[0]['lr'],
"loss_D_total_epoch": running_loss_D_total/iter,
"loss_G_total_epoch": running_loss_G_total/iter,
"loss_align_epoch": running_loss_align/iter,
"loss_rec_epoch": running_loss_rec/iter,
"loss_res_epoch": running_loss_res/iter,
"loss_sty_epoch": running_loss_sty/iter,
"loss_idt_epoch": running_loss_idt/iter,
"loss_cyc_epoch": running_loss_cyc/iter,
"PSNR_val": total_PSNR/val_iter,
"SSIM_val": total_SSIM/val_iter,
"LPIPS_val": total_LPIPS/val_iter
}
)
if (epoch+1) % args.save_freq == 0:
weights_file_name = 'epoch_%d.pth' % (epoch+1)
weights_file = os.path.join(args.snap_path, weights_file_name)
torch.save({
'epoch': epoch,
'model_Enc': model_Enc.state_dict(),
'model_Dec_Id': model_Dec_Id.state_dict(),
'model_Dec_SR': model_Dec_SR.state_dict(),
'model_Disc_feat': model_Disc_feat.state_dict(),
'model_Disc_img_LR': model_Disc_img_LR.state_dict(),
'model_Disc_img_HR': model_Disc_img_HR.state_dict(),
'optimizer_D': optimizer_D.state_dict(),
'optimizer_G': optimizer_G.state_dict(),
'scheduler_D': scheduler_D.state_dict(),
'scheduler_G': scheduler_G.state_dict(),
}, weights_file)
print('save weights of epoch %d' % (epoch+1))
else:
### EVALUATE ###
total_PSNR = 0
total_SSIM = 0
total_LPIPS = 0
val_iter = 0
with torch.no_grad():
model_Enc.eval()
model_Dec_SR.eval()
for batch_idx, batch in enumerate(val_loader):
val_iter += 1
source = batch["img_LQ"].to(device)
target = batch["img_GT"].to(device)
feat = model_Enc(source)
out = model_Dec_SR(feat)
total_PSNR += PSNR(out, target)
total_SSIM += SSIM(out, target)
total_LPIPS += LPIPS(out, target)
print("PSNR_val: ", total_PSNR/val_iter)
print("SSIM_val: ", total_SSIM/val_iter)
print("LPIPS_val: ", total_LPIPS/val_iter)