Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import numpy as np | |
from torchvision import transforms | |
from PIL import Image | |
import time | |
import torchvision | |
import cv2 | |
import torchvision.utils as tvu | |
import torch.functional as F | |
import argparse | |
from model.ESTNet import Network | |
def inference_img(img_path,Net): | |
low_image = Image.open(img_path).convert('RGB') | |
# preprocess input image | |
data_lowlight = (np.asarray(low_image)/255.0) | |
data_lowlight = torch.from_numpy(data_lowlight).float() | |
data_lowlight = data_lowlight.permute(2,0,1) | |
low_image = data_lowlight.unsqueeze(0) | |
with torch.no_grad(): | |
start = time.time() | |
restored = Net(low_image) | |
end = time.time() | |
restored = torch.clamp(restored, 0, 1) | |
return restored,end-start | |
if __name__ == '__main__': | |
parser=argparse.ArgumentParser() | |
parser.add_argument('--test_path',type=str,required=True,help='Path to test') | |
parser.add_argument('--save_path',type=str,required=True,help='Path to save') | |
parser.add_argument('--checkpoint',type=str,default='checkpoints/ESTNet.pth',help='Path of the checkpoint') | |
opt = parser.parse_args() | |
if not os.path.isdir(opt.save_path): | |
os.mkdir(opt.save_path) | |
Net = Network(3) | |
Net.load_state_dict(torch.load(opt.checkpoint, map_location=torch.device('cpu'))) | |
Net = Net.eval() | |
image = opt.test_path | |
extension = os.path.splitext(os.path.basename(str(image)))[1] | |
print(image) | |
restored2,time_num = inference_img(image,Net) | |
# outpath= f'output/out.{extension}' | |
torchvision.utils.save_image(restored2,opt.save_path+'output.png') |