SUNet_AWGN_denoising / main_test_SUNet.py
52Hz's picture
Update main_test_SUNet.py
b532ce8 verified
raw
history blame
5.57 kB
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
from skimage import img_as_ubyte
import os
import torch
import requests
from PIL import Image
import math
import yaml
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from natsort import natsorted
from model.SUNet import SUNet_model
with open('training.yaml', 'r') as config:
opt = yaml.safe_load(config)
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def main():
parser = argparse.ArgumentParser(description='Demo Image Restoration')
parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
parser.add_argument('--window_size', default=8, type=int, help='window size')
parser.add_argument('--size', default=256, type=int, help='model image patch size')
parser.add_argument('--stride', default=128, type=int, help='reconstruction stride')
parser.add_argument('--result_dir', default='result/', type=str, help='Directory for results')
parser.add_argument('--weights',
default='experiments/pretrained_models/AWGN_denoising_SUNet.pth', type=str,
help='Path to weights')
args = parser.parse_args()
inp_dir = args.input_dir
out_dir = args.result_dir
if os.path.exists(out_dir):
shutil.rmtree(out_dir)
os.makedirs(out_dir, exist_ok=True)
files = natsorted(glob.glob(os.path.join(inp_dir, '*')))
if len(files) == 0:
raise Exception(f"No files found at {inp_dir}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load corresponding models architecture and weights
model = SUNet_model(opt)
model = model.to(device)
model.eval()
load_checkpoint(model, args.weights)
stride = args.stride
model_img = args.size
for file_ in files:
img = Image.open(file_).convert('RGB')
input_ = TF.to_tensor(img).unsqueeze(0).to(device)
with torch.no_grad():
# pad to multiple of 256
square_input_, mask, max_wh = overlapped_square(input_.to(device), kernel=model_img, stride=stride)
output_patch = torch.zeros(square_input_[0].shape).type_as(square_input_[0])
for i, data in enumerate(square_input_):
restored = model(square_input_[i])
if i == 0:
output_patch += restored
else:
output_patch = torch.cat([output_patch, restored], dim=0)
B, C, PH, PW = output_patch.shape
weight = torch.ones(B, C, PH, PH).type_as(output_patch) # weight_mask
patch = output_patch.contiguous().view(B, C, -1, model_img*model_img)
patch = patch.permute(2, 1, 3, 0) # B, C, K*K, #patches
patch = patch.contiguous().view(1, C*model_img*model_img, -1)
weight_mask = weight.contiguous().view(B, C, -1, model_img * model_img)
weight_mask = weight_mask.permute(2, 1, 3, 0) # B, C, K*K, #patches
weight_mask = weight_mask.contiguous().view(1, C * model_img * model_img, -1)
restored = F.fold(patch, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
we_mk = F.fold(weight_mask, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
restored /= we_mk
restored = torch.masked_select(restored, mask.bool()).reshape(input_.shape)
restored = torch.clamp(restored, 0, 1)
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
f = os.path.splitext(os.path.split(file_)[-1])[0]
save_img((os.path.join(out_dir, f + '.png')), restored)
clean_folder(inp_dir)
def save_img(filepath, img):#
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_checkpoint(model, weights):
checkpoint = torch.load(weights, map_location=torch.device('cpu'))
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def overlapped_square(timg, kernel=256, stride=128):
patch_images = []
b, c, h, w = timg.size()
# 321, 481
X = int(math.ceil(max(h, w) / float(kernel)) * kernel)
img = torch.zeros(1, 3, X, X).type_as(timg) # 3, h, w
mask = torch.zeros(1, 1, X, X).type_as(timg)
img[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1.0)
patch = img.unfold(3, kernel, stride).unfold(2, kernel, stride)
patch = patch.contiguous().view(b, c, -1, kernel, kernel) # B, C, #patches, K, K
patch = patch.permute(2, 0, 1, 4, 3) # patches, B, C, K, K
for each in range(len(patch)):
patch_images.append(patch[each])
return patch_images, mask, X
if __name__ == '__main__':
main()