Deploy_Restoration / Underwater.py
AlexZou's picture
Upload 4 files
7970501
raw
history blame
No virus
1.43 kB
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
from net.Ushape_Trans import *
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
with torch.no_grad():
low_image = enhance_transforms(low_image)
low_image = low_image.unsqueeze(0)
start = time.time()
restored2 = Net(low_image)
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/underwater.pth',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
Net = Generator()
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
Net = Net.eval()
image = opt.test_path
print(image)
restored2,time_num = inference_img(image,Net)
torchvision.utils.save_image(restored2,opt.save_path+'output.png')