import os import torch import numpy as np from torchvision import transforms from PIL import Image import time import torchvision import argparse from models.SCET import SCET def inference_img(img_path,Net,device): low_image = Image.open(img_path).convert('RGB') enhance_transforms = transforms.Compose([ transforms.ToTensor() ]) with torch.no_grad(): low_image = enhance_transforms(low_image) low_image = low_image.unsqueeze(0) start = time.time() restored2 = Net(low_image.to(device)) end = time.time() return restored2,end-start if __name__ == '__main__': parser=argparse.ArgumentParser() parser.add_argument('--test_path',type=str,required=True,help='Path to test') parser.add_argument('--save_path',type=str,required=True,help='Path to save') parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint') parser.add_argument('--scale',type=int,default=4,help='scale factor') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') opt = parser.parse_args() if not os.path.isdir(opt.save_path): os.mkdir(opt.save_path) if opt.scale == 3: Net = SCET(63, 128, opt.scale).eval() else: Net = SCET(64, 128, opt.scale).eval() Net.load_state_dict(torch.load(opt.pk_path)) Net=Net.to(device) image=opt.test_path print(image) restored2,time_num=inference_img(image,Net,device) torchvision.utils.save_image(restored2,opt.save_path+'output.png')