|
import os, time, torch, argparse
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
from torchvision.utils import save_image as imwrite
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
from makedataset import Dataset
|
|
from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
|
|
from model.loss import Total_loss
|
|
|
|
from PIL import Image
|
|
|
|
transform_resize = transforms.Compose([
|
|
transforms.Resize([224,224]),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def main(args):
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
print('> Model Initialization...')
|
|
|
|
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
|
|
restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
|
|
loss = Total_loss(args)
|
|
|
|
print('> Loading dataset...')
|
|
data = Dataset(args.train_input)
|
|
dataset = DataLoader(dataset=data, num_workers=args.num_works, batch_size=args.bs, shuffle=True)
|
|
|
|
print('> Start training...')
|
|
start_all = time.time()
|
|
train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
|
|
end_all = time.time()
|
|
print('Whloe Training Time:' +str(end_all-start_all)+'s.')
|
|
|
|
def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
|
|
|
|
metric = []
|
|
for epoch in range(cur_epoch, args.epoch):
|
|
optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
|
|
learnrate = optimizer.param_groups[-1]['lr']
|
|
restorer.train()
|
|
|
|
for i, data in enumerate(dataset,0):
|
|
pos, inp, neg = data_process(data, args, device)
|
|
|
|
text_embedding,_,_ = embedder(inp[1],'text_encoder')
|
|
out = restorer(inp[0], text_embedding)
|
|
|
|
restorer.zero_grad()
|
|
total_loss = loss(inp, pos, neg, out)
|
|
total_loss.backward()
|
|
optimizer.step()
|
|
|
|
mse = tensor_metric(pos,out, 'MSE', data_range=1)
|
|
psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
|
|
ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
|
|
|
|
print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
|
|
len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
|
|
|
|
|
|
psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
|
|
metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
|
|
print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
|
|
|
|
load_excel(metric)
|
|
save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
|
|
args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
|
|
|
|
def test(args, restorer, embedder, device, epoch=-1):
|
|
combine_type = args.degr_type
|
|
psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
|
|
os.makedirs(args.output,exist_ok=True)
|
|
|
|
for i in range(len(combine_type)-1):
|
|
file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
|
|
for j in range(len(file_list)):
|
|
hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
|
|
lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
|
|
restorer.eval()
|
|
with torch.no_grad():
|
|
lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
starttime = time.time()
|
|
|
|
text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
|
|
text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
|
|
out_1 = restorer(lq_re, text_embedding_1)
|
|
if text_1 != text_2:
|
|
print(text_1, text_2)
|
|
out_2 = restorer(lq_re, text_embedding_2)
|
|
else:
|
|
out_2 = out_1
|
|
|
|
endtime1 = time.time()
|
|
|
|
imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
|
|
+ file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
|
|
psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
|
|
ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
|
|
psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
|
|
ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
|
|
print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
|
|
|
|
return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
|
|
psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
|
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description = "OneRestore Training")
|
|
|
|
|
|
parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
|
|
parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
|
|
parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
|
|
|
|
parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
|
|
parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
|
|
parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
|
|
parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
|
|
parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
|
|
parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
|
|
parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
|
|
'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
|
|
|
|
parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
|
|
parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
|
|
parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
|
|
|
|
argspar = parser.parse_args()
|
|
|
|
print_args(argspar)
|
|
main(argspar)
|
|
|