zhengrongzhang commited on
Commit
da9195c
·
1 Parent(s): 01afa9a

init model

Browse files
RCAN_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f80a5945e9d7bd9da2625aeec430dad3ba1123788edf36416f80ef59207c804
3
+ size 445505
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - RyzenAI
5
+ - Super Resolution
6
+ - Pytorch
7
+ - Vision
8
+ - SISR
9
+ datasets:
10
+ - Set5
11
+ - Div2k
12
+ language:
13
+ - en
14
+ Metircs:
15
+ - PSNR
16
+ ---
17
+
18
+ # RCAN model trained on DIV2K
19
+
20
+ RCAN is a very deep residual channel attention network for super resolution trained on DIV2K. It was introduced in the paper [Image Super-Resolution Using Very Deep Residual Channel Attention Networks in 2018](https://arxiv.org/abs/1807.02758) by Yulun Zhang et al. and first released in [this repository](https://github.com/yulunzhang/RCAN).
21
+
22
+ We develop a modified version that could be supported by [AMD Ryzen AI](https://ryzenai.docs.amd.com).
23
+
24
+
25
+ ## Model description
26
+ RCAN is an advanced algorithm for single image super resolution. Our modified version is smaller than the original version. It is based deep learning techniques and is capable of X2 super resolution.
27
+
28
+
29
+ ## Intended uses & limitations
30
+
31
+ You can use the raw model for super resolution. See the [model hub](https://huggingface.co/models?sort=trending&search=amd%2Frcan) to look for all available RCAN models.
32
+
33
+
34
+ ## How to use
35
+
36
+ ### Installation
37
+
38
+ Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
39
+ Run the following script to install pre-requisites for this model.
40
+ ```bash
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+
45
+ ### Data Preparation (optional: for accuracy evaluation)
46
+
47
+ 1. Download the benchmark(https://cv.snu.ac.kr/research/EDSR/benchmark.tar) dataset.
48
+ 2. Organize the dataset directory as follows:
49
+ ```Plain
50
+ └── dataset
51
+ └── benchmark
52
+ ├── Set5
53
+ ├── HR
54
+ | ├── baby.png
55
+ | ├── ...
56
+ └── LR_bicubic
57
+ └──X2
58
+ ├──babyx2.png
59
+ ├── ...
60
+ ├── Set14
61
+ ├── ...
62
+ ```
63
+
64
+ ### Test & Evaluation
65
+
66
+ - Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
67
+ ```python
68
+ parser = argparse.ArgumentParser(description='RCAN SISR')
69
+ parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
70
+ help='onnx path')
71
+ parser.add_argument('--image_path', default='test_data/test.png',
72
+ help='path of your image')
73
+ parser.add_argument('--output_path', default='test_data/sr.png',
74
+ help='path of your image')
75
+ parser.add_argument('--ipu', action='store_true',
76
+ help='use ipu')
77
+ parser.add_argument('--provider_config', type=str, default=None,
78
+ help='provider config path')
79
+ args = parser.parse_args()
80
+
81
+ if args.ipu:
82
+ providers = ["VitisAIExecutionProvider"]
83
+ provider_options = [{"config_file": args.provider_config}]
84
+ else:
85
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
86
+ provider_options = None
87
+ onnx_file_name = args.onnx_path
88
+ image_path = args.image_path
89
+ output_path = args.output_path
90
+
91
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
92
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
93
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
94
+ sr = np.clip(sr, 0, 255)
95
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
96
+ sr = cv2.imwrite(output_path, sr)
97
+ ```
98
+
99
+ - Run inference for a single image
100
+ ```python
101
+ python infer_onnx.py --onnx_path RCAN_int8.onnx --image_path /Path/To/Your/Image --ipu --provider_config Path/To/vaip_config.json
102
+ ```
103
+
104
+ - Test accuracy of the quantized model
105
+ ```python
106
+ python eval_onnx.py --onnx_path RCAN_int8.onnx --data_test Set5 --ipu --provider_config Path/To/vaip_config.json
107
+ ```
108
+ ### Performance
109
+
110
+ | Method | Scale | Flops | Set5 |
111
+ |------------|-------|-------|--------------|
112
+ |RCAN-S (float) |X2 |24.5G |37.531 / 0.958|
113
+ |RCAN-S (INT8) |X2 |24.5G |37.150 / 0.955|
114
+ - Note: the Flops is calculated with the output resolution is 360x640
115
+
116
+ ```bibtex
117
+ @inproceedings{zhang2018image,
118
+ title={Image super-resolution using very deep residual channel attention networks},
119
+ author={Zhang, Yulun and Li, Kunpeng and Li, Kai and Wang, Lichen and Zhong, Bineng and Fu, Yun},
120
+ booktitle={Proceedings of the European conference on computer vision (ECCV)},
121
+ pages={286--301},
122
+ year={2018}
123
+ }
124
+ ```
data/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ #from dataloader import MSDataLoader
3
+ from torch.utils.data import dataloader
4
+ from torch.utils.data import ConcatDataset
5
+ import torch
6
+ import random
7
+ # This is a simple wrapper function for ConcatDataset
8
+ class MyConcatDataset(ConcatDataset):
9
+ def __init__(self, datasets):
10
+ super(MyConcatDataset, self).__init__(datasets)
11
+
12
+
13
+ def set_scale(self, idx_scale):
14
+ for d in self.datasets:
15
+ if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
16
+
17
+ class Data:
18
+ def __init__(self, args):
19
+ self.loader_train = None
20
+ self.loader_test = []
21
+ for d in args.data_test:
22
+ if d in ['Set5', 'Set14', 'B100', 'Urban100']:
23
+ m = import_module('data.benchmark')
24
+ testset = getattr(m, 'Benchmark')(args, name=d)
25
+ else:
26
+ assert NotImplementedError
27
+
28
+ self.loader_test.append(
29
+ dataloader.DataLoader(
30
+ testset,
31
+ batch_size=1,
32
+ shuffle=False,
33
+ pin_memory=False,
34
+ num_workers=args.n_threads,
35
+ )
36
+ )
data/benchmark.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ #from data import common
4
+ from data import srdata
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+
9
+ class Benchmark(srdata.SRData):
10
+ def __init__(self, args, name='', benchmark=True):
11
+ super(Benchmark, self).__init__(
12
+ args, name=name, benchmark=True
13
+ )
14
+
15
+ def _set_filesystem(self, dir_data):
16
+ self.apath = os.path.join(dir_data, 'benchmark', self.name)
17
+ self.dir_hr = os.path.join(self.apath, 'HR')
18
+ if self.input_large:
19
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
20
+ else:
21
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
22
+ self.ext = ('', '.png')
23
+
data/common.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import skimage.color as sc
5
+
6
+ import torch
7
+
8
+ def set_channel(*args, n_channels=3):
9
+ def _set_channel(img):
10
+ if img.ndim == 2:
11
+ img = np.expand_dims(img, axis=2)
12
+
13
+ c = img.shape[2]
14
+ if n_channels == 1 and c == 3:
15
+ img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
16
+ elif n_channels == 3 and c == 1:
17
+ img = np.concatenate([img] * n_channels, 2)
18
+
19
+ return img
20
+
21
+ return [_set_channel(a) for a in args]
22
+
23
+ def np2Tensor(*args, rgb_range=255):
24
+ def _np2Tensor(img):
25
+ np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
26
+ tensor = torch.from_numpy(np_transpose).float()
27
+ tensor.mul_(rgb_range / 255)
28
+
29
+ return tensor
30
+
31
+ return [_np2Tensor(a) for a in args]
32
+
33
+
34
+
data/data_tiling.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import onnxruntime
3
+ import numpy as np
4
+ import math
5
+
6
+
7
+ def tiling_inference(session, lr, overlapping, patch_size):
8
+ _, _, h, w = lr.shape
9
+ sr = np.zeros((1, 3, 2*h, 2*w))
10
+ n_h = math.ceil(h / float(patch_size[0] - overlapping))
11
+ n_w = math.ceil(w / float(patch_size[1] - overlapping))
12
+ #every tilling input has same size of patch_size
13
+ for ih in range(n_h):
14
+ h_idx = ih * (patch_size[0] - overlapping)
15
+ h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0]
16
+ for iw in range(n_w):
17
+ w_idx = iw * (patch_size[1] - overlapping)
18
+ w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
19
+
20
+ tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
21
+ sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
22
+
23
+ left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
24
+ left += overlapping//2
25
+ right -= overlapping//2
26
+ top += overlapping//2
27
+ bottom -= overlapping//2
28
+ #processing edge pixels
29
+ if w_idx == 0:
30
+ left -= overlapping//2
31
+ if h_idx == 0:
32
+ top -= overlapping//2
33
+ if h_idx+patch_size[0]>=h:
34
+ bottom += overlapping//2
35
+ if w_idx+patch_size[1]>=w:
36
+ right += overlapping//2
37
+
38
+ #get preditions
39
+ sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
40
+ return sr
data/srdata.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import random
4
+ import pickle
5
+
6
+ from data import common
7
+
8
+ import numpy as np
9
+ import imageio
10
+ import torch
11
+ import torch.utils.data as data
12
+
13
+ class SRData(data.Dataset):
14
+ def __init__(self, args, name='', benchmark=False):
15
+ self.args = args
16
+ self.name = name
17
+ self.split = 'test'
18
+ self.do_eval = True
19
+ self.benchmark = benchmark
20
+ self.input_large = False
21
+ self.scale = args.scale
22
+ self.idx_scale = 0
23
+ self._set_filesystem(args.dir_data)
24
+ list_hr, list_lr = self._scan()
25
+ self.images_hr, self.images_lr = list_hr, list_lr
26
+
27
+ # Below functions as used to prepare images
28
+ def _scan(self):
29
+ names_hr = sorted(
30
+ glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
31
+ )
32
+ names_lr = [[] for _ in self.scale]
33
+ for f in names_hr:
34
+ filename, _ = os.path.splitext(os.path.basename(f))
35
+ for si, s in enumerate(self.scale):
36
+ names_lr[si].append(os.path.join(
37
+ self.dir_lr, 'X{}/{}x{}{}'.format(
38
+ s, filename, s, self.ext[1]
39
+ )
40
+ ))
41
+
42
+ return names_hr, names_lr
43
+
44
+ def _set_filesystem(self, dir_data):
45
+ self.apath = os.path.join(dir_data, self.name)
46
+ self.dir_hr = os.path.join(self.apath, 'HR')
47
+ self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
48
+ self.ext = ('.png', '.png')
49
+
50
+ def __getitem__(self, idx):
51
+ lr, hr, filename = self._load_file(idx)
52
+ pair = self.get_patch(lr, hr)
53
+ pair = common.set_channel(*pair, n_channels=3)
54
+ pair_t = common.np2Tensor(*pair, rgb_range=255)
55
+
56
+ return pair_t[0], pair_t[1], filename
57
+
58
+ def __len__(self):
59
+ return len(self.images_hr)
60
+
61
+ def _get_index(self, idx):
62
+ return idx
63
+
64
+ def _load_file(self, idx):
65
+ idx = self._get_index(idx)
66
+ f_hr = self.images_hr[idx]
67
+ f_lr = self.images_lr[self.idx_scale][idx]
68
+
69
+ filename, _ = os.path.splitext(os.path.basename(f_hr))
70
+ hr = imageio.imread(f_hr)
71
+ lr = imageio.imread(f_lr)
72
+ return lr, hr, filename
73
+
74
+ def get_patch(self, lr, hr):
75
+ scale = self.scale[self.idx_scale]
76
+ ih, iw = lr.shape[:2]
77
+ hr = hr[0:ih * scale, 0:iw * scale]
78
+ return lr, hr
79
+
80
+ def set_scale(self, idx_scale):
81
+ if not self.input_large:
82
+ self.idx_scale = idx_scale
83
+ else:
84
+ self.idx_scale = random.randint(0, len(self.scale) - 1)
85
+
eval_onnx.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import sys
4
+ import pathlib
5
+ CURRENT_DIR = pathlib.Path(__file__).parent
6
+ sys.path.append(str(CURRENT_DIR))
7
+ from tqdm import tqdm
8
+ import data
9
+ import metric
10
+ import onnxruntime
11
+ import cv2
12
+ from data.data_tiling import tiling_inference
13
+ import argparse
14
+
15
+ class Configs():
16
+ def __init__(self):
17
+ parser = argparse.ArgumentParser(description='SR')
18
+
19
+ # ipu test or cpu, you need to provide onnx path
20
+ parser.add_argument('--ipu', action='store_true',
21
+ help='use ipu')
22
+ parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
23
+ help='onnx path')
24
+ parser.add_argument('--provider_config', type=str, default=None,
25
+ help='provider config path')
26
+ # Data specifications, you can use default
27
+ parser.add_argument('--dir_data', type=str, default='dataset/',
28
+ help='dataset directory')
29
+ parser.add_argument('--data_test', type=str, default='Set5',
30
+ help='test dataset name')
31
+
32
+ parser.add_argument('--n_threads', type=int, default=6,
33
+ help='number of threads for data loading')
34
+ parser.add_argument('--scale', type=str, default='2',
35
+ help='super resolution scale, now only support x2')
36
+ self.parser = parser
37
+
38
+ def parse(self):
39
+ args = self.parser.parse_args()
40
+ args.scale = list(map(lambda x: int(x), args.scale.split('+')))
41
+ args.data_test = args.data_test.split('+')
42
+ print(args)
43
+ return args
44
+
45
+
46
+
47
+ def quantize(img, rgb_range): # clamp pix to rgb range
48
+ pixel_range = 255 / rgb_range
49
+ return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
50
+
51
+ def test_model(session, loader, device):
52
+ torch.set_grad_enabled(False)
53
+ self_scale = [2]
54
+ for idx_data, d in enumerate(loader.loader_test):
55
+ eval_ssim = 0
56
+ eval_psnr = 0
57
+ for idx_scale, scale in enumerate(self_scale):
58
+ d.dataset.set_scale(idx_scale)
59
+ for lr, hr, filename in tqdm(d, ncols=80):
60
+ sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
61
+ sr = torch.from_numpy(sr).to(device)
62
+ sr = quantize(sr, 255)
63
+ eval_psnr += metric.calc_psnr(
64
+ sr, hr, scale, 255, benchmark=d)
65
+ eval_ssim += metric.calc_ssim(
66
+ sr, hr, scale, 255, dataset=d)
67
+ mean_ssim = eval_ssim / len(d)
68
+ mean_psnr = eval_psnr / len(d)
69
+ print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
70
+ return mean_psnr, mean_ssim
71
+
72
+ def main(args):
73
+ loader = data.Data(args)
74
+ if args.ipu:
75
+ providers = ["VitisAIExecutionProvider"]
76
+ provider_options = [{"config_file": args.provider_config}]
77
+ else:
78
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
79
+ provider_options = None
80
+ onnx_file_name = args.onnx_path
81
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
82
+ test_model(ort_session, loader, device="cpu")
83
+
84
+
85
+ if __name__ == '__main__':
86
+ args = Configs().parse()
87
+ main(args)
infer_onnx.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import onnxruntime
3
+ import cv2
4
+ import sys
5
+ import pathlib
6
+ CURRENT_DIR = pathlib.Path(__file__).parent
7
+ sys.path.append(str(CURRENT_DIR))
8
+ import numpy as np
9
+ from data.data_tiling import tiling_inference
10
+ import argparse
11
+
12
+
13
+ def main(args):
14
+ if args.ipu:
15
+ providers = ["VitisAIExecutionProvider"]
16
+ provider_options = [{"config_file": args.provider_config}]
17
+ else:
18
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
19
+ provider_options = None
20
+ onnx_file_name = args.onnx_path
21
+ image_path = args.image_path
22
+ output_path = args.output_path
23
+
24
+ ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
25
+ lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
26
+ sr = tiling_inference(ort_session, lr, 8, (56, 56))
27
+ sr = np.clip(sr, 0, 255)
28
+ sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
29
+ sr = cv2.imwrite(output_path, sr)
30
+
31
+
32
+ if __name__ == '__main__':
33
+ parser = argparse.ArgumentParser(description='RCAN SISR')
34
+ parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
35
+ help='onnx path')
36
+ parser.add_argument('--image_path', default='test_data/test.png',
37
+ help='path of your image')
38
+ parser.add_argument('--output_path', default='test_data/sr.png',
39
+ help='path of your image')
40
+ parser.add_argument('--ipu', action='store_true',
41
+ help='use ipu')
42
+ parser.add_argument('--provider_config', type=str, default=None,
43
+ help='provider config path')
44
+ args = parser.parse_args()
45
+ main(args)
metric.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import skimage.measure
5
+ import skimage.color
6
+
7
+ def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
8
+ if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
9
+ print("the dimention of sr image is not equal to hr's! ")
10
+ sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
11
+ diff = (sr - hr).data.div(rgb_range)
12
+
13
+ if benchmark:
14
+ shave = scale
15
+ if diff.size(1) > 1:
16
+ convert = diff.new(1, 3, 1, 1)
17
+ convert[0, 0, 0, 0] = 65.738
18
+ convert[0, 1, 0, 0] = 129.057
19
+ convert[0, 2, 0, 0] = 25.064
20
+ diff.mul_(convert).div_(256)
21
+ diff = diff.sum(dim=1, keepdim=True)
22
+ else:
23
+ shave = scale + 6
24
+ valid = diff[:, :, shave:-shave, shave:-shave]
25
+ mse = valid.pow(2).mean()
26
+
27
+ return -10 * math.log10(mse)
28
+
29
+
30
+ import numpy as np
31
+ from scipy import signal
32
+
33
+
34
+ def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
35
+ """
36
+ 2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
37
+ Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
38
+ """
39
+ m,n = [(ss-1.)/2. for ss in shape]
40
+ y,x = np.ogrid[-m:m+1,-n:n+1]
41
+ h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
42
+ h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
43
+ sumh = h.sum()
44
+ if sumh != 0:
45
+ h /= sumh
46
+ return h
47
+
48
+ def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
49
+ '''
50
+ X : y channel (i.e., luminance) of transformed YCbCr space of X
51
+ Y : y channel (i.e., luminance) of transformed YCbCr space of Y
52
+ Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017).
53
+ Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba
54
+ The authors of EDSR use MATLAB's ssim as the evaluation tool,
55
+ thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2.
56
+ '''
57
+ gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
58
+
59
+ if True:#dataset and dataset.dataset.benchmark:
60
+ shave = scale
61
+ if X.size(1) > 1:
62
+ gray_coeffs = [65.738, 129.057, 25.064]
63
+ convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
64
+ X = X.mul(convert).sum(dim=1)
65
+ Y = Y.mul(convert).sum(dim=1)
66
+ else:
67
+ shave = scale + 6
68
+
69
+ X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
70
+ Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
71
+
72
+ window = gaussian_filter
73
+
74
+ ux = signal.convolve2d(X, window, mode='same', boundary='symm')
75
+ uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
76
+
77
+ uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
78
+ uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
79
+ uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
80
+
81
+ vx = uxx - ux * ux
82
+ vy = uyy - uy * uy
83
+ vxy = uxy - ux * uy
84
+
85
+ C1 = (K1 * R) ** 2
86
+ C2 = (K2 * R) ** 2
87
+
88
+ A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
89
+ D = B1 * B2
90
+ S = (A1 * A2) / D
91
+ mssim = S.mean()
92
+
93
+ return mssim
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ numpy>=1.23.5
3
+ scipy>=1.9
4
+ opencv-python
5
+ pandas
6
+ pillow
7
+ scikit-image
8
+ tqdm
test_data/test.png ADDED