Spaces:
Running
Running
import os, time, torch, argparse | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torchvision.utils import save_image as imwrite | |
import numpy as np | |
from torchvision import transforms | |
from makedataset import Dataset | |
from utils.utils import print_args, load_restore_ckpt_with_optim, load_embedder_ckpt, adjust_learning_rate, data_process, tensor_metric, load_excel, save_checkpoint | |
from model.loss import Total_loss | |
from model.Embedder import Embedder | |
from model.OneRestore import OneRestore | |
from torch.utils.data.distributed import DistributedSampler | |
from PIL import Image | |
torch.distributed.init_process_group(backend="nccl") | |
local_rank = torch.distributed.get_rank() | |
torch.cuda.set_device(local_rank) | |
device = torch.device("cuda", local_rank) | |
transform_resize = transforms.Compose([ | |
transforms.Resize([224,224]), | |
transforms.ToTensor() | |
]) | |
def main(args): | |
print('> Model Initialization...') | |
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path) | |
restorer, optimizer, cur_epoch = load_restore_ckpt_with_optim(device, local_rank=local_rank, freeze_model=False, ckpt_name=args.restore_model_path, lr=args.lr) | |
loss = Total_loss(args) | |
print('> Loading dataset...') | |
data = Dataset(args.train_input) | |
dataset = DataLoader(dataset=data, batch_size=args.bs, | |
shuffle=False, | |
num_workers=args.num_works, | |
pin_memory=True,drop_last=False, | |
sampler=DistributedSampler(data,shuffle=True)) | |
print('> Start training...') | |
start_all = time.time() | |
train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device) | |
end_all = time.time() | |
print('Whloe Training Time:' +str(end_all-start_all)+'s.') | |
def train(restorer, embedder, optimizer, loss, cur_epoch, args, dataset, device): | |
metric = [] | |
for epoch in range(cur_epoch, args.epoch): | |
optimizer = adjust_learning_rate(optimizer, epoch, args.adjust_lr) | |
learnrate = optimizer.param_groups[-1]['lr'] | |
restorer.train() | |
for i, data in enumerate(dataset,0): | |
pos, inp, neg = data_process(data, args, device) | |
text_embedding,_,_ = embedder(inp[1],'text_encoder') | |
out = restorer(inp[0], text_embedding) | |
restorer.zero_grad() | |
total_loss = loss(inp, pos, neg, out) | |
total_loss.backward() | |
optimizer.step() | |
mse = tensor_metric(pos,out, 'MSE', data_range=1) | |
psnr = tensor_metric(pos,out, 'PSNR', data_range=1) | |
ssim = tensor_metric(pos,out, 'SSIM', data_range=1) | |
print("[epoch %d][%d/%d] lr :%f Floss: %.4f MSE: %.4f PSNR: %.4f SSIM: %.4f"%(epoch+1, i+1, \ | |
len(dataset), learnrate, total_loss.item(), mse, psnr, ssim)) | |
psnr_t1, ssim_t1, psnr_t2, ssim_t2 = test(args, restorer, embedder, device, epoch) | |
metric.append([psnr_t1, ssim_t1, psnr_t2, ssim_t2]) | |
print("[epoch %d] Test images PSNR1: %.4f SSIM1: %.4f"%(epoch+1, psnr_t1,ssim_t1)) | |
load_excel(metric) | |
save_checkpoint({'epoch': epoch + 1,'state_dict': restorer.state_dict(),'optimizer' : optimizer.state_dict()},\ | |
args.save_model_path, epoch+1, psnr_t1,ssim_t1,psnr_t2,ssim_t2) | |
def test(args, restorer, embedder, device, epoch=-1): | |
combine_type = args.degr_type | |
psnr_1, psnr_2, ssim_1, ssim_2 = 0, 0, 0, 0 | |
os.makedirs(args.output,exist_ok=True) | |
for i in range(len(combine_type)-1): | |
file_list = os.listdir(f'{args.test_input}/{combine_type[i+1]}/') | |
for j in range(len(file_list)): | |
hq = Image.open(f'{args.test_input}/{combine_type[0]}/{file_list[j]}') | |
lq = Image.open(f'{args.test_input}/{combine_type[i+1]}/{file_list[j]}') | |
restorer.eval() | |
with torch.no_grad(): | |
lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
hq = torch.Tensor((np.array(hq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
starttime = time.time() | |
text_embedding_1,_,text_1 = embedder([combine_type[i+1]],'text_encoder') | |
text_embedding_2,_, text_2 = embedder(lq_em,'image_encoder') | |
out_1 = restorer(lq_re, text_embedding_1) | |
if text_1 != text_2: | |
print(text_1, text_2) | |
out_2 = restorer(lq_re, text_embedding_2) | |
else: | |
out_2 = out_1 | |
endtime1 = time.time() | |
imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \ | |
+ file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png', range=(0, 1)) | |
# due to the vision problem, you can replace above line by | |
# imwrite(torch.cat((lq_re, out_1, out_2, hq), dim=3), args.output \ | |
# + file_list[j][:-4] + '_' + str(epoch) + '_' + combine_type[i+1] + '.png') | |
psnr_1 += tensor_metric(hq, out_1, 'PSNR', data_range=1) | |
ssim_1 += tensor_metric(hq, out_1, 'SSIM', data_range=1) | |
psnr_2 += tensor_metric(hq, out_2, 'PSNR', data_range=1) | |
ssim_2 += tensor_metric(hq, out_2, 'SSIM', data_range=1) | |
print('The ' + file_list[j][:-4] + ' Time:' + str(endtime1 - starttime) + 's.') | |
return psnr_1 / (len(file_list)*len(combine_type)), ssim_1 / (len(file_list)*len(combine_type)),\ | |
psnr_2 / (len(file_list)*len(combine_type)), ssim_2 / (len(file_list)*len(combine_type)) | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description = "OneRestore Training") | |
# load model | |
parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path') | |
parser.add_argument("--restore-model-path", type=str, default = None, help = 'restore model path') | |
parser.add_argument("--save-model-path", type=str, default = "./ckpts/", help = 'restore model path') | |
parser.add_argument("--epoch", type=int, default = 300, help = 'epoch number') | |
parser.add_argument("--bs", type=int, default = 4, help = 'batchsize') | |
parser.add_argument("--lr", type=float, default = 1e-4, help = 'learning rate') | |
parser.add_argument("--adjust-lr", type=int, default = 30, help = 'adjust learning rate') | |
parser.add_argument("--num-works", type=int, default = 4, help = 'number works') | |
parser.add_argument("--loss-weight", type=tuple, default = (0.6,0.3,0.1), help = 'loss weights') | |
parser.add_argument("--degr-type", type=list, default = ['clear', 'low', 'haze', 'rain', 'snow',\ | |
'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'], help = 'degradation type') | |
parser.add_argument("--train-input", type=str, default = "./dataset.h5", help = 'train data') | |
parser.add_argument("--test-input", type=str, default = "./data/CDD-11_test", help = 'test path') | |
parser.add_argument("--output", type=str, default = "./result/", help = 'output path') | |
argspar = parser.parse_args() | |
print_args(argspar) | |
main(argspar) | |