Spaces:
Build error
Build error
Upload 5 files
Browse files- scripts/extract_subimages.py +135 -0
- scripts/generate_meta_info.py +58 -0
- scripts/generate_meta_info_pairdata.py +49 -0
- scripts/generate_multiscale_DF2K.py +48 -0
- scripts/pytorch2onnx.py +36 -0
scripts/extract_subimages.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from basicsr.utils import scandir
|
7 |
+
from multiprocessing import Pool
|
8 |
+
from os import path as osp
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
|
12 |
+
def main(args):
|
13 |
+
"""A multi-thread tool to crop large images to sub-images for faster IO.
|
14 |
+
|
15 |
+
opt (dict): Configuration dict. It contains:
|
16 |
+
n_thread (int): Thread number.
|
17 |
+
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
|
18 |
+
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
|
19 |
+
input_folder (str): Path to the input folder.
|
20 |
+
save_folder (str): Path to save folder.
|
21 |
+
crop_size (int): Crop size.
|
22 |
+
step (int): Step for overlapped sliding window.
|
23 |
+
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
24 |
+
|
25 |
+
Usage:
|
26 |
+
For each folder, run this script.
|
27 |
+
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
|
28 |
+
After process, each sub_folder should have the same number of subimages.
|
29 |
+
Remember to modify opt configurations according to your settings.
|
30 |
+
"""
|
31 |
+
|
32 |
+
opt = {}
|
33 |
+
opt['n_thread'] = args.n_thread
|
34 |
+
opt['compression_level'] = args.compression_level
|
35 |
+
opt['input_folder'] = args.input
|
36 |
+
opt['save_folder'] = args.output
|
37 |
+
opt['crop_size'] = args.crop_size
|
38 |
+
opt['step'] = args.step
|
39 |
+
opt['thresh_size'] = args.thresh_size
|
40 |
+
extract_subimages(opt)
|
41 |
+
|
42 |
+
|
43 |
+
def extract_subimages(opt):
|
44 |
+
"""Crop images to subimages.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
opt (dict): Configuration dict. It contains:
|
48 |
+
input_folder (str): Path to the input folder.
|
49 |
+
save_folder (str): Path to save folder.
|
50 |
+
n_thread (int): Thread number.
|
51 |
+
"""
|
52 |
+
input_folder = opt['input_folder']
|
53 |
+
save_folder = opt['save_folder']
|
54 |
+
if not osp.exists(save_folder):
|
55 |
+
os.makedirs(save_folder)
|
56 |
+
print(f'mkdir {save_folder} ...')
|
57 |
+
else:
|
58 |
+
print(f'Folder {save_folder} already exists. Exit.')
|
59 |
+
sys.exit(1)
|
60 |
+
|
61 |
+
# scan all images
|
62 |
+
img_list = list(scandir(input_folder, full_path=True))
|
63 |
+
|
64 |
+
pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
|
65 |
+
pool = Pool(opt['n_thread'])
|
66 |
+
for path in img_list:
|
67 |
+
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
|
68 |
+
pool.close()
|
69 |
+
pool.join()
|
70 |
+
pbar.close()
|
71 |
+
print('All processes done.')
|
72 |
+
|
73 |
+
|
74 |
+
def worker(path, opt):
|
75 |
+
"""Worker for each process.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
path (str): Image path.
|
79 |
+
opt (dict): Configuration dict. It contains:
|
80 |
+
crop_size (int): Crop size.
|
81 |
+
step (int): Step for overlapped sliding window.
|
82 |
+
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
|
83 |
+
save_folder (str): Path to save folder.
|
84 |
+
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
process_info (str): Process information displayed in progress bar.
|
88 |
+
"""
|
89 |
+
crop_size = opt['crop_size']
|
90 |
+
step = opt['step']
|
91 |
+
thresh_size = opt['thresh_size']
|
92 |
+
img_name, extension = osp.splitext(osp.basename(path))
|
93 |
+
|
94 |
+
# remove the x2, x3, x4 and x8 in the filename for DIV2K
|
95 |
+
img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
|
96 |
+
|
97 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
98 |
+
|
99 |
+
h, w = img.shape[0:2]
|
100 |
+
h_space = np.arange(0, h - crop_size + 1, step)
|
101 |
+
if h - (h_space[-1] + crop_size) > thresh_size:
|
102 |
+
h_space = np.append(h_space, h - crop_size)
|
103 |
+
w_space = np.arange(0, w - crop_size + 1, step)
|
104 |
+
if w - (w_space[-1] + crop_size) > thresh_size:
|
105 |
+
w_space = np.append(w_space, w - crop_size)
|
106 |
+
|
107 |
+
index = 0
|
108 |
+
for x in h_space:
|
109 |
+
for y in w_space:
|
110 |
+
index += 1
|
111 |
+
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
|
112 |
+
cropped_img = np.ascontiguousarray(cropped_img)
|
113 |
+
cv2.imwrite(
|
114 |
+
osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
|
115 |
+
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
|
116 |
+
process_info = f'Processing {img_name} ...'
|
117 |
+
return process_info
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == '__main__':
|
121 |
+
parser = argparse.ArgumentParser()
|
122 |
+
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
123 |
+
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
|
124 |
+
parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
|
125 |
+
parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
|
126 |
+
parser.add_argument(
|
127 |
+
'--thresh_size',
|
128 |
+
type=int,
|
129 |
+
default=0,
|
130 |
+
help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
|
131 |
+
parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
|
132 |
+
parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
main(args)
|
scripts/generate_meta_info.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
txt_file = open(args.meta_info, 'w')
|
9 |
+
for folder, root in zip(args.input, args.root):
|
10 |
+
img_paths = sorted(glob.glob(os.path.join(folder, '*')))
|
11 |
+
for img_path in img_paths:
|
12 |
+
status = True
|
13 |
+
if args.check:
|
14 |
+
# read the image once for check, as some images may have errors
|
15 |
+
try:
|
16 |
+
img = cv2.imread(img_path)
|
17 |
+
except (IOError, OSError) as error:
|
18 |
+
print(f'Read {img_path} error: {error}')
|
19 |
+
status = False
|
20 |
+
if img is None:
|
21 |
+
status = False
|
22 |
+
print(f'Img is None: {img_path}')
|
23 |
+
if status:
|
24 |
+
# get the relative path
|
25 |
+
img_name = os.path.relpath(img_path, root)
|
26 |
+
print(img_name)
|
27 |
+
txt_file.write(f'{img_name}\n')
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == '__main__':
|
31 |
+
"""Generate meta info (txt file) for only Ground-Truth images.
|
32 |
+
|
33 |
+
It can also generate meta info from several folders into one txt file.
|
34 |
+
"""
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument(
|
37 |
+
'--input',
|
38 |
+
nargs='+',
|
39 |
+
default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
|
40 |
+
help='Input folder, can be a list')
|
41 |
+
parser.add_argument(
|
42 |
+
'--root',
|
43 |
+
nargs='+',
|
44 |
+
default=['datasets/DF2K', 'datasets/DF2K'],
|
45 |
+
help='Folder root, should have the length as input folders')
|
46 |
+
parser.add_argument(
|
47 |
+
'--meta_info',
|
48 |
+
type=str,
|
49 |
+
default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
|
50 |
+
help='txt path for meta info')
|
51 |
+
parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok')
|
52 |
+
args = parser.parse_args()
|
53 |
+
|
54 |
+
assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
|
55 |
+
f'{len(args.input)} and {len(args.root)}.')
|
56 |
+
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
57 |
+
|
58 |
+
main(args)
|
scripts/generate_meta_info_pairdata.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
def main(args):
|
7 |
+
txt_file = open(args.meta_info, 'w')
|
8 |
+
# sca images
|
9 |
+
img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
|
10 |
+
img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
|
11 |
+
|
12 |
+
assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
|
13 |
+
f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
|
14 |
+
|
15 |
+
for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
|
16 |
+
# get the relative paths
|
17 |
+
img_name_gt = os.path.relpath(img_path_gt, args.root[0])
|
18 |
+
img_name_lq = os.path.relpath(img_path_lq, args.root[1])
|
19 |
+
print(f'{img_name_gt}, {img_name_lq}')
|
20 |
+
txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == '__main__':
|
24 |
+
"""This script is used to generate meta info (txt file) for paired images.
|
25 |
+
"""
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument(
|
28 |
+
'--input',
|
29 |
+
nargs='+',
|
30 |
+
default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
|
31 |
+
help='Input folder, should be [gt_folder, lq_folder]')
|
32 |
+
parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
|
33 |
+
parser.add_argument(
|
34 |
+
'--meta_info',
|
35 |
+
type=str,
|
36 |
+
default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
|
37 |
+
help='txt path for meta info')
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
|
41 |
+
assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
|
42 |
+
os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
|
43 |
+
for i in range(2):
|
44 |
+
if args.input[i].endswith('/'):
|
45 |
+
args.input[i] = args.input[i][:-1]
|
46 |
+
if args.root[i] is None:
|
47 |
+
args.root[i] = os.path.dirname(args.input[i])
|
48 |
+
|
49 |
+
main(args)
|
scripts/generate_multiscale_DF2K.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import glob
|
3 |
+
import os
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
# For DF2K, we consider the following three scales,
|
9 |
+
# and the smallest image whose shortest edge is 400
|
10 |
+
scale_list = [0.75, 0.5, 1 / 3]
|
11 |
+
shortest_edge = 400
|
12 |
+
|
13 |
+
path_list = sorted(glob.glob(os.path.join(args.input, '*')))
|
14 |
+
for path in path_list:
|
15 |
+
print(path)
|
16 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
17 |
+
|
18 |
+
img = Image.open(path)
|
19 |
+
width, height = img.size
|
20 |
+
for idx, scale in enumerate(scale_list):
|
21 |
+
print(f'\t{scale:.2f}')
|
22 |
+
rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
|
23 |
+
rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
|
24 |
+
|
25 |
+
# save the smallest image which the shortest edge is 400
|
26 |
+
if width < height:
|
27 |
+
ratio = height / width
|
28 |
+
width = shortest_edge
|
29 |
+
height = int(width * ratio)
|
30 |
+
else:
|
31 |
+
ratio = width / height
|
32 |
+
height = shortest_edge
|
33 |
+
width = int(height * ratio)
|
34 |
+
rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
|
35 |
+
rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
"""Generate multi-scale versions for GT images with LANCZOS resampling.
|
40 |
+
It is now used for DF2K dataset (DIV2K + Flickr 2K)
|
41 |
+
"""
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
|
44 |
+
parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
|
45 |
+
args = parser.parse_args()
|
46 |
+
|
47 |
+
os.makedirs(args.output, exist_ok=True)
|
48 |
+
main(args)
|
scripts/pytorch2onnx.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import torch.onnx
|
4 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
5 |
+
|
6 |
+
|
7 |
+
def main(args):
|
8 |
+
# An instance of the model
|
9 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
10 |
+
if args.params:
|
11 |
+
keyname = 'params'
|
12 |
+
else:
|
13 |
+
keyname = 'params_ema'
|
14 |
+
model.load_state_dict(torch.load(args.input)[keyname])
|
15 |
+
# set the train mode to false since we will only run the forward pass.
|
16 |
+
model.train(False)
|
17 |
+
model.cpu().eval()
|
18 |
+
|
19 |
+
# An example input
|
20 |
+
x = torch.rand(1, 3, 64, 64)
|
21 |
+
# Export the model
|
22 |
+
with torch.no_grad():
|
23 |
+
torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
|
24 |
+
print(torch_out.shape)
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == '__main__':
|
28 |
+
"""Convert pytorch model to onnx models"""
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument(
|
31 |
+
'--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
|
32 |
+
parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
|
33 |
+
parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
main(args)
|