Deploy_Restoration / SuperResolution.py
AlexZou's picture
Upload 4 files
7970501
raw history blame
No virus
1.48 kB
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import argparse
from models.SCET import SCET
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
enhance_transforms = transforms.Compose([
transforms.ToTensor()
])
with torch.no_grad():
low_image = enhance_transforms(low_image)
low_image = low_image.unsqueeze(0)
start = time.time()
restored2 = Net(low_image)
end = time.time()
return restored2,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--pk_path',type=str,default='model_zoo/SRx4.pth',help='Path of the checkpoint')
parser.add_argument('--scale',type=int,default=4,help='scale factor')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
if opt.scale == 3:
Net = SCET(63, 128, opt.scale)
else:
Net = SCET(64, 128, opt.scale)
Net.load_state_dict(torch.load(opt.pk_path, map_location=torch.device('cpu')))
Net=Net.eval()
image=opt.test_path
print(image)
restored2,time_num=inference_img(image,Net)
torchvision.utils.save_image(restored2,opt.save_path+'output.png')