52Hz's picture
Create predict.py
44b3c84
raw
history blame
3.04 kB
import cog
import tempfile
from pathlib import Path
import argparse
import shutil
import os
import glob
import torch
from skimage import img_as_ubyte
from PIL import Image
from model.SRMNet import SRMNet
from main_test_SRMNet import save_img, setup
import torchvision.transforms.functional as TF
import torch.nn.functional as F
class Predictor(cog.Predictor):
def setup(self):
model_dir = 'experiments/pretrained_models/AWGN_denoising_SRMNet.pth'
parser = argparse.ArgumentParser(description='Demo Image Denoising')
parser.add_argument('--input_dir', default='./test/', type=str, help='Input images')
parser.add_argument('--result_dir', default='./result/', type=str, help='Directory for results')
parser.add_argument('--weights',
default='./checkpoints/SRMNet_real_denoise/models/model_bestPSNR.pth', type=str,
help='Path to weights')
self.args = parser.parse_args()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@cog.input("image", type=Path, help="input image")
def predict(self, image):
# set input folder
input_dir = 'input_cog_temp'
os.makedirs(input_dir, exist_ok=True)
input_path = os.path.join(input_dir, os.path.basename(image))
shutil.copy(str(image), input_path)
# Load corresponding models architecture and weights
model = SRMNet()
model.eval()
model = model.to(self.device)
folder, save_dir = setup(self.args)
os.makedirs(save_dir, exist_ok=True)
out_path = Path(tempfile.mkdtemp()) / "out.png"
mul = 16
for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))):
img = Image.open(file_).convert('RGB')
input_ = TF.to_tensor(img).unsqueeze(0).cuda()
# Pad the input if not_multiple_of 8
h, w = input_.shape[2], input_.shape[3]
H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul
padh = H - h if h % mul != 0 else 0
padw = W - w if w % mul != 0 else 0
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
with torch.no_grad():
restored = model(input_)
restored = torch.clamp(restored, 0, 1)
restored = restored[:, :, :h, :w]
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
save_img(str(out_path), restored)
clean_folder(input_dir)
return out_path
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))