sunder-ali commited on
Commit
944ec77
1 Parent(s): e6a6443

Upload 2 files

Browse files
Files changed (2) hide show
  1. utils/utils_image.py +734 -0
  2. utils/utils_model.py +100 -0
utils/utils_image.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ import matplotlib.pyplot as plt
10
+ #from mpl_toolkits.mplot3d import Axes3D
11
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
12
+
13
+
14
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
15
+
16
+
17
+ def is_image_file(filename):
18
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19
+
20
+
21
+ def get_timestamp():
22
+ return datetime.now().strftime('%y%m%d-%H%M%S')
23
+
24
+
25
+ def imshow(x, title=None, cbar=False, figsize=None):
26
+ plt.figure(figsize=figsize)
27
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
28
+ if title:
29
+ plt.title(title)
30
+ if cbar:
31
+ plt.colorbar()
32
+ plt.show()
33
+
34
+
35
+ def surf(Z, cmap='rainbow', figsize=None):
36
+ plt.figure(figsize=figsize)
37
+ ax3 = plt.axes(projection='3d')
38
+
39
+ w, h = Z.shape[:2]
40
+ xx = np.arange(0,w,1)
41
+ yy = np.arange(0,h,1)
42
+ X, Y = np.meshgrid(xx, yy)
43
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
44
+ plt.show()
45
+
46
+
47
+ def get_image_paths(dataroot):
48
+ paths = None
49
+ if isinstance(dataroot, str):
50
+ paths = sorted(_get_paths_from_images(dataroot))
51
+ elif isinstance(dataroot, list):
52
+ paths = []
53
+ for i in dataroot:
54
+ paths += sorted(_get_paths_from_images(i))
55
+ return paths
56
+
57
+
58
+ def _get_paths_from_images(path):
59
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
60
+ images = []
61
+ for dirpath, _, fnames in sorted(os.walk(path)):
62
+ for fname in sorted(fnames):
63
+ if is_image_file(fname):
64
+ img_path = os.path.join(dirpath, fname)
65
+ images.append(img_path)
66
+ assert images, '{:s} has no valid image file'.format(path)
67
+ return images
68
+
69
+
70
+ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
71
+ w, h = img.shape[:2]
72
+ patches = []
73
+ if w > p_max and h > p_max:
74
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
75
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
76
+ w1.append(w-p_size)
77
+ h1.append(h-p_size)
78
+ for i in w1:
79
+ for j in h1:
80
+ patches.append(img[i:i+p_size, j:j+p_size,:])
81
+ else:
82
+ patches.append(img)
83
+
84
+ return patches
85
+
86
+
87
+ def imssave(imgs, img_path):
88
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
89
+ for i, img in enumerate(imgs):
90
+ if img.ndim == 3:
91
+ img = img[:, :, [2, 1, 0]]
92
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_{:04d}'.format(i))+'.png')
93
+ cv2.imwrite(new_path, img)
94
+
95
+
96
+ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=512, p_overlap=96, p_max=800):
97
+ paths = get_image_paths(original_dataroot)
98
+ for img_path in paths:
99
+ img = imread_uint(img_path, n_channels=n_channels)
100
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
101
+ imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
102
+
103
+
104
+ def mkdir(path):
105
+ if not os.path.exists(path):
106
+ os.makedirs(path)
107
+
108
+
109
+ def mkdirs(paths):
110
+ if isinstance(paths, str):
111
+ mkdir(paths)
112
+ else:
113
+ for path in paths:
114
+ mkdir(path)
115
+
116
+
117
+ def mkdir_and_rename(path):
118
+ if os.path.exists(path):
119
+ new_name = path + '_archived_' + get_timestamp()
120
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
121
+ os.rename(path, new_name)
122
+ os.makedirs(path)
123
+
124
+ def imread_uint(path, n_channels=3):
125
+ if n_channels == 1:
126
+ img = cv2.imread(path, 0)
127
+ img = np.expand_dims(img, axis=2)
128
+ elif n_channels == 3:
129
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
130
+ if img.ndim == 2:
131
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
132
+ else:
133
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
134
+ return img
135
+
136
+
137
+ def imsave(img, img_path):
138
+ img = np.squeeze(img)
139
+ if img.ndim == 3:
140
+ img = img[:, :, [2, 1, 0]]
141
+ cv2.imwrite(img_path, img)
142
+
143
+ def imwrite(img, img_path):
144
+ img = np.squeeze(img)
145
+ if img.ndim == 3:
146
+ img = img[:, :, [2, 1, 0]]
147
+ cv2.imwrite(img_path, img)
148
+
149
+ def read_img(path):
150
+
151
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
152
+ img = img.astype(np.float32) / 255.
153
+ if img.ndim == 2:
154
+ img = np.expand_dims(img, axis=2)
155
+ if img.shape[2] > 3:
156
+ img = img[:, :, :3]
157
+ return img
158
+
159
+ def uint2single(img):
160
+
161
+ return np.float32(img/255.)
162
+
163
+
164
+ def single2uint(img):
165
+
166
+ return np.uint8((img.clip(0, 1)*255.).round())
167
+
168
+
169
+ def uint162single(img):
170
+
171
+ return np.float32(img/65535.)
172
+
173
+
174
+ def single2uint16(img):
175
+
176
+ return np.uint16((img.clip(0, 1)*65535.).round())
177
+
178
+ def uint2tensor4(img):
179
+ if img.ndim == 2:
180
+ img = np.expand_dims(img, axis=2)
181
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
182
+
183
+ def uint2tensor3(img):
184
+ if img.ndim == 2:
185
+ img = np.expand_dims(img, axis=2)
186
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
187
+
188
+ def tensor2uint(img):
189
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
190
+ if img.ndim == 3:
191
+ img = np.transpose(img, (1, 2, 0))
192
+ return np.uint8((img*255.0).round())
193
+
194
+
195
+ def single2tensor3(img):
196
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
197
+
198
+ def single2tensor4(img):
199
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
200
+
201
+ def tensor2single(img):
202
+ img = img.data.squeeze().float().cpu().numpy()
203
+ if img.ndim == 3:
204
+ img = np.transpose(img, (1, 2, 0))
205
+
206
+ return img
207
+
208
+ def tensor2single3(img):
209
+ img = img.data.squeeze().float().cpu().numpy()
210
+ if img.ndim == 3:
211
+ img = np.transpose(img, (1, 2, 0))
212
+ elif img.ndim == 2:
213
+ img = np.expand_dims(img, axis=2)
214
+ return img
215
+
216
+
217
+ def single2tensor5(img):
218
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
219
+
220
+
221
+ def single32tensor5(img):
222
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
223
+
224
+
225
+ def single42tensor4(img):
226
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
227
+
228
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
229
+
230
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
231
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
232
+ n_dim = tensor.dim()
233
+ if n_dim == 4:
234
+ n_img = len(tensor)
235
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
236
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
237
+ elif n_dim == 3:
238
+ img_np = tensor.numpy()
239
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
240
+ elif n_dim == 2:
241
+ img_np = tensor.numpy()
242
+ else:
243
+ raise TypeError(
244
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
245
+ if out_type == np.uint8:
246
+ img_np = (img_np * 255.0).round()
247
+ # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
248
+ return img_np.astype(out_type)
249
+
250
+
251
+
252
+ def augment_img(img, mode=0):
253
+ if mode == 0:
254
+ return img
255
+ elif mode == 1:
256
+ return np.flipud(np.rot90(img))
257
+ elif mode == 2:
258
+ return np.flipud(img)
259
+ elif mode == 3:
260
+ return np.rot90(img, k=3)
261
+ elif mode == 4:
262
+ return np.flipud(np.rot90(img, k=2))
263
+ elif mode == 5:
264
+ return np.rot90(img)
265
+ elif mode == 6:
266
+ return np.rot90(img, k=2)
267
+ elif mode == 7:
268
+ return np.flipud(np.rot90(img, k=3))
269
+
270
+
271
+ def augment_img_tensor4(img, mode=0):
272
+ if mode == 0:
273
+ return img
274
+ elif mode == 1:
275
+ return img.rot90(1, [2, 3]).flip([2])
276
+ elif mode == 2:
277
+ return img.flip([2])
278
+ elif mode == 3:
279
+ return img.rot90(3, [2, 3])
280
+ elif mode == 4:
281
+ return img.rot90(2, [2, 3]).flip([2])
282
+ elif mode == 5:
283
+ return img.rot90(1, [2, 3])
284
+ elif mode == 6:
285
+ return img.rot90(2, [2, 3])
286
+ elif mode == 7:
287
+ return img.rot90(3, [2, 3]).flip([2])
288
+
289
+
290
+ def augment_img_tensor(img, mode=0):
291
+ img_size = img.size()
292
+ img_np = img.data.cpu().numpy()
293
+ if len(img_size) == 3:
294
+ img_np = np.transpose(img_np, (1, 2, 0))
295
+ elif len(img_size) == 4:
296
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
297
+ img_np = augment_img(img_np, mode=mode)
298
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
299
+ if len(img_size) == 3:
300
+ img_tensor = img_tensor.permute(2, 0, 1)
301
+ elif len(img_size) == 4:
302
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
303
+
304
+ return img_tensor.type_as(img)
305
+
306
+
307
+ def augment_img_np3(img, mode=0):
308
+ if mode == 0:
309
+ return img
310
+ elif mode == 1:
311
+ return img.transpose(1, 0, 2)
312
+ elif mode == 2:
313
+ return img[::-1, :, :]
314
+ elif mode == 3:
315
+ img = img[::-1, :, :]
316
+ img = img.transpose(1, 0, 2)
317
+ return img
318
+ elif mode == 4:
319
+ return img[:, ::-1, :]
320
+ elif mode == 5:
321
+ img = img[:, ::-1, :]
322
+ img = img.transpose(1, 0, 2)
323
+ return img
324
+ elif mode == 6:
325
+ img = img[:, ::-1, :]
326
+ img = img[::-1, :, :]
327
+ return img
328
+ elif mode == 7:
329
+ img = img[:, ::-1, :]
330
+ img = img[::-1, :, :]
331
+ img = img.transpose(1, 0, 2)
332
+ return img
333
+
334
+
335
+ def augment_imgs(img_list, hflip=True, rot=True):
336
+ hflip = hflip and random.random() < 0.5
337
+ vflip = rot and random.random() < 0.5
338
+ rot90 = rot and random.random() < 0.5
339
+
340
+ def _augment(img):
341
+ if hflip:
342
+ img = img[:, ::-1, :]
343
+ if vflip:
344
+ img = img[::-1, :, :]
345
+ if rot90:
346
+ img = img.transpose(1, 0, 2)
347
+ return img
348
+
349
+ return [_augment(img) for img in img_list]
350
+
351
+
352
+ def modcrop(img_in, scale):
353
+ # img_in: Numpy, HWC or HW
354
+ img = np.copy(img_in)
355
+ if img.ndim == 2:
356
+ H, W = img.shape
357
+ H_r, W_r = H % scale, W % scale
358
+ img = img[:H - H_r, :W - W_r]
359
+ elif img.ndim == 3:
360
+ H, W, C = img.shape
361
+ H_r, W_r = H % scale, W % scale
362
+ img = img[:H - H_r, :W - W_r, :]
363
+ else:
364
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
365
+ return img
366
+
367
+
368
+ def shave(img_in, border=0):
369
+ # img_in: Numpy, HWC or HW
370
+ img = np.copy(img_in)
371
+ h, w = img.shape[:2]
372
+ img = img[border:h-border, border:w-border]
373
+ return img
374
+
375
+
376
+ def rgb2ycbcr(img, only_y=True):
377
+ in_img_type = img.dtype
378
+ img.astype(np.float32)
379
+ if in_img_type != np.uint8:
380
+ img *= 255.
381
+ # convert
382
+ if only_y:
383
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
384
+ else:
385
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
386
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
387
+ if in_img_type == np.uint8:
388
+ rlt = rlt.round()
389
+ else:
390
+ rlt /= 255.
391
+ return rlt.astype(in_img_type)
392
+
393
+
394
+ def ycbcr2rgb(img):
395
+ in_img_type = img.dtype
396
+ img.astype(np.float32)
397
+ if in_img_type != np.uint8:
398
+ img *= 255.
399
+ # convert
400
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
401
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
402
+ rlt = np.clip(rlt, 0, 255)
403
+ if in_img_type == np.uint8:
404
+ rlt = rlt.round()
405
+ else:
406
+ rlt /= 255.
407
+ return rlt.astype(in_img_type)
408
+
409
+
410
+ def bgr2ycbcr(img, only_y=True):
411
+ in_img_type = img.dtype
412
+ img.astype(np.float32)
413
+ if in_img_type != np.uint8:
414
+ img *= 255.
415
+ # convert
416
+ if only_y:
417
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
418
+ else:
419
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
420
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
421
+ if in_img_type == np.uint8:
422
+ rlt = rlt.round()
423
+ else:
424
+ rlt /= 255.
425
+ return rlt.astype(in_img_type)
426
+
427
+
428
+ def channel_convert(in_c, tar_type, img_list):
429
+ # conversion among BGR, gray and y
430
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
431
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
432
+ return [np.expand_dims(img, axis=2) for img in gray_list]
433
+ elif in_c == 3 and tar_type == 'y': # BGR to y
434
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
435
+ return [np.expand_dims(img, axis=2) for img in y_list]
436
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
437
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
438
+ else:
439
+ return img_list
440
+
441
+ def calculate_psnr(img1, img2, border=0):
442
+ if not img1.shape == img2.shape:
443
+ raise ValueError('Input images must have the same dimensions.')
444
+ h, w = img1.shape[:2]
445
+ img1 = img1[border:h-border, border:w-border]
446
+ img2 = img2[border:h-border, border:w-border]
447
+
448
+ img1 = img1.astype(np.float64)
449
+ img2 = img2.astype(np.float64)
450
+ mse = np.mean((img1 - img2)**2)
451
+ if mse == 0:
452
+ return float('inf')
453
+ return 20 * math.log10(255.0 / math.sqrt(mse))
454
+
455
+ def calculate_ssim(img1, img2, border=0):
456
+
457
+ if not img1.shape == img2.shape:
458
+ raise ValueError('Input images must have the same dimensions.')
459
+ h, w = img1.shape[:2]
460
+ img1 = img1[border:h-border, border:w-border]
461
+ img2 = img2[border:h-border, border:w-border]
462
+
463
+ if img1.ndim == 2:
464
+ return ssim(img1, img2)
465
+ elif img1.ndim == 3:
466
+ if img1.shape[2] == 3:
467
+ ssims = []
468
+ for i in range(3):
469
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
470
+ return np.array(ssims).mean()
471
+ elif img1.shape[2] == 1:
472
+ return ssim(np.squeeze(img1), np.squeeze(img2))
473
+ else:
474
+ raise ValueError('Wrong input image dimensions.')
475
+
476
+
477
+ def ssim(img1, img2):
478
+ C1 = (0.01 * 255)**2
479
+ C2 = (0.03 * 255)**2
480
+
481
+ img1 = img1.astype(np.float64)
482
+ img2 = img2.astype(np.float64)
483
+ kernel = cv2.getGaussianKernel(11, 1.5)
484
+ window = np.outer(kernel, kernel.transpose())
485
+
486
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
487
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
488
+ mu1_sq = mu1**2
489
+ mu2_sq = mu2**2
490
+ mu1_mu2 = mu1 * mu2
491
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
492
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
493
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
494
+
495
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
496
+ (sigma1_sq + sigma2_sq + C2))
497
+ return ssim_map.mean()
498
+
499
+
500
+ def _blocking_effect_factor(im):
501
+ block_size = 8
502
+
503
+ block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
504
+ block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
505
+
506
+ horizontal_block_difference = (
507
+ (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
508
+ 3).sum(2).sum(1)
509
+ vertical_block_difference = (
510
+ (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
511
+ 2).sum(1)
512
+
513
+ nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
514
+ nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
515
+
516
+ horizontal_nonblock_difference = (
517
+ (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
518
+ 3).sum(2).sum(1)
519
+ vertical_nonblock_difference = (
520
+ (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
521
+ 3).sum(2).sum(1)
522
+
523
+ n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
524
+ n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
525
+ boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
526
+ n_boundary_horiz + n_boundary_vert)
527
+
528
+ n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
529
+ n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
530
+ nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
531
+ n_nonboundary_horiz + n_nonboundary_vert)
532
+
533
+ scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
534
+ bef = scaler * (boundary_difference - nonboundary_difference)
535
+
536
+ bef[boundary_difference <= nonboundary_difference] = 0
537
+ return bef
538
+
539
+
540
+ def calculate_psnrb(img1, img2, border=0):
541
+
542
+ if not img1.shape == img2.shape:
543
+ raise ValueError('Input images must have the same dimensions.')
544
+
545
+ if img1.ndim == 2:
546
+ img1, img2 = np.expand_dims(img1, 2), np.expand_dims(img2, 2)
547
+
548
+ h, w = img1.shape[:2]
549
+ img1 = img1[border:h-border, border:w-border]
550
+ img2 = img2[border:h-border, border:w-border]
551
+
552
+ img1 = img1.astype(np.float64)
553
+ img2 = img2.astype(np.float64)
554
+
555
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
556
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
557
+
558
+ total = 0
559
+ for c in range(img1.shape[1]):
560
+ mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
561
+ bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
562
+
563
+ mse = mse.view(mse.shape[0], -1).mean(1)
564
+ total += 10 * torch.log10(1 / (mse + bef))
565
+
566
+ return float(total) / img1.shape[1]
567
+
568
+ def cubic(x):
569
+ absx = torch.abs(x)
570
+ absx2 = absx**2
571
+ absx3 = absx**3
572
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
573
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
574
+
575
+
576
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
577
+ if (scale < 1) and (antialiasing):
578
+ kernel_width = kernel_width / scale
579
+
580
+ x = torch.linspace(1, out_length, out_length)
581
+
582
+ u = x / scale + 0.5 * (1 - 1 / scale)
583
+
584
+ left = torch.floor(u - kernel_width / 2)
585
+
586
+ P = math.ceil(kernel_width) + 2
587
+
588
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
589
+ 1, P).expand(out_length, P)
590
+
591
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
592
+
593
+ if (scale < 1) and (antialiasing):
594
+ weights = scale * cubic(distance_to_center * scale)
595
+ else:
596
+ weights = cubic(distance_to_center)
597
+
598
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
599
+ weights = weights / weights_sum.expand(out_length, P)
600
+
601
+ weights_zero_tmp = torch.sum((weights == 0), 0)
602
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
603
+ indices = indices.narrow(1, 1, P - 2)
604
+ weights = weights.narrow(1, 1, P - 2)
605
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
606
+ indices = indices.narrow(1, 0, P - 2)
607
+ weights = weights.narrow(1, 0, P - 2)
608
+ weights = weights.contiguous()
609
+ indices = indices.contiguous()
610
+ sym_len_s = -indices.min() + 1
611
+ sym_len_e = indices.max() - in_length
612
+ indices = indices + sym_len_s - 1
613
+ return weights, indices, int(sym_len_s), int(sym_len_e)
614
+
615
+ def imresize(img, scale, antialiasing=True):
616
+ need_squeeze = True if img.dim() == 2 else False
617
+ if need_squeeze:
618
+ img.unsqueeze_(0)
619
+ in_C, in_H, in_W = img.size()
620
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
621
+ kernel_width = 4
622
+ kernel = 'cubic'
623
+
624
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
625
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
626
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
627
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
628
+
629
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
630
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
631
+
632
+ sym_patch = img[:, :sym_len_Hs, :]
633
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
634
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
635
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
636
+
637
+ sym_patch = img[:, -sym_len_He:, :]
638
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
639
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
640
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
641
+
642
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
643
+ kernel_width = weights_H.size(1)
644
+ for i in range(out_H):
645
+ idx = int(indices_H[i][0])
646
+ for j in range(out_C):
647
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
648
+
649
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
650
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
651
+
652
+ sym_patch = out_1[:, :, :sym_len_Ws]
653
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
654
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
655
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
656
+
657
+ sym_patch = out_1[:, :, -sym_len_We:]
658
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
659
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
660
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
661
+
662
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
663
+ kernel_width = weights_W.size(1)
664
+ for i in range(out_W):
665
+ idx = int(indices_W[i][0])
666
+ for j in range(out_C):
667
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
668
+ if need_squeeze:
669
+ out_2.squeeze_()
670
+ return out_2
671
+
672
+ def imresize_np(img, scale, antialiasing=True):
673
+ img = torch.from_numpy(img)
674
+ need_squeeze = True if img.dim() == 2 else False
675
+ if need_squeeze:
676
+ img.unsqueeze_(2)
677
+
678
+ in_H, in_W, in_C = img.size()
679
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
680
+ kernel_width = 4
681
+ kernel = 'cubic'
682
+
683
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
684
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
685
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
686
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
687
+
688
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
689
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
690
+
691
+ sym_patch = img[:sym_len_Hs, :, :]
692
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
693
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
694
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
695
+
696
+ sym_patch = img[-sym_len_He:, :, :]
697
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
698
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
699
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
700
+
701
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
702
+ kernel_width = weights_H.size(1)
703
+ for i in range(out_H):
704
+ idx = int(indices_H[i][0])
705
+ for j in range(out_C):
706
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
707
+
708
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
709
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
710
+
711
+ sym_patch = out_1[:, :sym_len_Ws, :]
712
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
713
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
714
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
715
+
716
+ sym_patch = out_1[:, -sym_len_We:, :]
717
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
718
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
719
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
720
+
721
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
722
+ kernel_width = weights_W.size(1)
723
+ for i in range(out_W):
724
+ idx = int(indices_W[i][0])
725
+ for j in range(out_C):
726
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
727
+ if need_squeeze:
728
+ out_2.squeeze_()
729
+
730
+ return out_2.numpy()
731
+
732
+
733
+ if __name__ == '__main__':
734
+ img = imread_uint('test.bmp', 3)
utils/utils_model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from utils import utils_image as util
4
+
5
+ def infer(model, L):
6
+ E = model(L)
7
+ return E
8
+
9
+ def inferp(model, L, modulo=16):
10
+ h, w = L.size()[-2:]
11
+ paddingBottom = int(np.ceil(h/modulo)*modulo-h)
12
+ paddingRight = int(np.ceil(w/modulo)*modulo-w)
13
+ L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L)
14
+ E = model(L)
15
+ E = E[..., :h, :w]
16
+ return E
17
+
18
+ def inferspfn(model, L, refield=32, min_size=256, sf=1, modulo=1):
19
+ h, w = L.size()[-2:]
20
+ if h*w <= min_size**2:
21
+ L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L)
22
+ E = model(L)
23
+ E = E[..., :h*sf, :w*sf]
24
+ else:
25
+ top = slice(0, (h//2//refield+1)*refield)
26
+ bottom = slice(h - (h//2//refield+1)*refield, h)
27
+ left = slice(0, (w//2//refield+1)*refield)
28
+ right = slice(w - (w//2//refield+1)*refield, w)
29
+ Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
30
+
31
+ if h * w <= 4*(min_size**2):
32
+ Es = [model(Ls[i]) for i in range(4)]
33
+ else:
34
+ Es = [inferspfn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)]
35
+
36
+ b, c = Es[0].size()[:2]
37
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
38
+
39
+ E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
40
+ E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
41
+ E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
42
+ E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
43
+ return E
44
+
45
+
46
+ def infersp(model, L, refield=32, min_size=256, sf=1, modulo=1):
47
+ E = inferspfn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo)
48
+ return E
49
+
50
+ def inferosp(model, L, refield=32, min_size=256, sf=1, modulo=1):
51
+ h, w = L.size()[-2:]
52
+
53
+ top = slice(0, (h//2//refield+1)*refield)
54
+ bottom = slice(h - (h//2//refield+1)*refield, h)
55
+ left = slice(0, (w//2//refield+1)*refield)
56
+ right = slice(w - (w//2//refield+1)*refield, w)
57
+ Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]]
58
+ Es = [model(Ls[i]) for i in range(4)]
59
+ b, c = Es[0].size()[:2]
60
+ E = torch.zeros(b, c, sf * h, sf * w).type_as(L)
61
+ E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf]
62
+ E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:]
63
+ E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf]
64
+ E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:]
65
+ return E
66
+
67
+ def inference(model, L, mode=0, refield=128, min_size=256, sf=1, modulo=1):
68
+ if mode == 0:
69
+ E = infer(model, L)
70
+ elif mode == 1:
71
+ E = inferp(model, L, modulo)
72
+ elif mode == 2:
73
+ E = infersp(model, L, refield, min_size, sf, modulo)
74
+ elif mode == 3:
75
+ E = inferosp(model, L, refield, min_size, sf, modulo)
76
+ return E
77
+
78
+
79
+ if __name__ == '__main__':
80
+
81
+ class Net(torch.nn.Module):
82
+ def __init__(self, in_channels=3, out_channels=3):
83
+ super(Net, self).__init__()
84
+ self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)
85
+
86
+ def forward(self, x):
87
+ x = self.conv(x)
88
+ return x
89
+
90
+ start = torch.cuda.Event(enable_timing=True)
91
+ end = torch.cuda.Event(enable_timing=True)
92
+
93
+ model = Net()
94
+ model = model.eval()
95
+ x = torch.randn((2,3,400,400))
96
+ torch.cuda.empty_cache()
97
+ with torch.no_grad():
98
+ for mode in range(5):
99
+ y = inference(model, x, mode)
100
+ print(y.shape)