Image_Restoration / Underwater.py
BraveLizzy's picture
Upload Underwater.py
d85556c
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
from model.Underwater import GeneratorFunieGAN
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
# preprocess input image
data_lowlight = (np.asarray(low_image)/255.0)
data_lowlight = torch.from_numpy(data_lowlight).float()
data_lowlight = data_lowlight.permute(2,0,1)
low_image = data_lowlight.unsqueeze(0)
with torch.no_grad():
start = time.time()
restored = Net(low_image)
end = time.time()
restored = torch.clamp(restored, 0, 1)
return restored,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--checkpoint',type=str,default='checkpoints/Underwater.pth',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
Net = GeneratorFunieGAN()
Net.load_state_dict(torch.load(opt.checkpoint, map_location=torch.device('cpu')))
Net = Net.eval()
image = opt.test_path
extension = os.path.splitext(os.path.basename(str(image)))[1]
print(image)
restored2,time_num = inference_img(image,Net)
# outpath= f'output/out.{extension}'
torchvision.utils.save_image(restored2,opt.save_path+'output.png')