|
import cog |
|
import tempfile |
|
from pathlib import Path |
|
import argparse |
|
import shutil |
|
import os |
|
import glob |
|
import torch |
|
from skimage import img_as_ubyte |
|
from PIL import Image |
|
from model.SRMNet import SRMNet |
|
from main_test_SRMNet import save_img, setup |
|
import torchvision.transforms.functional as TF |
|
import torch.nn.functional as F |
|
|
|
class Predictor(cog.Predictor): |
|
def setup(self): |
|
model_dir = 'experiments/pretrained_models/AWGN_denoising_SRMNet.pth' |
|
|
|
parser = argparse.ArgumentParser(description='Demo Image Denoising') |
|
parser.add_argument('--input_dir', default='./test/', type=str, help='Input images') |
|
parser.add_argument('--result_dir', default='./result/', type=str, help='Directory for results') |
|
parser.add_argument('--weights', |
|
default='./checkpoints/SRMNet_real_denoise/models/model_bestPSNR.pth', type=str, |
|
help='Path to weights') |
|
|
|
self.args = parser.parse_args() |
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
@cog.input("image", type=Path, help="input image") |
|
|
|
def predict(self, image): |
|
|
|
|
|
input_dir = 'input_cog_temp' |
|
os.makedirs(input_dir, exist_ok=True) |
|
input_path = os.path.join(input_dir, os.path.basename(image)) |
|
shutil.copy(str(image), input_path) |
|
|
|
|
|
model = SRMNet() |
|
model.eval() |
|
model = model.to(self.device) |
|
|
|
folder, save_dir = setup(self.args) |
|
os.makedirs(save_dir, exist_ok=True) |
|
|
|
out_path = Path(tempfile.mkdtemp()) / "out.png" |
|
mul = 16 |
|
for file_ in sorted(glob.glob(os.path.join(folder, '*.PNG'))): |
|
img = Image.open(file_).convert('RGB') |
|
input_ = TF.to_tensor(img).unsqueeze(0).cuda() |
|
|
|
|
|
h, w = input_.shape[2], input_.shape[3] |
|
H, W = ((h + mul) // mul) * mul, ((w + mul) // mul) * mul |
|
padh = H - h if h % mul != 0 else 0 |
|
padw = W - w if w % mul != 0 else 0 |
|
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect') |
|
with torch.no_grad(): |
|
restored = model(input_) |
|
|
|
restored = torch.clamp(restored, 0, 1) |
|
restored = restored[:, :, :h, :w] |
|
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() |
|
restored = img_as_ubyte(restored[0]) |
|
|
|
save_img(str(out_path), restored) |
|
clean_folder(input_dir) |
|
return out_path |
|
|
|
|
|
def clean_folder(folder): |
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
try: |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
except Exception as e: |
|
print('Failed to delete %s. Reason: %s' % (file_path, e)) |