OneRestore / train_OneRestore_multi-gpu.py
gy65896's picture
Upload 36 files
2940390 verified
raw
history blame
7.78 kB
import os, time, torch, argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image as imwrite
import numpy as np
from torchvision import transforms
from makedataset import Dataset
from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint
from model.loss import Total_loss
from model.Embedder import Embedder
from model.OneRestore import OneRestore
from torch.utils.data.distributed import DistributedSampler
from PIL import Image
torch.distributed.init_process_group(backend="nccl")
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
transform_resize = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor()
])
def main(args):
print('> Model Initialization...')
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, local_rank=local_rank, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr)
loss = Total_loss(args)
print('> Loading dataset...')
data = Dataset(args.train_input)
dataset = DataLoader(dataset=data, batch_size=args.bs,
shuffle=False,
num_workers=args.num_works,
pin_memory=True,drop_last=False,
sampler=DistributedSampler(data,shuffle=True))
print('> Start training...')
start_all = time.time()
train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device)
end_all = time.time()
print('Whloe Training Time:' +str(end_all-start_all)+'s.')
def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device):
metric = []
for epoch in range(cur_epoch, args.epoch):
optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr)
learnrate = optimizer.param_groups[-1]['lr']
restorer.train()
for i, data in enumerate(dataset,0):
pos, inp, neg = data_process(data, args, device)
text_embedding,_,_ = embedder(inp[1],'text_encoder')
out = restorer(inp[0], text_embedding)
restorer.zero_grad()
total_loss = loss(inp, pos, neg, out)
total_loss.backward()
optimizer.step()
mse = tensor_metric(pos,out, 'MSE', data_range=1)
psnr = tensor_metric(pos,out, 'PSNR', data_range=1)
ssim = tensor_metric(pos,out, 'SSIM', data_range=1)
print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \
len(dataset), learnrate, total_loss.item(), mse, psnr, ssim))
psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch)
metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2])
print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1))
load_excel(metric)
save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\
args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2)
def test(args, restorer, embedder, device, epoch=-1):
combine_type = args.degr_type
psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0
os.makedirs(args.output,exist_ok=True)
for i in range(len(combine_type)-1):
file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/')
for j in range(len(file_list)):
hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}')
lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}')
restorer.eval()
with torch.no_grad():
lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
starttime = time.time()
text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder')
text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder')
out_1 = restorer(lq_re, text_embedding_1)
if text_1 != text_2:
print(text_1, text_2)
out_2 = restorer(lq_re, text_embedding_2)
else:
out_2 = out_1
endtime1 = time.time()
imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
+ file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1))
# due to the vision problem, you can replace above line by
# imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \
# + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png')
psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1)
ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1)
psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1)
ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1)
print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.')
return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\
psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type))
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "OneRestore Training")
# load model
parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path')
parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path')
parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number')
parser.add_argument("--bs", type=int, default = 4, help = 'batchsize')
parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate')
parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate')
parser.add_argument("--num-works", type=int, default = 4, help = 'number works')
parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights')
parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\
'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type')
parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data')
parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path')
parser.add_argument("--output", type=str, default = "./result/", help = 'output path')
argspar = parser.parse_args()
print_args(argspar)
main(argspar)