jjourney1125 commited on
Commit
5fd5bbc
1 Parent(s): 195ab71

Add main_test_swin2sr module

Browse files
Files changed (1) hide show
  1. main_test_swin2sr.py +302 -0
main_test_swin2sr.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import numpy as np
5
+ from collections import OrderedDict
6
+ import os
7
+ import torch
8
+ import requests
9
+
10
+ from models.network_swin2sr import Swin2SR as net
11
+ from utils import util_calculate_psnr_ssim as util
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--task', type=str, default='color_dn', help='classical_sr, lightweight_sr, real_sr, '
17
+ 'gray_dn, color_dn, jpeg_car, color_jpeg_car')
18
+ parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') # 1 for dn and jpeg car
19
+ parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50')
20
+ parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40')
21
+ parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training Swin2SR. '
22
+ 'Just used to differentiate two different settings in Table 2 of the paper. '
23
+ 'Images are NOT tested patch by patch.')
24
+ parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')
25
+ parser.add_argument('--model_path', type=str,
26
+ default='model_zoo/swin2sr/Swin2SR_ClassicalSR_X2_64.pth')
27
+ parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder')
28
+ parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder')
29
+ parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
30
+ parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
31
+ parser.add_argument('--save_img_only', default=False, action='store_true', help='save image and do not evaluate')
32
+ args = parser.parse_args()
33
+
34
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35
+ # set up model
36
+ if os.path.exists(args.model_path):
37
+ print(f'loading model from {args.model_path}')
38
+ else:
39
+ os.makedirs(os.path.dirname(args.model_path), exist_ok=True)
40
+ url = 'https://github.com/mv-lab/swin2sr/releases/download/v0.0.1/{}'.format(os.path.basename(args.model_path))
41
+ r = requests.get(url, allow_redirects=True)
42
+ print(f'downloading model {args.model_path}')
43
+ open(args.model_path, 'wb').write(r.content)
44
+
45
+ model = define_model(args)
46
+ model.eval()
47
+ model = model.to(device)
48
+
49
+ # setup folder and path
50
+ folder, save_dir, border, window_size = setup(args)
51
+ os.makedirs(save_dir, exist_ok=True)
52
+ test_results = OrderedDict()
53
+ test_results['psnr'] = []
54
+ test_results['ssim'] = []
55
+ test_results['psnr_y'] = []
56
+ test_results['ssim_y'] = []
57
+ test_results['psnrb'] = []
58
+ test_results['psnrb_y'] = []
59
+ psnr, ssim, psnr_y, ssim_y, psnrb, psnrb_y = 0, 0, 0, 0, 0, 0
60
+
61
+ for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))):
62
+ # read image
63
+ imgname, img_lq, img_gt = get_image_pair(args, path) # image to HWC-BGR, float32
64
+ img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
65
+ img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
66
+
67
+ # inference
68
+ with torch.no_grad():
69
+ # pad input image to be a multiple of window_size
70
+ _, _, h_old, w_old = img_lq.size()
71
+ h_pad = (h_old // window_size + 1) * window_size - h_old
72
+ w_pad = (w_old // window_size + 1) * window_size - w_old
73
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
74
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
75
+ output = test(img_lq, model, args, window_size)
76
+
77
+ if args.task == 'compressed_sr':
78
+ output = output[0][..., :h_old * args.scale, :w_old * args.scale]
79
+ else:
80
+ output = output[..., :h_old * args.scale, :w_old * args.scale]
81
+
82
+ # save image
83
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
84
+ if output.ndim == 3:
85
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
86
+ output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
87
+ cv2.imwrite(f'{save_dir}/{imgname}_Swin2SR.png', output)
88
+
89
+
90
+ # evaluate psnr/ssim/psnr_b
91
+ if img_gt is not None:
92
+ img_gt = (img_gt * 255.0).round().astype(np.uint8) # float32 to uint8
93
+ img_gt = img_gt[:h_old * args.scale, :w_old * args.scale, ...] # crop gt
94
+ img_gt = np.squeeze(img_gt)
95
+
96
+ psnr = util.calculate_psnr(output, img_gt, crop_border=border)
97
+ ssim = util.calculate_ssim(output, img_gt, crop_border=border)
98
+ test_results['psnr'].append(psnr)
99
+ test_results['ssim'].append(ssim)
100
+ if img_gt.ndim == 3: # RGB image
101
+ psnr_y = util.calculate_psnr(output, img_gt, crop_border=border, test_y_channel=True)
102
+ ssim_y = util.calculate_ssim(output, img_gt, crop_border=border, test_y_channel=True)
103
+ test_results['psnr_y'].append(psnr_y)
104
+ test_results['ssim_y'].append(ssim_y)
105
+ if args.task in ['jpeg_car', 'color_jpeg_car']:
106
+ psnrb = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=False)
107
+ test_results['psnrb'].append(psnrb)
108
+ if args.task in ['color_jpeg_car']:
109
+ psnrb_y = util.calculate_psnrb(output, img_gt, crop_border=border, test_y_channel=True)
110
+ test_results['psnrb_y'].append(psnrb_y)
111
+ print('Testing {:d} {:20s} - PSNR: {:.2f} dB; SSIM: {:.4f}; PSNRB: {:.2f} dB;'
112
+ 'PSNR_Y: {:.2f} dB; SSIM_Y: {:.4f}; PSNRB_Y: {:.2f} dB.'.
113
+ format(idx, imgname, psnr, ssim, psnrb, psnr_y, ssim_y, psnrb_y))
114
+ else:
115
+ print('Testing {:d} {:20s}'.format(idx, imgname))
116
+
117
+ # summarize psnr/ssim
118
+ if img_gt is not None:
119
+ ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
120
+ ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
121
+ print('\n{} \n-- Average PSNR/SSIM(RGB): {:.2f} dB; {:.4f}'.format(save_dir, ave_psnr, ave_ssim))
122
+ if img_gt.ndim == 3:
123
+ ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
124
+ ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
125
+ print('-- Average PSNR_Y/SSIM_Y: {:.2f} dB; {:.4f}'.format(ave_psnr_y, ave_ssim_y))
126
+ if args.task in ['jpeg_car', 'color_jpeg_car']:
127
+ ave_psnrb = sum(test_results['psnrb']) / len(test_results['psnrb'])
128
+ print('-- Average PSNRB: {:.2f} dB'.format(ave_psnrb))
129
+ if args.task in ['color_jpeg_car']:
130
+ ave_psnrb_y = sum(test_results['psnrb_y']) / len(test_results['psnrb_y'])
131
+ print('-- Average PSNRB_Y: {:.2f} dB'.format(ave_psnrb_y))
132
+
133
+
134
+ def define_model(args):
135
+ # 001 classical image sr
136
+ if args.task == 'classical_sr':
137
+ model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
138
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
139
+ mlp_ratio=2, upsampler='pixelshuffle', resi_connection='1conv')
140
+ param_key_g = 'params'
141
+
142
+ # 002 lightweight image sr
143
+ # use 'pixelshuffledirect' to save parameters
144
+ elif args.task in ['lightweight_sr']:
145
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
146
+ img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
147
+ mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
148
+ param_key_g = 'params'
149
+
150
+ elif args.task == 'compressed_sr':
151
+ model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
152
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
153
+ mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
154
+ param_key_g = 'params'
155
+
156
+ # 003 real-world image sr
157
+ elif args.task == 'real_sr':
158
+ if not args.large_model:
159
+ # use 'nearest+conv' to avoid block artifacts
160
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
161
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
162
+ mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv')
163
+ else:
164
+ # larger model size; use '3conv' to save parameters and memory; use ema for GAN training
165
+ model = net(upscale=args.scale, in_chans=3, img_size=64, window_size=8,
166
+ img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
167
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
168
+ mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
169
+ param_key_g = 'params_ema'
170
+
171
+ # 006 grayscale JPEG compression artifact reduction
172
+ # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
173
+ elif args.task == 'jpeg_car':
174
+ model = net(upscale=1, in_chans=1, img_size=126, window_size=7,
175
+ img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
176
+ mlp_ratio=2, upsampler='', resi_connection='1conv')
177
+ param_key_g = 'params'
178
+
179
+ # 006 color JPEG compression artifact reduction
180
+ # use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's sligtly better than 1
181
+ elif args.task == 'color_jpeg_car':
182
+ model = net(upscale=1, in_chans=3, img_size=126, window_size=7,
183
+ img_range=255., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
184
+ mlp_ratio=2, upsampler='', resi_connection='1conv')
185
+ param_key_g = 'params'
186
+
187
+ pretrained_model = torch.load(args.model_path)
188
+ model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
189
+
190
+ return model
191
+
192
+
193
+ def setup(args):
194
+ # 001 classical image sr/ 002 lightweight image sr
195
+ if args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
196
+ save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
197
+ if args.save_img_only:
198
+ folder = args.folder_lq
199
+ else:
200
+ folder = args.folder_gt
201
+ border = args.scale
202
+ window_size = 8
203
+
204
+ # 003 real-world image sr
205
+ elif args.task in ['real_sr']:
206
+ save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
207
+ if args.large_model:
208
+ save_dir += '_large'
209
+ folder = args.folder_lq
210
+ border = 0
211
+ window_size = 8
212
+
213
+ # 006 JPEG compression artifact reduction
214
+ elif args.task in ['jpeg_car', 'color_jpeg_car']:
215
+ save_dir = f'results/swin2sr_{args.task}_jpeg{args.jpeg}'
216
+ folder = args.folder_gt
217
+ border = 0
218
+ window_size = 7
219
+
220
+ return folder, save_dir, border, window_size
221
+
222
+
223
+ def get_image_pair(args, path):
224
+ (imgname, imgext) = os.path.splitext(os.path.basename(path))
225
+
226
+ # 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)
227
+ if args.task in ['classical_sr', 'lightweight_sr']:
228
+ if args.save_img_only:
229
+ img_gt = None
230
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
231
+ else:
232
+ img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
233
+ img_lq = cv2.imread(f'{args.folder_lq}/{imgname}x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(
234
+ np.float32) / 255.
235
+
236
+ elif args.task in ['compressed_sr']:
237
+ if args.save_img_only:
238
+ img_gt = None
239
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
240
+ else:
241
+ img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
242
+ img_lq = cv2.imread(f'{args.folder_lq}/{imgname}.jpg', cv2.IMREAD_COLOR).astype(
243
+ np.float32) / 255.
244
+
245
+ # 003 real-world image sr (load lq image only)
246
+ elif args.task in ['real_sr', 'lightweight_sr_infer']:
247
+ img_gt = None
248
+ img_lq = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
249
+
250
+ # 006 grayscale JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
251
+ elif args.task in ['jpeg_car']:
252
+ img_gt = cv2.imread(path, cv2.IMREAD_UNCHANGED)
253
+ if img_gt.ndim != 2:
254
+ img_gt = util.bgr2ycbcr(img_gt, y_only=True)
255
+ result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
256
+ img_lq = cv2.imdecode(encimg, 0)
257
+ img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
258
+ img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
259
+
260
+ # 006 JPEG compression artifact reduction (load gt image and generate lq image on-the-fly)
261
+ elif args.task in ['color_jpeg_car']:
262
+ img_gt = cv2.imread(path)
263
+ result, encimg = cv2.imencode('.jpg', img_gt, [int(cv2.IMWRITE_JPEG_QUALITY), args.jpeg])
264
+ img_lq = cv2.imdecode(encimg, 1)
265
+ img_gt = img_gt.astype(np.float32)/ 255.
266
+ img_lq = img_lq.astype(np.float32)/ 255.
267
+
268
+ return imgname, img_lq, img_gt
269
+
270
+
271
+ def test(img_lq, model, args, window_size):
272
+ if args.tile is None:
273
+ # test the image as a whole
274
+ output = model(img_lq)
275
+ else:
276
+ # test the image tile by tile
277
+ b, c, h, w = img_lq.size()
278
+ tile = min(args.tile, h, w)
279
+ assert tile % window_size == 0, "tile size should be a multiple of window_size"
280
+ tile_overlap = args.tile_overlap
281
+ sf = args.scale
282
+
283
+ stride = tile - tile_overlap
284
+ h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
285
+ w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
286
+ E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
287
+ W = torch.zeros_like(E)
288
+
289
+ for h_idx in h_idx_list:
290
+ for w_idx in w_idx_list:
291
+ in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
292
+ out_patch = model(in_patch)
293
+ out_patch_mask = torch.ones_like(out_patch)
294
+
295
+ E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
296
+ W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
297
+ output = E.div_(W)
298
+
299
+ return output
300
+
301
+ if __name__ == '__main__':
302
+ main()