Spaces:
Running
Running
File size: 5,501 Bytes
61f07f3 bbc278d 61f07f3 bbc278d 61f07f3 84f440e 61f07f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
import cv2
import glob
import numpy as np
from collections import OrderedDict
from skimage import img_as_ubyte
import os
import torch
import requests
from PIL import Image
import math
import yaml
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from natsort import natsorted
from model.SUNet import SUNet_model
with open('training.yaml', 'r') as config:
opt = yaml.safe_load(config)
def clean_folder(folder):
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def main():
parser = argparse.ArgumentParser(description='Demo Image Restoration')
parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
parser.add_argument('--window_size', default=8, type=int, help='window size')
parser.add_argument('--size', default=256, type=int, help='model image patch size')
parser.add_argument('--stride', default=128, type=int, help='reconstruction stride')
parser.add_argument('--result_dir', default='result/', type=str, help='Directory for results')
parser.add_argument('--weights',
default='experiments/pretrained_models/AWGN_denoising_SUNet.pth', type=str,
help='Path to weights')
args = parser.parse_args()
inp_dir = args.input_dir
out_dir = args.result_dir
os.makedirs(out_dir, exist_ok=True)
files = natsorted(glob.glob(os.path.join(inp_dir, '*')))
if len(files) == 0:
raise Exception(f"No files found at {inp_dir}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load corresponding models architecture and weights
model = SUNet_model(opt)
model = model.to(device)
model.eval()
load_checkpoint(model, args.weights)
stride = args.stride
model_img = args.size
for file_ in files:
img = Image.open(file_).convert('RGB')
input_ = TF.to_tensor(img).unsqueeze(0).to(device)
with torch.no_grad():
# pad to multiple of 256
square_input_, mask, max_wh = overlapped_square(input_.to(device), kernel=model_img, stride=stride)
output_patch = torch.zeros(square_input_[0].shape).type_as(square_input_[0])
for i, data in enumerate(square_input_):
restored = model(square_input_[i])
if i == 0:
output_patch += restored
else:
output_patch = torch.cat([output_patch, restored], dim=0)
B, C, PH, PW = output_patch.shape
weight = torch.ones(B, C, PH, PH).type_as(output_patch) # weight_mask
patch = output_patch.contiguous().view(B, C, -1, model_img*model_img)
patch = patch.permute(2, 1, 3, 0) # B, C, K*K, #patches
patch = patch.contiguous().view(1, C*model_img*model_img, -1)
weight_mask = weight.contiguous().view(B, C, -1, model_img * model_img)
weight_mask = weight_mask.permute(2, 1, 3, 0) # B, C, K*K, #patches
weight_mask = weight_mask.contiguous().view(1, C * model_img * model_img, -1)
restored = F.fold(patch, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
we_mk = F.fold(weight_mask, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
restored /= we_mk
restored = torch.masked_select(restored, mask.bool()).reshape(input_.shape)
restored = torch.clamp(restored, 0, 1)
restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
restored = img_as_ubyte(restored[0])
f = os.path.splitext(os.path.split(file_)[-1])[0]
save_img((os.path.join(out_dir, f + '.png')), restored)
clean_folder(inp_dir)
def save_img(filepath, img):#
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
def load_checkpoint(model, weights):
checkpoint = torch.load(weights, map_location=torch.device('cpu'))
try:
model.load_state_dict(checkpoint["state_dict"])
except:
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
def overlapped_square(timg, kernel=256, stride=128):
patch_images = []
b, c, h, w = timg.size()
# 321, 481
X = int(math.ceil(max(h, w) / float(kernel)) * kernel)
img = torch.zeros(1, 3, X, X).type_as(timg) # 3, h, w
mask = torch.zeros(1, 1, X, X).type_as(timg)
img[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1.0)
patch = img.unfold(3, kernel, stride).unfold(2, kernel, stride)
patch = patch.contiguous().view(b, c, -1, kernel, kernel) # B, C, #patches, K, K
patch = patch.permute(2, 0, 1, 4, 3) # patches, B, C, K, K
for each in range(len(patch)):
patch_images.append(patch[each])
return patch_images, mask, X
if __name__ == '__main__':
main() |