Image_Restoration / Denoise_XRay.py
BraveLizzy's picture
Update Denoise_XRay.py
8704a73
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image
import time
import torchvision
import cv2
import torchvision.utils as tvu
import torch.functional as F
import argparse
from model.ESTNet import Network
def inference_img(img_path,Net):
low_image = Image.open(img_path).convert('RGB')
# preprocess input image
data_lowlight = (np.asarray(low_image)/255.0)
data_lowlight = torch.from_numpy(data_lowlight).float()
data_lowlight = data_lowlight.permute(2,0,1)
low_image = data_lowlight.unsqueeze(0)
with torch.no_grad():
start = time.time()
restored = Net(low_image)
end = time.time()
restored = torch.clamp(restored, 0, 1)
return restored,end-start
if __name__ == '__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--test_path',type=str,required=True,help='Path to test')
parser.add_argument('--save_path',type=str,required=True,help='Path to save')
parser.add_argument('--checkpoint',type=str,default='checkpoints/ESTNet.pth',help='Path of the checkpoint')
opt = parser.parse_args()
if not os.path.isdir(opt.save_path):
os.mkdir(opt.save_path)
Net = Network(3)
Net.load_state_dict(torch.load(opt.checkpoint, map_location=torch.device('cpu')))
Net = Net.eval()
image = opt.test_path
extension = os.path.splitext(os.path.basename(str(image)))[1]
print(image)
restored2,time_num = inference_img(image,Net)
# outpath= f'output/out.{extension}'
torchvision.utils.save_image(restored2,opt.save_path+'output.png')