zhengrongzhang
commited on
Commit
•
da9195c
1
Parent(s):
01afa9a
init model
Browse files- RCAN_int8.onnx +3 -0
- README.md +124 -0
- data/__init__.py +36 -0
- data/benchmark.py +23 -0
- data/common.py +34 -0
- data/data_tiling.py +40 -0
- data/srdata.py +85 -0
- eval_onnx.py +87 -0
- infer_onnx.py +45 -0
- metric.py +93 -0
- requirements.txt +8 -0
- test_data/test.png +0 -0
RCAN_int8.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f80a5945e9d7bd9da2625aeec430dad3ba1123788edf36416f80ef59207c804
|
3 |
+
size 445505
|
README.md
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- RyzenAI
|
5 |
+
- Super Resolution
|
6 |
+
- Pytorch
|
7 |
+
- Vision
|
8 |
+
- SISR
|
9 |
+
datasets:
|
10 |
+
- Set5
|
11 |
+
- Div2k
|
12 |
+
language:
|
13 |
+
- en
|
14 |
+
Metircs:
|
15 |
+
- PSNR
|
16 |
+
---
|
17 |
+
|
18 |
+
# RCAN model trained on DIV2K
|
19 |
+
|
20 |
+
RCAN is a very deep residual channel attention network for super resolution trained on DIV2K. It was introduced in the paper [Image Super-Resolution Using Very Deep Residual Channel Attention Networks in 2018](https://arxiv.org/abs/1807.02758) by Yulun Zhang et al. and first released in [this repository](https://github.com/yulunzhang/RCAN).
|
21 |
+
|
22 |
+
We develop a modified version that could be supported by [AMD Ryzen AI](https://ryzenai.docs.amd.com).
|
23 |
+
|
24 |
+
|
25 |
+
## Model description
|
26 |
+
RCAN is an advanced algorithm for single image super resolution. Our modified version is smaller than the original version. It is based deep learning techniques and is capable of X2 super resolution.
|
27 |
+
|
28 |
+
|
29 |
+
## Intended uses & limitations
|
30 |
+
|
31 |
+
You can use the raw model for super resolution. See the [model hub](https://huggingface.co/models?sort=trending&search=amd%2Frcan) to look for all available RCAN models.
|
32 |
+
|
33 |
+
|
34 |
+
## How to use
|
35 |
+
|
36 |
+
### Installation
|
37 |
+
|
38 |
+
Follow [Ryzen AI Installation](https://ryzenai.docs.amd.com/en/latest/inst.html) to prepare the environment for Ryzen AI.
|
39 |
+
Run the following script to install pre-requisites for this model.
|
40 |
+
```bash
|
41 |
+
pip install -r requirements.txt
|
42 |
+
```
|
43 |
+
|
44 |
+
|
45 |
+
### Data Preparation (optional: for accuracy evaluation)
|
46 |
+
|
47 |
+
1. Download the benchmark(https://cv.snu.ac.kr/research/EDSR/benchmark.tar) dataset.
|
48 |
+
2. Organize the dataset directory as follows:
|
49 |
+
```Plain
|
50 |
+
└── dataset
|
51 |
+
└── benchmark
|
52 |
+
├── Set5
|
53 |
+
├── HR
|
54 |
+
| ├── baby.png
|
55 |
+
| ├── ...
|
56 |
+
└── LR_bicubic
|
57 |
+
└──X2
|
58 |
+
├──babyx2.png
|
59 |
+
├── ...
|
60 |
+
├── Set14
|
61 |
+
├── ...
|
62 |
+
```
|
63 |
+
|
64 |
+
### Test & Evaluation
|
65 |
+
|
66 |
+
- Code snippet from [`infer_onnx.py`](infer_onnx.py) on how to use
|
67 |
+
```python
|
68 |
+
parser = argparse.ArgumentParser(description='RCAN SISR')
|
69 |
+
parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
|
70 |
+
help='onnx path')
|
71 |
+
parser.add_argument('--image_path', default='test_data/test.png',
|
72 |
+
help='path of your image')
|
73 |
+
parser.add_argument('--output_path', default='test_data/sr.png',
|
74 |
+
help='path of your image')
|
75 |
+
parser.add_argument('--ipu', action='store_true',
|
76 |
+
help='use ipu')
|
77 |
+
parser.add_argument('--provider_config', type=str, default=None,
|
78 |
+
help='provider config path')
|
79 |
+
args = parser.parse_args()
|
80 |
+
|
81 |
+
if args.ipu:
|
82 |
+
providers = ["VitisAIExecutionProvider"]
|
83 |
+
provider_options = [{"config_file": args.provider_config}]
|
84 |
+
else:
|
85 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
86 |
+
provider_options = None
|
87 |
+
onnx_file_name = args.onnx_path
|
88 |
+
image_path = args.image_path
|
89 |
+
output_path = args.output_path
|
90 |
+
|
91 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
92 |
+
lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
|
93 |
+
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
94 |
+
sr = np.clip(sr, 0, 255)
|
95 |
+
sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
|
96 |
+
sr = cv2.imwrite(output_path, sr)
|
97 |
+
```
|
98 |
+
|
99 |
+
- Run inference for a single image
|
100 |
+
```python
|
101 |
+
python infer_onnx.py --onnx_path RCAN_int8.onnx --image_path /Path/To/Your/Image --ipu --provider_config Path/To/vaip_config.json
|
102 |
+
```
|
103 |
+
|
104 |
+
- Test accuracy of the quantized model
|
105 |
+
```python
|
106 |
+
python eval_onnx.py --onnx_path RCAN_int8.onnx --data_test Set5 --ipu --provider_config Path/To/vaip_config.json
|
107 |
+
```
|
108 |
+
### Performance
|
109 |
+
|
110 |
+
| Method | Scale | Flops | Set5 |
|
111 |
+
|------------|-------|-------|--------------|
|
112 |
+
|RCAN-S (float) |X2 |24.5G |37.531 / 0.958|
|
113 |
+
|RCAN-S (INT8) |X2 |24.5G |37.150 / 0.955|
|
114 |
+
- Note: the Flops is calculated with the output resolution is 360x640
|
115 |
+
|
116 |
+
```bibtex
|
117 |
+
@inproceedings{zhang2018image,
|
118 |
+
title={Image super-resolution using very deep residual channel attention networks},
|
119 |
+
author={Zhang, Yulun and Li, Kunpeng and Li, Kai and Wang, Lichen and Zhong, Bineng and Fu, Yun},
|
120 |
+
booktitle={Proceedings of the European conference on computer vision (ECCV)},
|
121 |
+
pages={286--301},
|
122 |
+
year={2018}
|
123 |
+
}
|
124 |
+
```
|
data/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
#from dataloader import MSDataLoader
|
3 |
+
from torch.utils.data import dataloader
|
4 |
+
from torch.utils.data import ConcatDataset
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
# This is a simple wrapper function for ConcatDataset
|
8 |
+
class MyConcatDataset(ConcatDataset):
|
9 |
+
def __init__(self, datasets):
|
10 |
+
super(MyConcatDataset, self).__init__(datasets)
|
11 |
+
|
12 |
+
|
13 |
+
def set_scale(self, idx_scale):
|
14 |
+
for d in self.datasets:
|
15 |
+
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
|
16 |
+
|
17 |
+
class Data:
|
18 |
+
def __init__(self, args):
|
19 |
+
self.loader_train = None
|
20 |
+
self.loader_test = []
|
21 |
+
for d in args.data_test:
|
22 |
+
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
|
23 |
+
m = import_module('data.benchmark')
|
24 |
+
testset = getattr(m, 'Benchmark')(args, name=d)
|
25 |
+
else:
|
26 |
+
assert NotImplementedError
|
27 |
+
|
28 |
+
self.loader_test.append(
|
29 |
+
dataloader.DataLoader(
|
30 |
+
testset,
|
31 |
+
batch_size=1,
|
32 |
+
shuffle=False,
|
33 |
+
pin_memory=False,
|
34 |
+
num_workers=args.n_threads,
|
35 |
+
)
|
36 |
+
)
|
data/benchmark.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
#from data import common
|
4 |
+
from data import srdata
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
class Benchmark(srdata.SRData):
|
10 |
+
def __init__(self, args, name='', benchmark=True):
|
11 |
+
super(Benchmark, self).__init__(
|
12 |
+
args, name=name, benchmark=True
|
13 |
+
)
|
14 |
+
|
15 |
+
def _set_filesystem(self, dir_data):
|
16 |
+
self.apath = os.path.join(dir_data, 'benchmark', self.name)
|
17 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
18 |
+
if self.input_large:
|
19 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
|
20 |
+
else:
|
21 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
22 |
+
self.ext = ('', '.png')
|
23 |
+
|
data/common.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import skimage.color as sc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def set_channel(*args, n_channels=3):
|
9 |
+
def _set_channel(img):
|
10 |
+
if img.ndim == 2:
|
11 |
+
img = np.expand_dims(img, axis=2)
|
12 |
+
|
13 |
+
c = img.shape[2]
|
14 |
+
if n_channels == 1 and c == 3:
|
15 |
+
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
|
16 |
+
elif n_channels == 3 and c == 1:
|
17 |
+
img = np.concatenate([img] * n_channels, 2)
|
18 |
+
|
19 |
+
return img
|
20 |
+
|
21 |
+
return [_set_channel(a) for a in args]
|
22 |
+
|
23 |
+
def np2Tensor(*args, rgb_range=255):
|
24 |
+
def _np2Tensor(img):
|
25 |
+
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
|
26 |
+
tensor = torch.from_numpy(np_transpose).float()
|
27 |
+
tensor.mul_(rgb_range / 255)
|
28 |
+
|
29 |
+
return tensor
|
30 |
+
|
31 |
+
return [_np2Tensor(a) for a in args]
|
32 |
+
|
33 |
+
|
34 |
+
|
data/data_tiling.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import onnxruntime
|
3 |
+
import numpy as np
|
4 |
+
import math
|
5 |
+
|
6 |
+
|
7 |
+
def tiling_inference(session, lr, overlapping, patch_size):
|
8 |
+
_, _, h, w = lr.shape
|
9 |
+
sr = np.zeros((1, 3, 2*h, 2*w))
|
10 |
+
n_h = math.ceil(h / float(patch_size[0] - overlapping))
|
11 |
+
n_w = math.ceil(w / float(patch_size[1] - overlapping))
|
12 |
+
#every tilling input has same size of patch_size
|
13 |
+
for ih in range(n_h):
|
14 |
+
h_idx = ih * (patch_size[0] - overlapping)
|
15 |
+
h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0]
|
16 |
+
for iw in range(n_w):
|
17 |
+
w_idx = iw * (patch_size[1] - overlapping)
|
18 |
+
w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1]
|
19 |
+
|
20 |
+
tilling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1]]
|
21 |
+
sr_tiling = session.run(None, {session.get_inputs()[0].name: tilling_lr})[0]
|
22 |
+
|
23 |
+
left, right, top, bottom = 0, patch_size[1], 0, patch_size[0]
|
24 |
+
left += overlapping//2
|
25 |
+
right -= overlapping//2
|
26 |
+
top += overlapping//2
|
27 |
+
bottom -= overlapping//2
|
28 |
+
#processing edge pixels
|
29 |
+
if w_idx == 0:
|
30 |
+
left -= overlapping//2
|
31 |
+
if h_idx == 0:
|
32 |
+
top -= overlapping//2
|
33 |
+
if h_idx+patch_size[0]>=h:
|
34 |
+
bottom += overlapping//2
|
35 |
+
if w_idx+patch_size[1]>=w:
|
36 |
+
right += overlapping//2
|
37 |
+
|
38 |
+
#get preditions
|
39 |
+
sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right)] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right]
|
40 |
+
return sr
|
data/srdata.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from data import common
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import imageio
|
10 |
+
import torch
|
11 |
+
import torch.utils.data as data
|
12 |
+
|
13 |
+
class SRData(data.Dataset):
|
14 |
+
def __init__(self, args, name='', benchmark=False):
|
15 |
+
self.args = args
|
16 |
+
self.name = name
|
17 |
+
self.split = 'test'
|
18 |
+
self.do_eval = True
|
19 |
+
self.benchmark = benchmark
|
20 |
+
self.input_large = False
|
21 |
+
self.scale = args.scale
|
22 |
+
self.idx_scale = 0
|
23 |
+
self._set_filesystem(args.dir_data)
|
24 |
+
list_hr, list_lr = self._scan()
|
25 |
+
self.images_hr, self.images_lr = list_hr, list_lr
|
26 |
+
|
27 |
+
# Below functions as used to prepare images
|
28 |
+
def _scan(self):
|
29 |
+
names_hr = sorted(
|
30 |
+
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
|
31 |
+
)
|
32 |
+
names_lr = [[] for _ in self.scale]
|
33 |
+
for f in names_hr:
|
34 |
+
filename, _ = os.path.splitext(os.path.basename(f))
|
35 |
+
for si, s in enumerate(self.scale):
|
36 |
+
names_lr[si].append(os.path.join(
|
37 |
+
self.dir_lr, 'X{}/{}x{}{}'.format(
|
38 |
+
s, filename, s, self.ext[1]
|
39 |
+
)
|
40 |
+
))
|
41 |
+
|
42 |
+
return names_hr, names_lr
|
43 |
+
|
44 |
+
def _set_filesystem(self, dir_data):
|
45 |
+
self.apath = os.path.join(dir_data, self.name)
|
46 |
+
self.dir_hr = os.path.join(self.apath, 'HR')
|
47 |
+
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
|
48 |
+
self.ext = ('.png', '.png')
|
49 |
+
|
50 |
+
def __getitem__(self, idx):
|
51 |
+
lr, hr, filename = self._load_file(idx)
|
52 |
+
pair = self.get_patch(lr, hr)
|
53 |
+
pair = common.set_channel(*pair, n_channels=3)
|
54 |
+
pair_t = common.np2Tensor(*pair, rgb_range=255)
|
55 |
+
|
56 |
+
return pair_t[0], pair_t[1], filename
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.images_hr)
|
60 |
+
|
61 |
+
def _get_index(self, idx):
|
62 |
+
return idx
|
63 |
+
|
64 |
+
def _load_file(self, idx):
|
65 |
+
idx = self._get_index(idx)
|
66 |
+
f_hr = self.images_hr[idx]
|
67 |
+
f_lr = self.images_lr[self.idx_scale][idx]
|
68 |
+
|
69 |
+
filename, _ = os.path.splitext(os.path.basename(f_hr))
|
70 |
+
hr = imageio.imread(f_hr)
|
71 |
+
lr = imageio.imread(f_lr)
|
72 |
+
return lr, hr, filename
|
73 |
+
|
74 |
+
def get_patch(self, lr, hr):
|
75 |
+
scale = self.scale[self.idx_scale]
|
76 |
+
ih, iw = lr.shape[:2]
|
77 |
+
hr = hr[0:ih * scale, 0:iw * scale]
|
78 |
+
return lr, hr
|
79 |
+
|
80 |
+
def set_scale(self, idx_scale):
|
81 |
+
if not self.input_large:
|
82 |
+
self.idx_scale = idx_scale
|
83 |
+
else:
|
84 |
+
self.idx_scale = random.randint(0, len(self.scale) - 1)
|
85 |
+
|
eval_onnx.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import pathlib
|
5 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
6 |
+
sys.path.append(str(CURRENT_DIR))
|
7 |
+
from tqdm import tqdm
|
8 |
+
import data
|
9 |
+
import metric
|
10 |
+
import onnxruntime
|
11 |
+
import cv2
|
12 |
+
from data.data_tiling import tiling_inference
|
13 |
+
import argparse
|
14 |
+
|
15 |
+
class Configs():
|
16 |
+
def __init__(self):
|
17 |
+
parser = argparse.ArgumentParser(description='SR')
|
18 |
+
|
19 |
+
# ipu test or cpu, you need to provide onnx path
|
20 |
+
parser.add_argument('--ipu', action='store_true',
|
21 |
+
help='use ipu')
|
22 |
+
parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
|
23 |
+
help='onnx path')
|
24 |
+
parser.add_argument('--provider_config', type=str, default=None,
|
25 |
+
help='provider config path')
|
26 |
+
# Data specifications, you can use default
|
27 |
+
parser.add_argument('--dir_data', type=str, default='dataset/',
|
28 |
+
help='dataset directory')
|
29 |
+
parser.add_argument('--data_test', type=str, default='Set5',
|
30 |
+
help='test dataset name')
|
31 |
+
|
32 |
+
parser.add_argument('--n_threads', type=int, default=6,
|
33 |
+
help='number of threads for data loading')
|
34 |
+
parser.add_argument('--scale', type=str, default='2',
|
35 |
+
help='super resolution scale, now only support x2')
|
36 |
+
self.parser = parser
|
37 |
+
|
38 |
+
def parse(self):
|
39 |
+
args = self.parser.parse_args()
|
40 |
+
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
|
41 |
+
args.data_test = args.data_test.split('+')
|
42 |
+
print(args)
|
43 |
+
return args
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
def quantize(img, rgb_range): # clamp pix to rgb range
|
48 |
+
pixel_range = 255 / rgb_range
|
49 |
+
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
|
50 |
+
|
51 |
+
def test_model(session, loader, device):
|
52 |
+
torch.set_grad_enabled(False)
|
53 |
+
self_scale = [2]
|
54 |
+
for idx_data, d in enumerate(loader.loader_test):
|
55 |
+
eval_ssim = 0
|
56 |
+
eval_psnr = 0
|
57 |
+
for idx_scale, scale in enumerate(self_scale):
|
58 |
+
d.dataset.set_scale(idx_scale)
|
59 |
+
for lr, hr, filename in tqdm(d, ncols=80):
|
60 |
+
sr = tiling_inference(session, lr.cpu().numpy(), 8, (56, 56))
|
61 |
+
sr = torch.from_numpy(sr).to(device)
|
62 |
+
sr = quantize(sr, 255)
|
63 |
+
eval_psnr += metric.calc_psnr(
|
64 |
+
sr, hr, scale, 255, benchmark=d)
|
65 |
+
eval_ssim += metric.calc_ssim(
|
66 |
+
sr, hr, scale, 255, dataset=d)
|
67 |
+
mean_ssim = eval_ssim / len(d)
|
68 |
+
mean_psnr = eval_psnr / len(d)
|
69 |
+
print("psnr: %s, ssim: %s"%(mean_psnr, mean_ssim))
|
70 |
+
return mean_psnr, mean_ssim
|
71 |
+
|
72 |
+
def main(args):
|
73 |
+
loader = data.Data(args)
|
74 |
+
if args.ipu:
|
75 |
+
providers = ["VitisAIExecutionProvider"]
|
76 |
+
provider_options = [{"config_file": args.provider_config}]
|
77 |
+
else:
|
78 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
79 |
+
provider_options = None
|
80 |
+
onnx_file_name = args.onnx_path
|
81 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
82 |
+
test_model(ort_session, loader, device="cpu")
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
args = Configs().parse()
|
87 |
+
main(args)
|
infer_onnx.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import onnxruntime
|
3 |
+
import cv2
|
4 |
+
import sys
|
5 |
+
import pathlib
|
6 |
+
CURRENT_DIR = pathlib.Path(__file__).parent
|
7 |
+
sys.path.append(str(CURRENT_DIR))
|
8 |
+
import numpy as np
|
9 |
+
from data.data_tiling import tiling_inference
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
def main(args):
|
14 |
+
if args.ipu:
|
15 |
+
providers = ["VitisAIExecutionProvider"]
|
16 |
+
provider_options = [{"config_file": args.provider_config}]
|
17 |
+
else:
|
18 |
+
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
19 |
+
provider_options = None
|
20 |
+
onnx_file_name = args.onnx_path
|
21 |
+
image_path = args.image_path
|
22 |
+
output_path = args.output_path
|
23 |
+
|
24 |
+
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=providers, provider_options=provider_options)
|
25 |
+
lr = cv2.imread(image_path)[np.newaxis,:,:,:].transpose((0,3,1,2)).astype(np.float32)
|
26 |
+
sr = tiling_inference(ort_session, lr, 8, (56, 56))
|
27 |
+
sr = np.clip(sr, 0, 255)
|
28 |
+
sr = sr.squeeze().transpose((1,2,0)).astype(np.uint8)
|
29 |
+
sr = cv2.imwrite(output_path, sr)
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == '__main__':
|
33 |
+
parser = argparse.ArgumentParser(description='RCAN SISR')
|
34 |
+
parser.add_argument('--onnx_path', type=str, default='RCAN_int8.onnx',
|
35 |
+
help='onnx path')
|
36 |
+
parser.add_argument('--image_path', default='test_data/test.png',
|
37 |
+
help='path of your image')
|
38 |
+
parser.add_argument('--output_path', default='test_data/sr.png',
|
39 |
+
help='path of your image')
|
40 |
+
parser.add_argument('--ipu', action='store_true',
|
41 |
+
help='use ipu')
|
42 |
+
parser.add_argument('--provider_config', type=str, default=None,
|
43 |
+
help='provider config path')
|
44 |
+
args = parser.parse_args()
|
45 |
+
main(args)
|
metric.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import skimage.measure
|
5 |
+
import skimage.color
|
6 |
+
|
7 |
+
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
|
8 |
+
if sr.size(-2) > hr.size(-2) or sr.size(-1) > hr.size(-1):
|
9 |
+
print("the dimention of sr image is not equal to hr's! ")
|
10 |
+
sr = sr[:,:,:hr.size(-2),:hr.size(-1)]
|
11 |
+
diff = (sr - hr).data.div(rgb_range)
|
12 |
+
|
13 |
+
if benchmark:
|
14 |
+
shave = scale
|
15 |
+
if diff.size(1) > 1:
|
16 |
+
convert = diff.new(1, 3, 1, 1)
|
17 |
+
convert[0, 0, 0, 0] = 65.738
|
18 |
+
convert[0, 1, 0, 0] = 129.057
|
19 |
+
convert[0, 2, 0, 0] = 25.064
|
20 |
+
diff.mul_(convert).div_(256)
|
21 |
+
diff = diff.sum(dim=1, keepdim=True)
|
22 |
+
else:
|
23 |
+
shave = scale + 6
|
24 |
+
valid = diff[:, :, shave:-shave, shave:-shave]
|
25 |
+
mse = valid.pow(2).mean()
|
26 |
+
|
27 |
+
return -10 * math.log10(mse)
|
28 |
+
|
29 |
+
|
30 |
+
import numpy as np
|
31 |
+
from scipy import signal
|
32 |
+
|
33 |
+
|
34 |
+
def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
|
35 |
+
"""
|
36 |
+
2D gaussian mask - should give the same result as MATLAB's fspecial('gaussian',[shape],[sigma])
|
37 |
+
Acknowledgement : https://stackoverflow.com/questions/17190649/how-to-obtain-a-gaussian-filter-in-python (Author@ali_m)
|
38 |
+
"""
|
39 |
+
m,n = [(ss-1.)/2. for ss in shape]
|
40 |
+
y,x = np.ogrid[-m:m+1,-n:n+1]
|
41 |
+
h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
|
42 |
+
h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
|
43 |
+
sumh = h.sum()
|
44 |
+
if sumh != 0:
|
45 |
+
h /= sumh
|
46 |
+
return h
|
47 |
+
|
48 |
+
def calc_ssim(X, Y, scale, rgb_range, dataset=None, sigma=1.5, K1=0.01, K2=0.03, R=255):
|
49 |
+
'''
|
50 |
+
X : y channel (i.e., luminance) of transformed YCbCr space of X
|
51 |
+
Y : y channel (i.e., luminance) of transformed YCbCr space of Y
|
52 |
+
Please follow the setting of psnr_ssim.m in EDSR (Enhanced Deep Residual Networks for Single Image Super-Resolution CVPRW2017).
|
53 |
+
Official Link : https://github.com/LimBee/NTIRE2017/tree/db34606c2844e89317aac8728a2de562ef1f8aba
|
54 |
+
The authors of EDSR use MATLAB's ssim as the evaluation tool,
|
55 |
+
thus this function is the same as ssim.m in MATLAB with C(3) == C(2)/2.
|
56 |
+
'''
|
57 |
+
gaussian_filter = matlab_style_gauss2D((11, 11), sigma)
|
58 |
+
|
59 |
+
if True:#dataset and dataset.dataset.benchmark:
|
60 |
+
shave = scale
|
61 |
+
if X.size(1) > 1:
|
62 |
+
gray_coeffs = [65.738, 129.057, 25.064]
|
63 |
+
convert = X.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
|
64 |
+
X = X.mul(convert).sum(dim=1)
|
65 |
+
Y = Y.mul(convert).sum(dim=1)
|
66 |
+
else:
|
67 |
+
shave = scale + 6
|
68 |
+
|
69 |
+
X = X[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
|
70 |
+
Y = Y[..., shave:-shave, shave:-shave].squeeze().cpu().numpy().astype(np.float64)
|
71 |
+
|
72 |
+
window = gaussian_filter
|
73 |
+
|
74 |
+
ux = signal.convolve2d(X, window, mode='same', boundary='symm')
|
75 |
+
uy = signal.convolve2d(Y, window, mode='same', boundary='symm')
|
76 |
+
|
77 |
+
uxx = signal.convolve2d(X*X, window, mode='same', boundary='symm')
|
78 |
+
uyy = signal.convolve2d(Y*Y, window, mode='same', boundary='symm')
|
79 |
+
uxy = signal.convolve2d(X*Y, window, mode='same', boundary='symm')
|
80 |
+
|
81 |
+
vx = uxx - ux * ux
|
82 |
+
vy = uyy - uy * uy
|
83 |
+
vxy = uxy - ux * uy
|
84 |
+
|
85 |
+
C1 = (K1 * R) ** 2
|
86 |
+
C2 = (K2 * R) ** 2
|
87 |
+
|
88 |
+
A1, A2, B1, B2 = ((2 * ux * uy + C1, 2 * vxy + C2, ux ** 2 + uy ** 2 + C1, vx + vy + C2))
|
89 |
+
D = B1 * B2
|
90 |
+
S = (A1 * A2) / D
|
91 |
+
mssim = S.mean()
|
92 |
+
|
93 |
+
return mssim
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
numpy>=1.23.5
|
3 |
+
scipy>=1.9
|
4 |
+
opencv-python
|
5 |
+
pandas
|
6 |
+
pillow
|
7 |
+
scikit-image
|
8 |
+
tqdm
|
test_data/test.png
ADDED