52Hz commited on
Commit
61f07f3
1 Parent(s): bd78b3e

Create main_test_SUNet.py

Browse files
Files changed (1) hide show
  1. main_test_SUNet.py +136 -0
main_test_SUNet.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import numpy as np
5
+ from collections import OrderedDict
6
+ from skimage import img_as_ubyte
7
+ import os
8
+ import torch
9
+ import requests
10
+ from PIL import Image
11
+ import math
12
+ import yaml
13
+ import torchvision.transforms.functional as TF
14
+ import torch.nn.functional as F
15
+ from natsort import natsorted
16
+ from model.SUNet import SUNet_model
17
+
18
+ with open('training.yaml', 'r') as config:
19
+ opt = yaml.safe_load(config)
20
+
21
+
22
+
23
+ def main():
24
+ parser = argparse.ArgumentParser(description='Demo Image Restoration')
25
+ parser.add_argument('--input_dir', default='test/', type=str, help='Input images')
26
+ parser.add_argument('--window_size', default=8, type=int, help='window size')
27
+ parser.add_argument('--size', default=256, type=int, help='model image patch size')
28
+ parser.add_argument('--stride', default=128, type=int, help='reconstruction stride')
29
+ parser.add_argument('--result_dir', default='result/', type=str, help='Directory for results')
30
+ parser.add_argument('--weights',
31
+ default='experiments/pretrained_models/AWGN_denoising_SUNet.pth', type=str,
32
+ help='Path to weights')
33
+
34
+ args = parser.parse_args()
35
+
36
+ inp_dir = args.input_dir
37
+ out_dir = args.result_dir
38
+
39
+ os.makedirs(out_dir, exist_ok=True)
40
+
41
+ files = natsorted(glob.glob(os.path.join(inp_dir, '*')))
42
+
43
+ if len(files) == 0:
44
+ raise Exception(f"No files found at {inp_dir}")
45
+
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+
48
+ # Load corresponding models architecture and weights
49
+ model = SUNet_model(opt)
50
+ model = model.to(device)
51
+ model.eval()
52
+ load_checkpoint(model, args.weights)
53
+ stride = args.stride
54
+ model_img = args.size
55
+
56
+ for file_ in files:
57
+ img = Image.open(file_).convert('RGB')
58
+ input_ = TF.to_tensor(img).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ # pad to multiple of 256
61
+ square_input_, mask, max_wh = overlapped_square(input_.to(device), kernel=model_img, stride=stride)
62
+ output_patch = torch.zeros(square_input_[0].shape).type_as(square_input_[0])
63
+ for i, data in enumerate(square_input_):
64
+ s = time.time()
65
+ restored = model(square_input_[i])
66
+ f = time.time()
67
+ print(f-s)
68
+ if i == 0:
69
+ output_patch += restored
70
+ else:
71
+ output_patch = torch.cat([output_patch, restored], dim=0)
72
+
73
+ B, C, PH, PW = output_patch.shape
74
+ weight = torch.ones(B, C, PH, PH).type_as(output_patch) # weight_mask
75
+
76
+ patch = output_patch.contiguous().view(B, C, -1, model_img*model_img)
77
+ patch = patch.permute(2, 1, 3, 0) # B, C, K*K, #patches
78
+ patch = patch.contiguous().view(1, C*model_img*model_img, -1)
79
+
80
+ weight_mask = weight.contiguous().view(B, C, -1, model_img * model_img)
81
+ weight_mask = weight_mask.permute(2, 1, 3, 0) # B, C, K*K, #patches
82
+ weight_mask = weight_mask.contiguous().view(1, C * model_img * model_img, -1)
83
+
84
+ restored = F.fold(patch, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
85
+ we_mk = F.fold(weight_mask, output_size=(max_wh, max_wh), kernel_size=model_img, stride=stride)
86
+ restored /= we_mk
87
+
88
+ restored = torch.masked_select(restored, mask.bool()).reshape(input_.shape)
89
+ restored = torch.clamp(restored, 0, 1)
90
+
91
+ restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
92
+ restored = img_as_ubyte(restored[0])
93
+
94
+ f = os.path.splitext(os.path.split(file_)[-1])[0]
95
+ save_img((os.path.join(out_dir, f + '.png')), restored)
96
+
97
+ def save_img(filepath, img):
98
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
99
+
100
+
101
+ def load_checkpoint(model, weights):
102
+ checkpoint = torch.load(weights, map_location=torch.device('cpu'))
103
+ try:
104
+ model.load_state_dict(checkpoint["state_dict"])
105
+ except:
106
+ state_dict = checkpoint["state_dict"]
107
+ new_state_dict = OrderedDict()
108
+ for k, v in state_dict.items():
109
+ name = k[7:] # remove `module.`
110
+ new_state_dict[name] = v
111
+ model.load_state_dict(new_state_dict)
112
+
113
+ def overlapped_square(timg, kernel=256, stride=128):
114
+ patch_images = []
115
+ b, c, h, w = timg.size()
116
+ # 321, 481
117
+ X = int(math.ceil(max(h, w) / float(kernel)) * kernel)
118
+ img = torch.zeros(1, 3, X, X).type_as(timg) # 3, h, w
119
+ mask = torch.zeros(1, 1, X, X).type_as(timg)
120
+
121
+ img[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)] = timg
122
+ mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1.0)
123
+
124
+ patch = img.unfold(3, kernel, stride).unfold(2, kernel, stride)
125
+ patch = patch.contiguous().view(b, c, -1, kernel, kernel) # B, C, #patches, K, K
126
+ patch = patch.permute(2, 0, 1, 4, 3) # patches, B, C, K, K
127
+
128
+ for each in range(len(patch)):
129
+ patch_images.append(patch[each])
130
+
131
+ return patch_images, mask, X
132
+
133
+
134
+
135
+ if __name__ == '__main__':
136
+ main()