Deploy_Restoration / Lowlight.py
AlexZou's picture
Update Lowlight.py
816136c
raw history blame
No virus
1.43 kB
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
from model.IAT_main import IAT
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.Resize((256,256)),
transforms.ToTensor()
])
with torch.no_grad():
low_image = enhance_transforms(low_image)
low_image = low_image.unsqueeze(0)
start = time.time()
restored2 = Net(low_image)
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/underwater.pth',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
Net = IAT()
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
Net = Net.eval()
image = opt.test_path
print(image)
restored2,time_num = inference_img(image,Net)
torchvision.utils.save_image(restored2,opt.save_path+'output.png')