import os import torch import numpy as np from torchvision import transforms from PIL import Image import time import torchvision import cv2 import torchvision.utils as tvu import torch.functional as F import argparse from model.ESTNet import Network def inference_img(img_path,Net): low_image = Image.open(img_path).convert('RGB') # preprocess input image data_lowlight = (np.asarray(low_image)/255.0) data_lowlight = torch.from_numpy(data_lowlight).float() data_lowlight = data_lowlight.permute(2,0,1) low_image = data_lowlight.unsqueeze(0) with torch.no_grad(): start = time.time() restored = Net(low_image) end = time.time() restored = torch.clamp(restored, 0, 1) return restored,end-start if __name__ == '__main__': parser=argparse.ArgumentParser() parser.add_argument('--test_path',type=str,required=True,help='Path to test') parser.add_argument('--save_path',type=str,required=True,help='Path to save') parser.add_argument('--checkpoint',type=str,default='checkpoints/ESTNet.pth',help='Path of the checkpoint') opt = parser.parse_args() if not os.path.isdir(opt.save_path): os.mkdir(opt.save_path) Net = Network(3) Net.load_state_dict(torch.load(opt.checkpoint, map_location=torch.device('cpu'))) Net = Net.eval() image = opt.test_path extension = os.path.splitext(os.path.basename(str(image)))[1] print(image) restored2,time_num = inference_img(image,Net) # outpath= f'output/out.{extension}' torchvision.utils.save_image(restored2,opt.save_path+'output.png')