AK391 commited on
Commit
810c8ea
1 Parent(s): db58d3d
Files changed (45) hide show
  1. VERSION +1 -1
  2. experiments/.DS_Store +0 -0
  3. inference_realesrgan.py +68 -19
  4. inference_realesrgan_video.py +199 -0
  5. options/finetune_realesrgan_x4plus.yml +188 -0
  6. options/finetune_realesrgan_x4plus_pairdata.yml +150 -0
  7. options/setup.cfg +33 -0
  8. options/train_realesrgan_x2plus.yml +186 -0
  9. options/train_realesrgan_x4plus.yml +5 -6
  10. options/train_realesrnet_x2plus.yml +145 -0
  11. options/train_realesrnet_x4plus.yml +4 -4
  12. realesrgan/__init__.py +1 -1
  13. realesrgan/archs/discriminator_arch.py +14 -7
  14. realesrgan/archs/srvgg_arch.py +69 -0
  15. realesrgan/data/realesrgan_dataset.py +29 -12
  16. realesrgan/data/realesrgan_paired_dataset.py +108 -0
  17. realesrgan/models/realesrgan_model.py +29 -13
  18. realesrgan/models/realesrnet_model.py +30 -14
  19. realesrgan/utils.py +79 -30
  20. scripts/extract_subimages.py +135 -0
  21. scripts/generate_meta_info.py +58 -0
  22. scripts/generate_meta_info_pairdata.py +49 -0
  23. scripts/generate_multiscale_DF2K.py +48 -0
  24. scripts/pytorch2onnx.py +30 -11
  25. setup.py +1 -7
  26. tests/data/gt.lmdb/data.mdb +0 -0
  27. tests/data/gt.lmdb/lock.mdb +0 -0
  28. tests/data/gt.lmdb/meta_info.txt +2 -0
  29. tests/data/gt/baboon.png +0 -0
  30. tests/data/gt/comic.png +0 -0
  31. tests/data/lq.lmdb/data.mdb +0 -0
  32. tests/data/lq.lmdb/lock.mdb +0 -0
  33. tests/data/lq.lmdb/meta_info.txt +2 -0
  34. tests/data/lq/baboon.png +0 -0
  35. tests/data/lq/comic.png +0 -0
  36. tests/data/meta_info_gt.txt +2 -0
  37. tests/data/meta_info_pair.txt +2 -0
  38. tests/data/test_realesrgan_dataset.yml +28 -0
  39. tests/data/test_realesrgan_model.yml +115 -0
  40. tests/data/test_realesrgan_paired_dataset.yml +13 -0
  41. tests/data/test_realesrnet_model.yml +75 -0
  42. tests/test_dataset.py +151 -0
  43. tests/test_discriminator_arch.py +19 -0
  44. tests/test_model.py +126 -0
  45. tests/test_utils.py +87 -0
VERSION CHANGED
@@ -1 +1 @@
1
- 0.2.1
1
+ 0.2.3.0
experiments/.DS_Store ADDED
Binary file (6.15 kB). View file
inference_realesrgan.py CHANGED
@@ -2,25 +2,32 @@ import argparse
2
  import cv2
3
  import glob
4
  import os
 
5
 
6
  from realesrgan import RealESRGANer
 
7
 
8
 
9
  def main():
 
 
10
  parser = argparse.ArgumentParser()
11
- parser.add_argument('--input', type=str, default='inputs', help='Input image or folder')
12
  parser.add_argument(
13
- '--model_path',
 
14
  type=str,
15
- default='RealESRGAN_x4plus.pth',
16
- help='Path to the pre-trained model')
17
- parser.add_argument('--output', type=str, default='results', help='Output folder')
18
- parser.add_argument('--netscale', type=int, default=4, help='Upsample scale factor of the network')
19
- parser.add_argument('--outscale', type=float, default=4, help='The final upsampling scale of the image')
 
20
  parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
21
- parser.add_argument('--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
22
  parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
23
  parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
 
24
  parser.add_argument('--half', action='store_true', help='Use half precision during inference')
25
  parser.add_argument(
26
  '--alpha_upsampler',
@@ -34,14 +41,55 @@ def main():
34
  help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
35
  args = parser.parse_args()
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  upsampler = RealESRGANer(
38
- scale=args.netscale,
39
- model_path=args.model_path,
 
40
  tile=args.tile,
41
  tile_pad=args.tile_pad,
42
  pre_pad=args.pre_pad,
43
  half=args.half)
 
 
 
 
 
 
 
 
 
44
  os.makedirs(args.output, exist_ok=True)
 
45
  if os.path.isfile(args.input):
46
  paths = [args.input]
47
  else:
@@ -52,18 +100,19 @@ def main():
52
  print('Testing', idx, imgname)
53
 
54
  img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
55
- h, w = img.shape[0:2]
56
- if max(h, w) > 1000 and args.netscale == 4:
57
- import warnings
58
- warnings.warn('The input image is large, try X2 model for better performace.')
59
- if max(h, w) < 500 and args.netscale == 2:
60
- import warnings
61
- warnings.warn('The input image is small, try X4 model for better performace.')
62
 
63
  try:
64
- output, img_mode = upsampler.enhance(img, outscale=args.outscale)
65
- except Exception as error:
 
 
 
66
  print('Error', error)
 
67
  else:
68
  if args.ext == 'auto':
69
  extension = extension[1:]
2
  import cv2
3
  import glob
4
  import os
5
+ from basicsr.archs.rrdbnet_arch import RRDBNet
6
 
7
  from realesrgan import RealESRGANer
8
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
9
 
10
 
11
  def main():
12
+ """Inference demo for Real-ESRGAN.
13
+ """
14
  parser = argparse.ArgumentParser()
15
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
16
  parser.add_argument(
17
+ '-n',
18
+ '--model_name',
19
  type=str,
20
+ default='RealESRGAN_x4plus',
21
+ help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus'
22
+ 'RealESRGANv2-anime-xsx2 | RealESRGANv2-animevideo-xsx2-nousm | RealESRGANv2-animevideo-xsx2'
23
+ 'RealESRGANv2-anime-xsx4 | RealESRGANv2-animevideo-xsx4-nousm | RealESRGANv2-animevideo-xsx4'))
24
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
25
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
26
  parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored image')
27
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
28
  parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
29
  parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
30
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
31
  parser.add_argument('--half', action='store_true', help='Use half precision during inference')
32
  parser.add_argument(
33
  '--alpha_upsampler',
41
  help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
42
  args = parser.parse_args()
43
 
44
+ # determine models according to model names
45
+ args.model_name = args.model_name.split('.')[0]
46
+ if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
47
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
48
+ netscale = 4
49
+ elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
50
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
51
+ netscale = 4
52
+ elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
53
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
54
+ netscale = 2
55
+ elif args.model_name in [
56
+ 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
57
+ ]: # x2 VGG-style model (XS size)
58
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
59
+ netscale = 2
60
+ elif args.model_name in [
61
+ 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
62
+ ]: # x4 VGG-style model (XS size)
63
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
64
+ netscale = 4
65
+
66
+ # determine model paths
67
+ model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
68
+ if not os.path.isfile(model_path):
69
+ model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
70
+ if not os.path.isfile(model_path):
71
+ raise ValueError(f'Model {args.model_name} does not exist.')
72
+
73
+ # restorer
74
  upsampler = RealESRGANer(
75
+ scale=netscale,
76
+ model_path=model_path,
77
+ model=model,
78
  tile=args.tile,
79
  tile_pad=args.tile_pad,
80
  pre_pad=args.pre_pad,
81
  half=args.half)
82
+
83
+ if args.face_enhance: # Use GFPGAN for face enhancement
84
+ from gfpgan import GFPGANer
85
+ face_enhancer = GFPGANer(
86
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
87
+ upscale=args.outscale,
88
+ arch='clean',
89
+ channel_multiplier=2,
90
+ bg_upsampler=upsampler)
91
  os.makedirs(args.output, exist_ok=True)
92
+
93
  if os.path.isfile(args.input):
94
  paths = [args.input]
95
  else:
100
  print('Testing', idx, imgname)
101
 
102
  img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
103
+ if len(img.shape) == 3 and img.shape[2] == 4:
104
+ img_mode = 'RGBA'
105
+ else:
106
+ img_mode = None
 
 
 
107
 
108
  try:
109
+ if args.face_enhance:
110
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
111
+ else:
112
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
113
+ except RuntimeError as error:
114
  print('Error', error)
115
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
116
  else:
117
  if args.ext == 'auto':
118
  extension = extension[1:]
inference_realesrgan_video.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import mimetypes
4
+ import os
5
+ import queue
6
+ import shutil
7
+ import torch
8
+ from basicsr.archs.rrdbnet_arch import RRDBNet
9
+ from basicsr.utils.logger import AvgTimer
10
+ from tqdm import tqdm
11
+
12
+ from realesrgan import IOConsumer, PrefetchReader, RealESRGANer
13
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
14
+
15
+
16
+ def main():
17
+ """Inference demo for Real-ESRGAN.
18
+ It mainly for restoring anime videos.
19
+
20
+ """
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
23
+ parser.add_argument(
24
+ '-n',
25
+ '--model_name',
26
+ type=str,
27
+ default='RealESRGAN_x4plus',
28
+ help=('Model names: RealESRGAN_x4plus | RealESRNet_x4plus | RealESRGAN_x4plus_anime_6B | RealESRGAN_x2plus'
29
+ 'RealESRGANv2-anime-xsx2 | RealESRGANv2-animevideo-xsx2-nousm | RealESRGANv2-animevideo-xsx2'
30
+ 'RealESRGANv2-anime-xsx4 | RealESRGANv2-animevideo-xsx4-nousm | RealESRGANv2-animevideo-xsx4'))
31
+ parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
32
+ parser.add_argument('-s', '--outscale', type=float, default=4, help='The final upsampling scale of the image')
33
+ parser.add_argument('--suffix', type=str, default='out', help='Suffix of the restored video')
34
+ parser.add_argument('-t', '--tile', type=int, default=0, help='Tile size, 0 for no tile during testing')
35
+ parser.add_argument('--tile_pad', type=int, default=10, help='Tile padding')
36
+ parser.add_argument('--pre_pad', type=int, default=0, help='Pre padding size at each border')
37
+ parser.add_argument('--face_enhance', action='store_true', help='Use GFPGAN to enhance face')
38
+ parser.add_argument('--half', action='store_true', help='Use half precision during inference')
39
+ parser.add_argument('-v', '--video', action='store_true', help='Output a video using ffmpeg')
40
+ parser.add_argument('-a', '--audio', action='store_true', help='Keep audio')
41
+ parser.add_argument('--fps', type=float, default=None, help='FPS of the output video')
42
+ parser.add_argument('--consumer', type=int, default=4, help='Number of IO consumers')
43
+
44
+ parser.add_argument(
45
+ '--alpha_upsampler',
46
+ type=str,
47
+ default='realesrgan',
48
+ help='The upsampler for the alpha channels. Options: realesrgan | bicubic')
49
+ parser.add_argument(
50
+ '--ext',
51
+ type=str,
52
+ default='auto',
53
+ help='Image extension. Options: auto | jpg | png, auto means using the same extension as inputs')
54
+ args = parser.parse_args()
55
+
56
+ # ---------------------- determine models according to model names ---------------------- #
57
+ args.model_name = args.model_name.split('.')[0]
58
+ if args.model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
59
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
60
+ netscale = 4
61
+ elif args.model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
62
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
63
+ netscale = 4
64
+ elif args.model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
65
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
66
+ netscale = 2
67
+ elif args.model_name in [
68
+ 'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
69
+ ]: # x2 VGG-style model (XS size)
70
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
71
+ netscale = 2
72
+ elif args.model_name in [
73
+ 'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
74
+ ]: # x4 VGG-style model (XS size)
75
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
76
+ netscale = 4
77
+
78
+ # ---------------------- determine model paths ---------------------- #
79
+ model_path = os.path.join('experiments/pretrained_models', args.model_name + '.pth')
80
+ if not os.path.isfile(model_path):
81
+ model_path = os.path.join('realesrgan/weights', args.model_name + '.pth')
82
+ if not os.path.isfile(model_path):
83
+ raise ValueError(f'Model {args.model_name} does not exist.')
84
+
85
+ # restorer
86
+ upsampler = RealESRGANer(
87
+ scale=netscale,
88
+ model_path=model_path,
89
+ model=model,
90
+ tile=args.tile,
91
+ tile_pad=args.tile_pad,
92
+ pre_pad=args.pre_pad,
93
+ half=args.half)
94
+
95
+ if args.face_enhance: # Use GFPGAN for face enhancement
96
+ from gfpgan import GFPGANer
97
+ face_enhancer = GFPGANer(
98
+ model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
99
+ upscale=args.outscale,
100
+ arch='clean',
101
+ channel_multiplier=2,
102
+ bg_upsampler=upsampler)
103
+ os.makedirs(args.output, exist_ok=True)
104
+ # for saving restored frames
105
+ save_frame_folder = os.path.join(args.output, 'frames_tmpout')
106
+ os.makedirs(save_frame_folder, exist_ok=True)
107
+
108
+ if mimetypes.guess_type(args.input)[0].startswith('video'): # is a video file
109
+ video_name = os.path.splitext(os.path.basename(args.input))[0]
110
+ frame_folder = os.path.join('tmp_frames', video_name)
111
+ os.makedirs(frame_folder, exist_ok=True)
112
+ # use ffmpeg to extract frames
113
+ os.system(f'ffmpeg -i {args.input} -qscale:v 1 -qmin 1 -qmax 1 -vsync 0 {frame_folder}/frame%08d.png')
114
+ # get image path list
115
+ paths = sorted(glob.glob(os.path.join(frame_folder, '*')))
116
+ if args.video:
117
+ if args.fps is None:
118
+ # get input video fps
119
+ import ffmpeg
120
+ probe = ffmpeg.probe(args.input)
121
+ video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
122
+ args.fps = eval(video_streams[0]['avg_frame_rate'])
123
+ elif mimetypes.guess_type(args.input)[0].startswith('image'): # is an image file
124
+ paths = [args.input]
125
+ video_name = 'video'
126
+ else:
127
+ paths = sorted(glob.glob(os.path.join(args.input, '*')))
128
+ video_name = 'video'
129
+
130
+ timer = AvgTimer()
131
+ timer.start()
132
+ pbar = tqdm(total=len(paths), unit='frame', desc='inference')
133
+ # set up prefetch reader
134
+ reader = PrefetchReader(paths, num_prefetch_queue=4)
135
+ reader.start()
136
+
137
+ que = queue.Queue()
138
+ consumers = [IOConsumer(args, que, f'IO_{i}') for i in range(args.consumer)]
139
+ for consumer in consumers:
140
+ consumer.start()
141
+
142
+ for idx, (path, img) in enumerate(zip(paths, reader)):
143
+ imgname, extension = os.path.splitext(os.path.basename(path))
144
+ if len(img.shape) == 3 and img.shape[2] == 4:
145
+ img_mode = 'RGBA'
146
+ else:
147
+ img_mode = None
148
+
149
+ try:
150
+ if args.face_enhance:
151
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
152
+ else:
153
+ output, _ = upsampler.enhance(img, outscale=args.outscale)
154
+ except RuntimeError as error:
155
+ print('Error', error)
156
+ print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
157
+
158
+ else:
159
+ if args.ext == 'auto':
160
+ extension = extension[1:]
161
+ else:
162
+ extension = args.ext
163
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
164
+ extension = 'png'
165
+ save_path = os.path.join(save_frame_folder, f'{imgname}_out.{extension}')
166
+
167
+ que.put({'output': output, 'save_path': save_path})
168
+
169
+ pbar.update(1)
170
+ torch.cuda.synchronize()
171
+ timer.record()
172
+ avg_fps = 1. / (timer.get_avg_time() + 1e-7)
173
+ pbar.set_description(f'idx {idx}, fps {avg_fps:.2f}')
174
+
175
+ for _ in range(args.consumer):
176
+ que.put('quit')
177
+ for consumer in consumers:
178
+ consumer.join()
179
+ pbar.close()
180
+
181
+ # merge frames to video
182
+ if args.video:
183
+ video_save_path = os.path.join(args.output, f'{video_name}_{args.suffix}.mp4')
184
+ if args.audio:
185
+ os.system(
186
+ f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} -i {args.input}'
187
+ f' -map 0:v:0 -map 1:a:0 -c:a copy -c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}')
188
+ else:
189
+ os.system(f'ffmpeg -r {args.fps} -i {save_frame_folder}/frame%08d_out.{extension} '
190
+ f'-c:v libx264 -r {args.fps} -pix_fmt yuv420p {video_save_path}')
191
+
192
+ # delete tmp file
193
+ shutil.rmtree(save_frame_folder)
194
+ if os.path.isdir(frame_folder):
195
+ shutil.rmtree(frame_folder)
196
+
197
+
198
+ if __name__ == '__main__':
199
+ main()
options/finetune_realesrgan_x4plus.yml ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: finetune_RealESRGANx4plus_400k
3
+ model_type: RealESRGANModel
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9
+ # USM the ground-truth
10
+ l1_gt_usm: True
11
+ percep_gt_usm: True
12
+ gan_gt_usm: False
13
+
14
+ # the first degradation process
15
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16
+ resize_range: [0.15, 1.5]
17
+ gaussian_noise_prob: 0.5
18
+ noise_range: [1, 30]
19
+ poisson_scale_range: [0.05, 3]
20
+ gray_noise_prob: 0.4
21
+ jpeg_range: [30, 95]
22
+
23
+ # the second degradation process
24
+ second_blur_prob: 0.8
25
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26
+ resize_range2: [0.3, 1.2]
27
+ gaussian_noise_prob2: 0.5
28
+ noise_range2: [1, 25]
29
+ poisson_scale_range2: [0.05, 2.5]
30
+ gray_noise_prob2: 0.4
31
+ jpeg_range2: [30, 95]
32
+
33
+ gt_size: 256
34
+ queue_size: 180
35
+
36
+ # dataset and data loader settings
37
+ datasets:
38
+ train:
39
+ name: DF2K+OST
40
+ type: RealESRGANDataset
41
+ dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
+ io_backend:
44
+ type: disk
45
+
46
+ blur_kernel_size: 21
47
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ sinc_prob: 0.1
50
+ blur_sigma: [0.2, 3]
51
+ betag_range: [0.5, 4]
52
+ betap_range: [1, 2]
53
+
54
+ blur_kernel_size2: 21
55
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ sinc_prob2: 0.1
58
+ blur_sigma2: [0.2, 1.5]
59
+ betag_range2: [0.5, 4]
60
+ betap_range2: [1, 2]
61
+
62
+ final_sinc_prob: 0.8
63
+
64
+ gt_size: 256
65
+ use_hflip: True
66
+ use_rot: False
67
+
68
+ # data loader
69
+ use_shuffle: true
70
+ num_worker_per_gpu: 5
71
+ batch_size_per_gpu: 12
72
+ dataset_enlarge_ratio: 1
73
+ prefetch_mode: ~
74
+
75
+ # Uncomment these for validation
76
+ # val:
77
+ # name: validation
78
+ # type: PairedImageDataset
79
+ # dataroot_gt: path_to_gt
80
+ # dataroot_lq: path_to_lq
81
+ # io_backend:
82
+ # type: disk
83
+
84
+ # network structures
85
+ network_g:
86
+ type: RRDBNet
87
+ num_in_ch: 3
88
+ num_out_ch: 3
89
+ num_feat: 64
90
+ num_block: 23
91
+ num_grow_ch: 32
92
+
93
+ network_d:
94
+ type: UNetDiscriminatorSN
95
+ num_in_ch: 3
96
+ num_feat: 64
97
+ skip_connection: True
98
+
99
+ # path
100
+ path:
101
+ # use the pre-trained Real-ESRNet model
102
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103
+ param_key_g: params_ema
104
+ strict_load_g: true
105
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
106
+ param_key_d: params
107
+ strict_load_d: true
108
+ resume_state: ~
109
+
110
+ # training settings
111
+ train:
112
+ ema_decay: 0.999
113
+ optim_g:
114
+ type: Adam
115
+ lr: !!float 1e-4
116
+ weight_decay: 0
117
+ betas: [0.9, 0.99]
118
+ optim_d:
119
+ type: Adam
120
+ lr: !!float 1e-4
121
+ weight_decay: 0
122
+ betas: [0.9, 0.99]
123
+
124
+ scheduler:
125
+ type: MultiStepLR
126
+ milestones: [400000]
127
+ gamma: 0.5
128
+
129
+ total_iter: 400000
130
+ warmup_iter: -1 # no warm up
131
+
132
+ # losses
133
+ pixel_opt:
134
+ type: L1Loss
135
+ loss_weight: 1.0
136
+ reduction: mean
137
+ # perceptual loss (content and style losses)
138
+ perceptual_opt:
139
+ type: PerceptualLoss
140
+ layer_weights:
141
+ # before relu
142
+ 'conv1_2': 0.1
143
+ 'conv2_2': 0.1
144
+ 'conv3_4': 1
145
+ 'conv4_4': 1
146
+ 'conv5_4': 1
147
+ vgg_type: vgg19
148
+ use_input_norm: true
149
+ perceptual_weight: !!float 1.0
150
+ style_weight: 0
151
+ range_norm: false
152
+ criterion: l1
153
+ # gan loss
154
+ gan_opt:
155
+ type: GANLoss
156
+ gan_type: vanilla
157
+ real_label_val: 1.0
158
+ fake_label_val: 0.0
159
+ loss_weight: !!float 1e-1
160
+
161
+ net_d_iters: 1
162
+ net_d_init_iters: 0
163
+
164
+ # Uncomment these for validation
165
+ # validation settings
166
+ # val:
167
+ # val_freq: !!float 5e3
168
+ # save_img: True
169
+
170
+ # metrics:
171
+ # psnr: # metric name
172
+ # type: calculate_psnr
173
+ # crop_border: 4
174
+ # test_y_channel: false
175
+
176
+ # logging settings
177
+ logger:
178
+ print_freq: 100
179
+ save_checkpoint_freq: !!float 5e3
180
+ use_tb_logger: true
181
+ wandb:
182
+ project: ~
183
+ resume_id: ~
184
+
185
+ # dist training settings
186
+ dist_params:
187
+ backend: nccl
188
+ port: 29500
options/finetune_realesrgan_x4plus_pairdata.yml ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: finetune_RealESRGANx4plus_400k_pairdata
3
+ model_type: RealESRGANModel
4
+ scale: 4
5
+ num_gpu: auto
6
+ manual_seed: 0
7
+
8
+ # USM the ground-truth
9
+ l1_gt_usm: True
10
+ percep_gt_usm: True
11
+ gan_gt_usm: False
12
+
13
+ high_order_degradation: False # do not use the high-order degradation generation process
14
+
15
+ # dataset and data loader settings
16
+ datasets:
17
+ train:
18
+ name: DIV2K
19
+ type: RealESRGANPairedDataset
20
+ dataroot_gt: datasets/DF2K
21
+ dataroot_lq: datasets/DF2K
22
+ meta_info: datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt
23
+ io_backend:
24
+ type: disk
25
+
26
+ gt_size: 256
27
+ use_hflip: True
28
+ use_rot: False
29
+
30
+ # data loader
31
+ use_shuffle: true
32
+ num_worker_per_gpu: 5
33
+ batch_size_per_gpu: 12
34
+ dataset_enlarge_ratio: 1
35
+ prefetch_mode: ~
36
+
37
+ # Uncomment these for validation
38
+ # val:
39
+ # name: validation
40
+ # type: PairedImageDataset
41
+ # dataroot_gt: path_to_gt
42
+ # dataroot_lq: path_to_lq
43
+ # io_backend:
44
+ # type: disk
45
+
46
+ # network structures
47
+ network_g:
48
+ type: RRDBNet
49
+ num_in_ch: 3
50
+ num_out_ch: 3
51
+ num_feat: 64
52
+ num_block: 23
53
+ num_grow_ch: 32
54
+
55
+ network_d:
56
+ type: UNetDiscriminatorSN
57
+ num_in_ch: 3
58
+ num_feat: 64
59
+ skip_connection: True
60
+
61
+ # path
62
+ path:
63
+ # use the pre-trained Real-ESRNet model
64
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
65
+ param_key_g: params_ema
66
+ strict_load_g: true
67
+ pretrain_network_d: experiments/pretrained_models/RealESRGAN_x4plus_netD.pth
68
+ param_key_d: params
69
+ strict_load_d: true
70
+ resume_state: ~
71
+
72
+ # training settings
73
+ train:
74
+ ema_decay: 0.999
75
+ optim_g:
76
+ type: Adam
77
+ lr: !!float 1e-4
78
+ weight_decay: 0
79
+ betas: [0.9, 0.99]
80
+ optim_d:
81
+ type: Adam
82
+ lr: !!float 1e-4
83
+ weight_decay: 0
84
+ betas: [0.9, 0.99]
85
+
86
+ scheduler:
87
+ type: MultiStepLR
88
+ milestones: [400000]
89
+ gamma: 0.5
90
+
91
+ total_iter: 400000
92
+ warmup_iter: -1 # no warm up
93
+
94
+ # losses
95
+ pixel_opt:
96
+ type: L1Loss
97
+ loss_weight: 1.0
98
+ reduction: mean
99
+ # perceptual loss (content and style losses)
100
+ perceptual_opt:
101
+ type: PerceptualLoss
102
+ layer_weights:
103
+ # before relu
104
+ 'conv1_2': 0.1
105
+ 'conv2_2': 0.1
106
+ 'conv3_4': 1
107
+ 'conv4_4': 1
108
+ 'conv5_4': 1
109
+ vgg_type: vgg19
110
+ use_input_norm: true
111
+ perceptual_weight: !!float 1.0
112
+ style_weight: 0
113
+ range_norm: false
114
+ criterion: l1
115
+ # gan loss
116
+ gan_opt:
117
+ type: GANLoss
118
+ gan_type: vanilla
119
+ real_label_val: 1.0
120
+ fake_label_val: 0.0
121
+ loss_weight: !!float 1e-1
122
+
123
+ net_d_iters: 1
124
+ net_d_init_iters: 0
125
+
126
+ # Uncomment these for validation
127
+ # validation settings
128
+ # val:
129
+ # val_freq: !!float 5e3
130
+ # save_img: True
131
+
132
+ # metrics:
133
+ # psnr: # metric name
134
+ # type: calculate_psnr
135
+ # crop_border: 4
136
+ # test_y_channel: false
137
+
138
+ # logging settings
139
+ logger:
140
+ print_freq: 100
141
+ save_checkpoint_freq: !!float 5e3
142
+ use_tb_logger: true
143
+ wandb:
144
+ project: ~
145
+ resume_id: ~
146
+
147
+ # dist training settings
148
+ dist_params:
149
+ backend: nccl
150
+ port: 29500
options/setup.cfg ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore =
3
+ # line break before binary operator (W503)
4
+ W503,
5
+ # line break after binary operator (W504)
6
+ W504,
7
+ max-line-length=120
8
+
9
+ [yapf]
10
+ based_on_style = pep8
11
+ column_limit = 120
12
+ blank_line_before_nested_class_or_def = true
13
+ split_before_expression_after_opening_paren = true
14
+
15
+ [isort]
16
+ line_length = 120
17
+ multi_line_output = 0
18
+ known_standard_library = pkg_resources,setuptools
19
+ known_first_party = realesrgan
20
+ known_third_party = PIL,basicsr,cv2,numpy,pytest,torch,torchvision,tqdm,yaml
21
+ no_lines_before = STDLIB,LOCALFOLDER
22
+ default_section = THIRDPARTY
23
+
24
+ [codespell]
25
+ skip = .git,./docs/build
26
+ count =
27
+ quiet-level = 3
28
+
29
+ [aliases]
30
+ test=pytest
31
+
32
+ [tool:pytest]
33
+ addopts=tests/
options/train_realesrgan_x2plus.yml ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRGANx2plus_400k_B12G4
3
+ model_type: RealESRGANModel
4
+ scale: 2
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
9
+ # USM the ground-truth
10
+ l1_gt_usm: True
11
+ percep_gt_usm: True
12
+ gan_gt_usm: False
13
+
14
+ # the first degradation process
15
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
16
+ resize_range: [0.15, 1.5]
17
+ gaussian_noise_prob: 0.5
18
+ noise_range: [1, 30]
19
+ poisson_scale_range: [0.05, 3]
20
+ gray_noise_prob: 0.4
21
+ jpeg_range: [30, 95]
22
+
23
+ # the second degradation process
24
+ second_blur_prob: 0.8
25
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
26
+ resize_range2: [0.3, 1.2]
27
+ gaussian_noise_prob2: 0.5
28
+ noise_range2: [1, 25]
29
+ poisson_scale_range2: [0.05, 2.5]
30
+ gray_noise_prob2: 0.4
31
+ jpeg_range2: [30, 95]
32
+
33
+ gt_size: 256
34
+ queue_size: 180
35
+
36
+ # dataset and data loader settings
37
+ datasets:
38
+ train:
39
+ name: DF2K+OST
40
+ type: RealESRGANDataset
41
+ dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
+ io_backend:
44
+ type: disk
45
+
46
+ blur_kernel_size: 21
47
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
48
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
49
+ sinc_prob: 0.1
50
+ blur_sigma: [0.2, 3]
51
+ betag_range: [0.5, 4]
52
+ betap_range: [1, 2]
53
+
54
+ blur_kernel_size2: 21
55
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
56
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
57
+ sinc_prob2: 0.1
58
+ blur_sigma2: [0.2, 1.5]
59
+ betag_range2: [0.5, 4]
60
+ betap_range2: [1, 2]
61
+
62
+ final_sinc_prob: 0.8
63
+
64
+ gt_size: 256
65
+ use_hflip: True
66
+ use_rot: False
67
+
68
+ # data loader
69
+ use_shuffle: true
70
+ num_worker_per_gpu: 5
71
+ batch_size_per_gpu: 12
72
+ dataset_enlarge_ratio: 1
73
+ prefetch_mode: ~
74
+
75
+ # Uncomment these for validation
76
+ # val:
77
+ # name: validation
78
+ # type: PairedImageDataset
79
+ # dataroot_gt: path_to_gt
80
+ # dataroot_lq: path_to_lq
81
+ # io_backend:
82
+ # type: disk
83
+
84
+ # network structures
85
+ network_g:
86
+ type: RRDBNet
87
+ num_in_ch: 3
88
+ num_out_ch: 3
89
+ num_feat: 64
90
+ num_block: 23
91
+ num_grow_ch: 32
92
+ scale: 2
93
+
94
+ network_d:
95
+ type: UNetDiscriminatorSN
96
+ num_in_ch: 3
97
+ num_feat: 64
98
+ skip_connection: True
99
+
100
+ # path
101
+ path:
102
+ # use the pre-trained Real-ESRNet model
103
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x2plus.pth
104
+ param_key_g: params_ema
105
+ strict_load_g: true
106
+ resume_state: ~
107
+
108
+ # training settings
109
+ train:
110
+ ema_decay: 0.999
111
+ optim_g:
112
+ type: Adam
113
+ lr: !!float 1e-4
114
+ weight_decay: 0
115
+ betas: [0.9, 0.99]
116
+ optim_d:
117
+ type: Adam
118
+ lr: !!float 1e-4
119
+ weight_decay: 0
120
+ betas: [0.9, 0.99]
121
+
122
+ scheduler:
123
+ type: MultiStepLR
124
+ milestones: [400000]
125
+ gamma: 0.5
126
+
127
+ total_iter: 400000
128
+ warmup_iter: -1 # no warm up
129
+
130
+ # losses
131
+ pixel_opt:
132
+ type: L1Loss
133
+ loss_weight: 1.0
134
+ reduction: mean
135
+ # perceptual loss (content and style losses)
136
+ perceptual_opt:
137
+ type: PerceptualLoss
138
+ layer_weights:
139
+ # before relu
140
+ 'conv1_2': 0.1
141
+ 'conv2_2': 0.1
142
+ 'conv3_4': 1
143
+ 'conv4_4': 1
144
+ 'conv5_4': 1
145
+ vgg_type: vgg19
146
+ use_input_norm: true
147
+ perceptual_weight: !!float 1.0
148
+ style_weight: 0
149
+ range_norm: false
150
+ criterion: l1
151
+ # gan loss
152
+ gan_opt:
153
+ type: GANLoss
154
+ gan_type: vanilla
155
+ real_label_val: 1.0
156
+ fake_label_val: 0.0
157
+ loss_weight: !!float 1e-1
158
+
159
+ net_d_iters: 1
160
+ net_d_init_iters: 0
161
+
162
+ # Uncomment these for validation
163
+ # validation settings
164
+ # val:
165
+ # val_freq: !!float 5e3
166
+ # save_img: True
167
+
168
+ # metrics:
169
+ # psnr: # metric name
170
+ # type: calculate_psnr
171
+ # crop_border: 4
172
+ # test_y_channel: false
173
+
174
+ # logging settings
175
+ logger:
176
+ print_freq: 100
177
+ save_checkpoint_freq: !!float 5e3
178
+ use_tb_logger: true
179
+ wandb:
180
+ project: ~
181
+ resume_id: ~
182
+
183
+ # dist training settings
184
+ dist_params:
185
+ backend: nccl
186
+ port: 29500
options/train_realesrgan_x4plus.yml CHANGED
@@ -1,8 +1,8 @@
1
  # general settings
2
- name: train_RealESRGANx4plus_400k_B12G4_fromRealESRNet
3
  model_type: RealESRGANModel
4
  scale: 4
5
- num_gpu: 4
6
  manual_seed: 0
7
 
8
  # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
@@ -39,7 +39,7 @@ datasets:
39
  name: DF2K+OST
40
  type: RealESRGANDataset
41
  dataroot_gt: datasets/DF2K
42
- meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
  io_backend:
44
  type: disk
45
 
@@ -90,7 +90,6 @@ network_g:
90
  num_block: 23
91
  num_grow_ch: 32
92
 
93
-
94
  network_d:
95
  type: UNetDiscriminatorSN
96
  num_in_ch: 3
@@ -100,7 +99,7 @@ network_d:
100
  # path
101
  path:
102
  # use the pre-trained Real-ESRNet model
103
- pretrain_network_g: experiments/train_RealESRNetx4plus_1000k_B12G4_fromESRGAN/models/net_g_1000000.pth
104
  param_key_g: params_ema
105
  strict_load_g: true
106
  resume_state: ~
@@ -166,7 +165,7 @@ train:
166
  # save_img: True
167
 
168
  # metrics:
169
- # psnr: # metric name, can be arbitrary
170
  # type: calculate_psnr
171
  # crop_border: 4
172
  # test_y_channel: false
1
  # general settings
2
+ name: train_RealESRGANx4plus_400k_B12G4
3
  model_type: RealESRGANModel
4
  scale: 4
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
  manual_seed: 0
7
 
8
  # ----------------- options for synthesizing training data in RealESRGANModel ----------------- #
39
  name: DF2K+OST
40
  type: RealESRGANDataset
41
  dataroot_gt: datasets/DF2K
42
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
43
  io_backend:
44
  type: disk
45
 
90
  num_block: 23
91
  num_grow_ch: 32
92
 
 
93
  network_d:
94
  type: UNetDiscriminatorSN
95
  num_in_ch: 3
99
  # path
100
  path:
101
  # use the pre-trained Real-ESRNet model
102
+ pretrain_network_g: experiments/pretrained_models/RealESRNet_x4plus.pth
103
  param_key_g: params_ema
104
  strict_load_g: true
105
  resume_state: ~
165
  # save_img: True
166
 
167
  # metrics:
168
+ # psnr: # metric name
169
  # type: calculate_psnr
170
  # crop_border: 4
171
  # test_y_channel: false
options/train_realesrnet_x2plus.yml ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: train_RealESRNetx2plus_1000k_B12G4
3
+ model_type: RealESRNetModel
4
+ scale: 2
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
+ manual_seed: 0
7
+
8
+ # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
9
+ gt_usm: True # USM the ground-truth
10
+
11
+ # the first degradation process
12
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
13
+ resize_range: [0.15, 1.5]
14
+ gaussian_noise_prob: 0.5
15
+ noise_range: [1, 30]
16
+ poisson_scale_range: [0.05, 3]
17
+ gray_noise_prob: 0.4
18
+ jpeg_range: [30, 95]
19
+
20
+ # the second degradation process
21
+ second_blur_prob: 0.8
22
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
23
+ resize_range2: [0.3, 1.2]
24
+ gaussian_noise_prob2: 0.5
25
+ noise_range2: [1, 25]
26
+ poisson_scale_range2: [0.05, 2.5]
27
+ gray_noise_prob2: 0.4
28
+ jpeg_range2: [30, 95]
29
+
30
+ gt_size: 256
31
+ queue_size: 180
32
+
33
+ # dataset and data loader settings
34
+ datasets:
35
+ train:
36
+ name: DF2K+OST
37
+ type: RealESRGANDataset
38
+ dataroot_gt: datasets/DF2K
39
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40
+ io_backend:
41
+ type: disk
42
+
43
+ blur_kernel_size: 21
44
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
45
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
46
+ sinc_prob: 0.1
47
+ blur_sigma: [0.2, 3]
48
+ betag_range: [0.5, 4]
49
+ betap_range: [1, 2]
50
+
51
+ blur_kernel_size2: 21
52
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
53
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
54
+ sinc_prob2: 0.1
55
+ blur_sigma2: [0.2, 1.5]
56
+ betag_range2: [0.5, 4]
57
+ betap_range2: [1, 2]
58
+
59
+ final_sinc_prob: 0.8
60
+
61
+ gt_size: 256
62
+ use_hflip: True
63
+ use_rot: False
64
+
65
+ # data loader
66
+ use_shuffle: true
67
+ num_worker_per_gpu: 5
68
+ batch_size_per_gpu: 12
69
+ dataset_enlarge_ratio: 1
70
+ prefetch_mode: ~
71
+
72
+ # Uncomment these for validation
73
+ # val:
74
+ # name: validation
75
+ # type: PairedImageDataset
76
+ # dataroot_gt: path_to_gt
77
+ # dataroot_lq: path_to_lq
78
+ # io_backend:
79
+ # type: disk
80
+
81
+ # network structures
82
+ network_g:
83
+ type: RRDBNet
84
+ num_in_ch: 3
85
+ num_out_ch: 3
86
+ num_feat: 64
87
+ num_block: 23
88
+ num_grow_ch: 32
89
+ scale: 2
90
+
91
+ # path
92
+ path:
93
+ pretrain_network_g: experiments/pretrained_models/RealESRGAN_x4plus.pth
94
+ param_key_g: params_ema
95
+ strict_load_g: False
96
+ resume_state: ~
97
+
98
+ # training settings
99
+ train:
100
+ ema_decay: 0.999
101
+ optim_g:
102
+ type: Adam
103
+ lr: !!float 2e-4
104
+ weight_decay: 0
105
+ betas: [0.9, 0.99]
106
+
107
+ scheduler:
108
+ type: MultiStepLR
109
+ milestones: [1000000]
110
+ gamma: 0.5
111
+
112
+ total_iter: 1000000
113
+ warmup_iter: -1 # no warm up
114
+
115
+ # losses
116
+ pixel_opt:
117
+ type: L1Loss
118
+ loss_weight: 1.0
119
+ reduction: mean
120
+
121
+ # Uncomment these for validation
122
+ # validation settings
123
+ # val:
124
+ # val_freq: !!float 5e3
125
+ # save_img: True
126
+
127
+ # metrics:
128
+ # psnr: # metric name
129
+ # type: calculate_psnr
130
+ # crop_border: 4
131
+ # test_y_channel: false
132
+
133
+ # logging settings
134
+ logger:
135
+ print_freq: 100
136
+ save_checkpoint_freq: !!float 5e3
137
+ use_tb_logger: true
138
+ wandb:
139
+ project: ~
140
+ resume_id: ~
141
+
142
+ # dist training settings
143
+ dist_params:
144
+ backend: nccl
145
+ port: 29500
options/train_realesrnet_x4plus.yml CHANGED
@@ -1,8 +1,8 @@
1
  # general settings
2
- name: train_RealESRNetx4plus_1000k_B12G4_fromESRGAN
3
  model_type: RealESRNetModel
4
  scale: 4
5
- num_gpu: 4
6
  manual_seed: 0
7
 
8
  # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
@@ -36,7 +36,7 @@ datasets:
36
  name: DF2K+OST
37
  type: RealESRGANDataset
38
  dataroot_gt: datasets/DF2K
39
- meta_info: realesrgan/data/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40
  io_backend:
41
  type: disk
42
 
@@ -124,7 +124,7 @@ train:
124
  # save_img: True
125
 
126
  # metrics:
127
- # psnr: # metric name, can be arbitrary
128
  # type: calculate_psnr
129
  # crop_border: 4
130
  # test_y_channel: false
1
  # general settings
2
+ name: train_RealESRNetx4plus_1000k_B12G4
3
  model_type: RealESRNetModel
4
  scale: 4
5
+ num_gpu: auto # auto: can infer from your visible devices automatically. official: 4 GPUs
6
  manual_seed: 0
7
 
8
  # ----------------- options for synthesizing training data in RealESRNetModel ----------------- #
36
  name: DF2K+OST
37
  type: RealESRGANDataset
38
  dataroot_gt: datasets/DF2K
39
+ meta_info: datasets/DF2K/meta_info/meta_info_DF2Kmultiscale+OST_sub.txt
40
  io_backend:
41
  type: disk
42
 
124
  # save_img: True
125
 
126
  # metrics:
127
+ # psnr: # metric name
128
  # type: calculate_psnr
129
  # crop_border: 4
130
  # test_y_channel: false
realesrgan/__init__.py CHANGED
@@ -3,4 +3,4 @@ from .archs import *
3
  from .data import *
4
  from .models import *
5
  from .utils import *
6
- #from .version import __gitsha__, __version__
3
  from .data import *
4
  from .models import *
5
  from .utils import *
6
+ from .version import *
realesrgan/archs/discriminator_arch.py CHANGED
@@ -6,15 +6,23 @@ from torch.nn.utils import spectral_norm
6
 
7
  @ARCH_REGISTRY.register()
8
  class UNetDiscriminatorSN(nn.Module):
9
- """Defines a U-Net discriminator with spectral normalization (SN)"""
 
 
 
 
 
 
 
 
10
 
11
  def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
12
  super(UNetDiscriminatorSN, self).__init__()
13
  self.skip_connection = skip_connection
14
  norm = spectral_norm
15
-
16
  self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
17
-
18
  self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
19
  self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
20
  self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
@@ -22,14 +30,13 @@ class UNetDiscriminatorSN(nn.Module):
22
  self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
23
  self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
24
  self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
25
-
26
- # extra
27
  self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
28
  self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
29
-
30
  self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
31
 
32
  def forward(self, x):
 
33
  x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
34
  x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
35
  x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
@@ -52,7 +59,7 @@ class UNetDiscriminatorSN(nn.Module):
52
  if self.skip_connection:
53
  x6 = x6 + x0
54
 
55
- # extra
56
  out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
57
  out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
58
  out = self.conv9(out)
6
 
7
  @ARCH_REGISTRY.register()
8
  class UNetDiscriminatorSN(nn.Module):
9
+ """Defines a U-Net discriminator with spectral normalization (SN)
10
+
11
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12
+
13
+ Arg:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_feat (int): Channel number of base intermediate features. Default: 64.
16
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17
+ """
18
 
19
  def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20
  super(UNetDiscriminatorSN, self).__init__()
21
  self.skip_connection = skip_connection
22
  norm = spectral_norm
23
+ # the first convolution
24
  self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25
+ # downsample
26
  self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27
  self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28
  self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
30
  self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31
  self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32
  self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33
+ # extra convolutions
 
34
  self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35
  self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
 
36
  self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37
 
38
  def forward(self, x):
39
+ # downsample
40
  x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41
  x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42
  x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
59
  if self.skip_connection:
60
  x6 = x6 + x0
61
 
62
+ # extra convolutions
63
  out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64
  out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65
  out = self.conv9(out)
realesrgan/archs/srvgg_arch.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.utils.registry import ARCH_REGISTRY
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ @ARCH_REGISTRY.register()
7
+ class SRVGGNetCompact(nn.Module):
8
+ """A compact VGG-style network structure for super-resolution.
9
+
10
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
11
+ conducted on the HR feature space.
12
+
13
+ Args:
14
+ num_in_ch (int): Channel number of inputs. Default: 3.
15
+ num_out_ch (int): Channel number of outputs. Default: 3.
16
+ num_feat (int): Channel number of intermediate features. Default: 64.
17
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
18
+ upscale (int): Upsampling factor. Default: 4.
19
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
20
+ """
21
+
22
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
23
+ super(SRVGGNetCompact, self).__init__()
24
+ self.num_in_ch = num_in_ch
25
+ self.num_out_ch = num_out_ch
26
+ self.num_feat = num_feat
27
+ self.num_conv = num_conv
28
+ self.upscale = upscale
29
+ self.act_type = act_type
30
+
31
+ self.body = nn.ModuleList()
32
+ # the first conv
33
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
34
+ # the first activation
35
+ if act_type == 'relu':
36
+ activation = nn.ReLU(inplace=True)
37
+ elif act_type == 'prelu':
38
+ activation = nn.PReLU(num_parameters=num_feat)
39
+ elif act_type == 'leakyrelu':
40
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
41
+ self.body.append(activation)
42
+
43
+ # the body structure
44
+ for _ in range(num_conv):
45
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
46
+ # activation
47
+ if act_type == 'relu':
48
+ activation = nn.ReLU(inplace=True)
49
+ elif act_type == 'prelu':
50
+ activation = nn.PReLU(num_parameters=num_feat)
51
+ elif act_type == 'leakyrelu':
52
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
53
+ self.body.append(activation)
54
+
55
+ # the last conv
56
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
57
+ # upsample
58
+ self.upsampler = nn.PixelShuffle(upscale)
59
+
60
+ def forward(self, x):
61
+ out = x
62
+ for i in range(0, len(self.body)):
63
+ out = self.body[i](out)
64
+
65
+ out = self.upsampler(out)
66
+ # add the nearest upsampled image, so that the network learns the residual
67
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
68
+ out += base
69
+ return out
realesrgan/data/realesrgan_dataset.py CHANGED
@@ -15,18 +15,31 @@ from torch.utils import data as data
15
 
16
  @DATASET_REGISTRY.register()
17
  class RealESRGANDataset(data.Dataset):
18
- """
19
- Dataset used for Real-ESRGAN model.
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  """
21
 
22
  def __init__(self, opt):
23
  super(RealESRGANDataset, self).__init__()
24
  self.opt = opt
25
- # file client (io backend)
26
  self.file_client = None
27
  self.io_backend_opt = opt['io_backend']
28
  self.gt_folder = opt['dataroot_gt']
29
 
 
30
  if self.io_backend_opt['type'] == 'lmdb':
31
  self.io_backend_opt['db_paths'] = [self.gt_folder]
32
  self.io_backend_opt['client_keys'] = ['gt']
@@ -35,18 +48,20 @@ class RealESRGANDataset(data.Dataset):
35
  with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
36
  self.paths = [line.split('.')[0] for line in fin]
37
  else:
 
 
38
  with open(self.opt['meta_info']) as fin:
39
- paths = [line.strip() for line in fin]
40
  self.paths = [os.path.join(self.gt_folder, v) for v in paths]
41
 
42
  # blur settings for the first degradation
43
  self.blur_kernel_size = opt['blur_kernel_size']
44
  self.kernel_list = opt['kernel_list']
45
- self.kernel_prob = opt['kernel_prob']
46
  self.blur_sigma = opt['blur_sigma']
47
- self.betag_range = opt['betag_range']
48
- self.betap_range = opt['betap_range']
49
- self.sinc_prob = opt['sinc_prob']
50
 
51
  # blur settings for the second degradation
52
  self.blur_kernel_size2 = opt['blur_kernel_size2']
@@ -61,6 +76,7 @@ class RealESRGANDataset(data.Dataset):
61
  self.final_sinc_prob = opt['final_sinc_prob']
62
 
63
  self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
 
64
  self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
65
  self.pulse_tensor[10, 10] = 1
66
 
@@ -76,7 +92,7 @@ class RealESRGANDataset(data.Dataset):
76
  while retry > 0:
77
  try:
78
  img_bytes = self.file_client.get(gt_path, 'gt')
79
- except Exception as e:
80
  logger = get_root_logger()
81
  logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
82
  # change another file to read
@@ -89,10 +105,11 @@ class RealESRGANDataset(data.Dataset):
89
  retry -= 1
90
  img_gt = imfrombytes(img_bytes, float32=True)
91
 
92
- # -------------------- augmentation for training: flip, rotation -------------------- #
93
  img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
94
 
95
- # crop or pad to 400: 400 is hard-coded. You may change it accordingly
 
96
  h, w = img_gt.shape[0:2]
97
  crop_pad_size = 400
98
  # pad
@@ -154,7 +171,7 @@ class RealESRGANDataset(data.Dataset):
154
  pad_size = (21 - kernel_size) // 2
155
  kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
156
 
157
- # ------------------------------------- sinc kernel ------------------------------------- #
158
  if np.random.uniform() < self.opt['final_sinc_prob']:
159
  kernel_size = random.choice(self.kernel_range)
160
  omega_c = np.random.uniform(np.pi / 3, np.pi)
15
 
16
  @DATASET_REGISTRY.register()
17
  class RealESRGANDataset(data.Dataset):
18
+ """Dataset used for Real-ESRGAN model:
19
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
20
+
21
+ It loads gt (Ground-Truth) images, and augments them.
22
+ It also generates blur kernels and sinc kernels for generating low-quality images.
23
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
24
+
25
+ Args:
26
+ opt (dict): Config for train datasets. It contains the following keys:
27
+ dataroot_gt (str): Data root path for gt.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ use_hflip (bool): Use horizontal flips.
31
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
32
+ Please see more options in the codes.
33
  """
34
 
35
  def __init__(self, opt):
36
  super(RealESRGANDataset, self).__init__()
37
  self.opt = opt
 
38
  self.file_client = None
39
  self.io_backend_opt = opt['io_backend']
40
  self.gt_folder = opt['dataroot_gt']
41
 
42
+ # file client (lmdb io backend)
43
  if self.io_backend_opt['type'] == 'lmdb':
44
  self.io_backend_opt['db_paths'] = [self.gt_folder]
45
  self.io_backend_opt['client_keys'] = ['gt']
48
  with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
49
  self.paths = [line.split('.')[0] for line in fin]
50
  else:
51
+ # disk backend with meta_info
52
+ # Each line in the meta_info describes the relative path to an image
53
  with open(self.opt['meta_info']) as fin:
54
+ paths = [line.strip().split(' ')[0] for line in fin]
55
  self.paths = [os.path.join(self.gt_folder, v) for v in paths]
56
 
57
  # blur settings for the first degradation
58
  self.blur_kernel_size = opt['blur_kernel_size']
59
  self.kernel_list = opt['kernel_list']
60
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
61
  self.blur_sigma = opt['blur_sigma']
62
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
63
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
64
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
65
 
66
  # blur settings for the second degradation
67
  self.blur_kernel_size2 = opt['blur_kernel_size2']
76
  self.final_sinc_prob = opt['final_sinc_prob']
77
 
78
  self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
79
+ # TODO: kernel range is now hard-coded, should be in the configure file
80
  self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
81
  self.pulse_tensor[10, 10] = 1
82
 
92
  while retry > 0:
93
  try:
94
  img_bytes = self.file_client.get(gt_path, 'gt')
95
+ except (IOError, OSError) as e:
96
  logger = get_root_logger()
97
  logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
98
  # change another file to read
105
  retry -= 1
106
  img_gt = imfrombytes(img_bytes, float32=True)
107
 
108
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
109
  img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
110
 
111
+ # crop or pad to 400
112
+ # TODO: 400 is hard-coded. You may change it accordingly
113
  h, w = img_gt.shape[0:2]
114
  crop_pad_size = 400
115
  # pad
171
  pad_size = (21 - kernel_size) // 2
172
  kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
173
 
174
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
175
  if np.random.uniform() < self.opt['final_sinc_prob']:
176
  kernel_size = random.choice(self.kernel_range)
177
  omega_c = np.random.uniform(np.pi / 3, np.pi)
realesrgan/data/realesrgan_paired_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb
3
+ from basicsr.data.transforms import augment, paired_random_crop
4
+ from basicsr.utils import FileClient, imfrombytes, img2tensor
5
+ from basicsr.utils.registry import DATASET_REGISTRY
6
+ from torch.utils import data as data
7
+ from torchvision.transforms.functional import normalize
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class RealESRGANPairedDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15
+
16
+ There are three modes:
17
+ 1. 'lmdb': Use lmdb files.
18
+ If opt['io_backend'] == lmdb.
19
+ 2. 'meta_info': Use meta information file to generate paths.
20
+ If opt['io_backend'] != lmdb and opt['meta_info'] is not None.
21
+ 3. 'folder': Scan folders to generate paths.
22
+ The rest.
23
+
24
+ Args:
25
+ opt (dict): Config for train datasets. It contains the following keys:
26
+ dataroot_gt (str): Data root path for gt.
27
+ dataroot_lq (str): Data root path for lq.
28
+ meta_info (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31
+ Default: '{}'.
32
+ gt_size (int): Cropped patched size for gt patches.
33
+ use_hflip (bool): Use horizontal flips.
34
+ use_rot (bool): Use rotation (use vertical flip and transposing h
35
+ and w for implementation).
36
+
37
+ scale (bool): Scale, which will be added automatically.
38
+ phase (str): 'train' or 'val'.
39
+ """
40
+
41
+ def __init__(self, opt):
42
+ super(RealESRGANPairedDataset, self).__init__()
43
+ self.opt = opt
44
+ self.file_client = None
45
+ self.io_backend_opt = opt['io_backend']
46
+ # mean and std for normalizing the input images
47
+ self.mean = opt['mean'] if 'mean' in opt else None
48
+ self.std = opt['std'] if 'std' in opt else None
49
+
50
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
51
+ self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}'
52
+
53
+ # file client (lmdb io backend)
54
+ if self.io_backend_opt['type'] == 'lmdb':
55
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
56
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
57
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
58
+ elif 'meta_info' in self.opt and self.opt['meta_info'] is not None:
59
+ # disk backend with meta_info
60
+ # Each line in the meta_info describes the relative path to an image
61
+ with open(self.opt['meta_info']) as fin:
62
+ paths = [line.strip() for line in fin]
63
+ self.paths = []
64
+ for path in paths:
65
+ gt_path, lq_path = path.split(', ')
66
+ gt_path = os.path.join(self.gt_folder, gt_path)
67
+ lq_path = os.path.join(self.lq_folder, lq_path)
68
+ self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)]))
69
+ else:
70
+ # disk backend
71
+ # it will scan the whole folder to get meta info
72
+ # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file
73
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
74
+
75
+ def __getitem__(self, index):
76
+ if self.file_client is None:
77
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
78
+
79
+ scale = self.opt['scale']
80
+
81
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
82
+ # image range: [0, 1], float32.
83
+ gt_path = self.paths[index]['gt_path']
84
+ img_bytes = self.file_client.get(gt_path, 'gt')
85
+ img_gt = imfrombytes(img_bytes, float32=True)
86
+ lq_path = self.paths[index]['lq_path']
87
+ img_bytes = self.file_client.get(lq_path, 'lq')
88
+ img_lq = imfrombytes(img_bytes, float32=True)
89
+
90
+ # augmentation for training
91
+ if self.opt['phase'] == 'train':
92
+ gt_size = self.opt['gt_size']
93
+ # random crop
94
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
95
+ # flip, rotation
96
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
97
+
98
+ # BGR to RGB, HWC to CHW, numpy to tensor
99
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100
+ # normalize
101
+ if self.mean is not None or self.std is not None:
102
+ normalize(img_lq, self.mean, self.std, inplace=True)
103
+ normalize(img_gt, self.mean, self.std, inplace=True)
104
+
105
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
realesrgan/models/realesrgan_model.py CHANGED
@@ -13,35 +13,45 @@ from torch.nn import functional as F
13
 
14
  @MODEL_REGISTRY.register()
15
  class RealESRGANModel(SRGANModel):
16
- """RealESRGAN Model"""
 
 
 
 
 
17
 
18
  def __init__(self, opt):
19
  super(RealESRGANModel, self).__init__(opt)
20
- self.jpeger = DiffJPEG(differentiable=False).cuda()
21
- self.usm_sharpener = USMSharp().cuda()
22
- self.queue_size = opt['queue_size']
23
 
24
  @torch.no_grad()
25
  def _dequeue_and_enqueue(self):
26
- # training pair pool
 
 
 
 
 
27
  # initialize
28
  b, c, h, w = self.lq.size()
29
  if not hasattr(self, 'queue_lr'):
30
- assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
31
  self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
32
  _, c, h, w = self.gt.size()
33
  self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
34
  self.queue_ptr = 0
35
- if self.queue_ptr == self.queue_size: # full
36
  # do dequeue and enqueue
37
  # shuffle
38
  idx = torch.randperm(self.queue_size)
39
  self.queue_lr = self.queue_lr[idx]
40
  self.queue_gt = self.queue_gt[idx]
41
- # get
42
  lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
43
  gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
44
- # update
45
  self.queue_lr[0:b, :, :, :] = self.lq.clone()
46
  self.queue_gt[0:b, :, :, :] = self.gt.clone()
47
 
@@ -55,7 +65,9 @@ class RealESRGANModel(SRGANModel):
55
 
56
  @torch.no_grad()
57
  def feed_data(self, data):
58
- if self.is_train:
 
 
59
  # training data synthesis
60
  self.gt = data['gt'].to(self.device)
61
  self.gt_usm = self.usm_sharpener(self.gt)
@@ -79,7 +91,7 @@ class RealESRGANModel(SRGANModel):
79
  scale = 1
80
  mode = random.choice(['area', 'bilinear', 'bicubic'])
81
  out = F.interpolate(out, scale_factor=scale, mode=mode)
82
- # noise
83
  gray_noise_prob = self.opt['gray_noise_prob']
84
  if np.random.uniform() < self.opt['gaussian_noise_prob']:
85
  out = random_add_gaussian_noise_pt(
@@ -93,7 +105,7 @@ class RealESRGANModel(SRGANModel):
93
  rounds=False)
94
  # JPEG compression
95
  jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
96
- out = torch.clamp(out, 0, 1)
97
  out = self.jpeger(out, quality=jpeg_p)
98
 
99
  # ----------------------- The second degradation process ----------------------- #
@@ -111,7 +123,7 @@ class RealESRGANModel(SRGANModel):
111
  mode = random.choice(['area', 'bilinear', 'bicubic'])
112
  out = F.interpolate(
113
  out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
114
- # noise
115
  gray_noise_prob = self.opt['gray_noise_prob2']
116
  if np.random.uniform() < self.opt['gaussian_noise_prob2']:
117
  out = random_add_gaussian_noise_pt(
@@ -162,10 +174,13 @@ class RealESRGANModel(SRGANModel):
162
  self._dequeue_and_enqueue()
163
  # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
164
  self.gt_usm = self.usm_sharpener(self.gt)
 
165
  else:
 
166
  self.lq = data['lq'].to(self.device)
167
  if 'gt' in data:
168
  self.gt = data['gt'].to(self.device)
 
169
 
170
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
171
  # do not use the synthetic process during validation
@@ -174,6 +189,7 @@ class RealESRGANModel(SRGANModel):
174
  self.is_train = True
175
 
176
  def optimize_parameters(self, current_iter):
 
177
  l1_gt = self.gt_usm
178
  percep_gt = self.gt_usm
179
  gan_gt = self.gt_usm
13
 
14
  @MODEL_REGISTRY.register()
15
  class RealESRGANModel(SRGANModel):
16
+ """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
17
+
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
 
23
  def __init__(self, opt):
24
  super(RealESRGANModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
 
29
  @torch.no_grad()
30
  def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
  # initialize
38
  b, c, h, w = self.lq.size()
39
  if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
  self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
  _, c, h, w = self.gt.size()
43
  self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
  self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
  # do dequeue and enqueue
47
  # shuffle
48
  idx = torch.randperm(self.queue_size)
49
  self.queue_lr = self.queue_lr[idx]
50
  self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
  lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
  gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
  self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
  self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
 
65
 
66
  @torch.no_grad()
67
  def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
  # training data synthesis
72
  self.gt = data['gt'].to(self.device)
73
  self.gt_usm = self.usm_sharpener(self.gt)
91
  scale = 1
92
  mode = random.choice(['area', 'bilinear', 'bicubic'])
93
  out = F.interpolate(out, scale_factor=scale, mode=mode)
94
+ # add noise
95
  gray_noise_prob = self.opt['gray_noise_prob']
96
  if np.random.uniform() < self.opt['gaussian_noise_prob']:
97
  out = random_add_gaussian_noise_pt(
105
  rounds=False)
106
  # JPEG compression
107
  jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
108
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
109
  out = self.jpeger(out, quality=jpeg_p)
110
 
111
  # ----------------------- The second degradation process ----------------------- #
123
  mode = random.choice(['area', 'bilinear', 'bicubic'])
124
  out = F.interpolate(
125
  out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
126
+ # add noise
127
  gray_noise_prob = self.opt['gray_noise_prob2']
128
  if np.random.uniform() < self.opt['gaussian_noise_prob2']:
129
  out = random_add_gaussian_noise_pt(
174
  self._dequeue_and_enqueue()
175
  # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
176
  self.gt_usm = self.usm_sharpener(self.gt)
177
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
178
  else:
179
+ # for paired training or validation
180
  self.lq = data['lq'].to(self.device)
181
  if 'gt' in data:
182
  self.gt = data['gt'].to(self.device)
183
+ self.gt_usm = self.usm_sharpener(self.gt)
184
 
185
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
186
  # do not use the synthetic process during validation
189
  self.is_train = True
190
 
191
  def optimize_parameters(self, current_iter):
192
+ # usm sharpening
193
  l1_gt = self.gt_usm
194
  percep_gt = self.gt_usm
195
  gan_gt = self.gt_usm
realesrgan/models/realesrnet_model.py CHANGED
@@ -12,35 +12,46 @@ from torch.nn import functional as F
12
 
13
  @MODEL_REGISTRY.register()
14
  class RealESRNetModel(SRModel):
15
- """RealESRNet Model"""
 
 
 
 
 
 
16
 
17
  def __init__(self, opt):
18
  super(RealESRNetModel, self).__init__(opt)
19
- self.jpeger = DiffJPEG(differentiable=False).cuda()
20
- self.usm_sharpener = USMSharp().cuda()
21
- self.queue_size = opt['queue_size']
22
 
23
  @torch.no_grad()
24
  def _dequeue_and_enqueue(self):
25
- # training pair pool
 
 
 
 
 
26
  # initialize
27
  b, c, h, w = self.lq.size()
28
  if not hasattr(self, 'queue_lr'):
29
- assert self.queue_size % b == 0, 'queue size should be divisible by batch size'
30
  self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
31
  _, c, h, w = self.gt.size()
32
  self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
33
  self.queue_ptr = 0
34
- if self.queue_ptr == self.queue_size: # full
35
  # do dequeue and enqueue
36
  # shuffle
37
  idx = torch.randperm(self.queue_size)
38
  self.queue_lr = self.queue_lr[idx]
39
  self.queue_gt = self.queue_gt[idx]
40
- # get
41
  lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
42
  gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
43
- # update
44
  self.queue_lr[0:b, :, :, :] = self.lq.clone()
45
  self.queue_gt[0:b, :, :, :] = self.gt.clone()
46
 
@@ -54,10 +65,12 @@ class RealESRNetModel(SRModel):
54
 
55
  @torch.no_grad()
56
  def feed_data(self, data):
57
- if self.is_train:
 
 
58
  # training data synthesis
59
  self.gt = data['gt'].to(self.device)
60
- # USM the GT images
61
  if self.opt['gt_usm'] is True:
62
  self.gt = self.usm_sharpener(self.gt)
63
 
@@ -80,7 +93,7 @@ class RealESRNetModel(SRModel):
80
  scale = 1
81
  mode = random.choice(['area', 'bilinear', 'bicubic'])
82
  out = F.interpolate(out, scale_factor=scale, mode=mode)
83
- # noise
84
  gray_noise_prob = self.opt['gray_noise_prob']
85
  if np.random.uniform() < self.opt['gaussian_noise_prob']:
86
  out = random_add_gaussian_noise_pt(
@@ -94,7 +107,7 @@ class RealESRNetModel(SRModel):
94
  rounds=False)
95
  # JPEG compression
96
  jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
97
- out = torch.clamp(out, 0, 1)
98
  out = self.jpeger(out, quality=jpeg_p)
99
 
100
  # ----------------------- The second degradation process ----------------------- #
@@ -112,7 +125,7 @@ class RealESRNetModel(SRModel):
112
  mode = random.choice(['area', 'bilinear', 'bicubic'])
113
  out = F.interpolate(
114
  out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
115
- # noise
116
  gray_noise_prob = self.opt['gray_noise_prob2']
117
  if np.random.uniform() < self.opt['gaussian_noise_prob2']:
118
  out = random_add_gaussian_noise_pt(
@@ -160,10 +173,13 @@ class RealESRNetModel(SRModel):
160
 
161
  # training pair pool
162
  self._dequeue_and_enqueue()
 
163
  else:
 
164
  self.lq = data['lq'].to(self.device)
165
  if 'gt' in data:
166
  self.gt = data['gt'].to(self.device)
 
167
 
168
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
169
  # do not use the synthetic process during validation
12
 
13
  @MODEL_REGISTRY.register()
14
  class RealESRNetModel(SRModel):
15
+ """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
16
+
17
+ It is trained without GAN losses.
18
+ It mainly performs:
19
+ 1. randomly synthesize LQ images in GPU tensors
20
+ 2. optimize the networks with GAN training.
21
+ """
22
 
23
  def __init__(self, opt):
24
  super(RealESRNetModel, self).__init__(opt)
25
+ self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
26
+ self.usm_sharpener = USMSharp().cuda() # do usm sharpening
27
+ self.queue_size = opt.get('queue_size', 180)
28
 
29
  @torch.no_grad()
30
  def _dequeue_and_enqueue(self):
31
+ """It is the training pair pool for increasing the diversity in a batch.
32
+
33
+ Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
34
+ batch could not have different resize scaling factors. Therefore, we employ this training pair pool
35
+ to increase the degradation diversity in a batch.
36
+ """
37
  # initialize
38
  b, c, h, w = self.lq.size()
39
  if not hasattr(self, 'queue_lr'):
40
+ assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
41
  self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
42
  _, c, h, w = self.gt.size()
43
  self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
44
  self.queue_ptr = 0
45
+ if self.queue_ptr == self.queue_size: # the pool is full
46
  # do dequeue and enqueue
47
  # shuffle
48
  idx = torch.randperm(self.queue_size)
49
  self.queue_lr = self.queue_lr[idx]
50
  self.queue_gt = self.queue_gt[idx]
51
+ # get first b samples
52
  lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
53
  gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
54
+ # update the queue
55
  self.queue_lr[0:b, :, :, :] = self.lq.clone()
56
  self.queue_gt[0:b, :, :, :] = self.gt.clone()
57
 
65
 
66
  @torch.no_grad()
67
  def feed_data(self, data):
68
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
69
+ """
70
+ if self.is_train and self.opt.get('high_order_degradation', True):
71
  # training data synthesis
72
  self.gt = data['gt'].to(self.device)
73
+ # USM sharpen the GT images
74
  if self.opt['gt_usm'] is True:
75
  self.gt = self.usm_sharpener(self.gt)
76
 
93
  scale = 1
94
  mode = random.choice(['area', 'bilinear', 'bicubic'])
95
  out = F.interpolate(out, scale_factor=scale, mode=mode)
96
+ # add noise
97
  gray_noise_prob = self.opt['gray_noise_prob']
98
  if np.random.uniform() < self.opt['gaussian_noise_prob']:
99
  out = random_add_gaussian_noise_pt(
107
  rounds=False)
108
  # JPEG compression
109
  jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
110
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
111
  out = self.jpeger(out, quality=jpeg_p)
112
 
113
  # ----------------------- The second degradation process ----------------------- #
125
  mode = random.choice(['area', 'bilinear', 'bicubic'])
126
  out = F.interpolate(
127
  out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
128
+ # add noise
129
  gray_noise_prob = self.opt['gray_noise_prob2']
130
  if np.random.uniform() < self.opt['gaussian_noise_prob2']:
131
  out = random_add_gaussian_noise_pt(
173
 
174
  # training pair pool
175
  self._dequeue_and_enqueue()
176
+ self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
177
  else:
178
+ # for paired training or validation
179
  self.lq = data['lq'].to(self.device)
180
  if 'gt' in data:
181
  self.gt = data['gt'].to(self.device)
182
+ self.gt_usm = self.usm_sharpener(self.gt)
183
 
184
  def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
185
  # do not use the synthetic process during validation
realesrgan/utils.py CHANGED
@@ -2,18 +2,31 @@ import cv2
2
  import math
3
  import numpy as np
4
  import os
 
 
5
  import torch
6
- from basicsr.archs.rrdbnet_arch import RRDBNet
7
- from torch.hub import download_url_to_file, get_dir
8
  from torch.nn import functional as F
9
- from urllib.parse import urlparse
10
 
11
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
 
13
 
14
  class RealESRGANer():
 
15
 
16
- def __init__(self, scale, model_path, tile=0, tile_pad=10, pre_pad=10, half=False):
 
 
 
 
 
 
 
 
 
 
 
 
17
  self.scale = scale
18
  self.tile_size = tile
19
  self.tile_pad = tile_pad
@@ -23,12 +36,12 @@ class RealESRGANer():
23
 
24
  # initialize model
25
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
27
-
28
  if model_path.startswith('https://'):
29
  model_path = load_file_from_url(
30
- url=model_path, model_dir='realesrgan/weights', progress=True, file_name=None)
31
- loadnet = torch.load(model_path)
 
32
  if 'params_ema' in loadnet:
33
  keyname = 'params_ema'
34
  else:
@@ -40,6 +53,8 @@ class RealESRGANer():
40
  self.model = self.model.half()
41
 
42
  def pre_process(self, img):
 
 
43
  img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
44
  self.img = img.unsqueeze(0).to(self.device)
45
  if self.half:
@@ -48,7 +63,7 @@ class RealESRGANer():
48
  # pre_pad
49
  if self.pre_pad != 0:
50
  self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
51
- # mod pad
52
  if self.scale == 2:
53
  self.mod_scale = 2
54
  elif self.scale == 1:
@@ -63,10 +78,14 @@ class RealESRGANer():
63
  self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
64
 
65
  def process(self):
 
66
  self.output = self.model(self.img)
67
 
68
  def tile_process(self):
69
- """Modified from: https://github.com/ata4/esrgan-launcher
 
 
 
70
  """
71
  batch, channel, height, width = self.img.shape
72
  output_height = height * self.scale
@@ -106,7 +125,7 @@ class RealESRGANer():
106
  try:
107
  with torch.no_grad():
108
  output_tile = self.model(input_tile)
109
- except Exception as error:
110
  print('Error', error)
111
  print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
112
 
@@ -143,7 +162,7 @@ class RealESRGANer():
143
  h_input, w_input = img.shape[0:2]
144
  # img: numpy
145
  img = img.astype(np.float32)
146
- if np.max(img) > 255: # 16-bit image
147
  max_range = 65535
148
  print('\tInput is a 16-bit image')
149
  else:
@@ -187,7 +206,7 @@ class RealESRGANer():
187
  output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
188
  output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
189
  output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
190
- else:
191
  h, w = alpha.shape[0:2]
192
  output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
193
 
@@ -211,21 +230,51 @@ class RealESRGANer():
211
  return output, img_mode
212
 
213
 
214
- def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
215
- """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
 
 
 
 
216
  """
217
- if model_dir is None:
218
- hub_dir = get_dir()
219
- model_dir = os.path.join(hub_dir, 'checkpoints')
220
-
221
- os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
222
-
223
- parts = urlparse(url)
224
- filename = os.path.basename(parts.path)
225
- if file_name is not None:
226
- filename = file_name
227
- cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
228
- if not os.path.exists(cached_file):
229
- print(f'Downloading: "{url}" to {cached_file}\n')
230
- download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
231
- return cached_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import math
3
  import numpy as np
4
  import os
5
+ import queue
6
+ import threading
7
  import torch
8
+ from basicsr.utils.download_util import load_file_from_url
 
9
  from torch.nn import functional as F
 
10
 
11
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
 
13
 
14
  class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
 
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
30
  self.scale = scale
31
  self.tile_size = tile
32
  self.tile_pad = tile_pad
36
 
37
  # initialize model
38
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
 
40
  if model_path.startswith('https://'):
41
  model_path = load_file_from_url(
42
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
43
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
44
+ # prefer to use params_ema
45
  if 'params_ema' in loadnet:
46
  keyname = 'params_ema'
47
  else:
53
  self.model = self.model.half()
54
 
55
  def pre_process(self, img):
56
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
57
+ """
58
  img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
59
  self.img = img.unsqueeze(0).to(self.device)
60
  if self.half:
63
  # pre_pad
64
  if self.pre_pad != 0:
65
  self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
66
+ # mod pad for divisible borders
67
  if self.scale == 2:
68
  self.mod_scale = 2
69
  elif self.scale == 1:
78
  self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
79
 
80
  def process(self):
81
+ # model inference
82
  self.output = self.model(self.img)
83
 
84
  def tile_process(self):
85
+ """It will first crop input images to tiles, and then process each tile.
86
+ Finally, all the processed tiles are merged into one images.
87
+
88
+ Modified from: https://github.com/ata4/esrgan-launcher
89
  """
90
  batch, channel, height, width = self.img.shape
91
  output_height = height * self.scale
125
  try:
126
  with torch.no_grad():
127
  output_tile = self.model(input_tile)
128
+ except RuntimeError as error:
129
  print('Error', error)
130
  print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
131
 
162
  h_input, w_input = img.shape[0:2]
163
  # img: numpy
164
  img = img.astype(np.float32)
165
+ if np.max(img) > 256: # 16-bit image
166
  max_range = 65535
167
  print('\tInput is a 16-bit image')
168
  else:
206
  output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
207
  output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
208
  output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
209
+ else: # use the cv2 resize for alpha channel
210
  h, w = alpha.shape[0:2]
211
  output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
212
 
230
  return output, img_mode
231
 
232
 
233
+ class PrefetchReader(threading.Thread):
234
+ """Prefetch images.
235
+
236
+ Args:
237
+ img_list (list[str]): A image list of image paths to be read.
238
+ num_prefetch_queue (int): Number of prefetch queue.
239
  """
240
+
241
+ def __init__(self, img_list, num_prefetch_queue):
242
+ super().__init__()
243
+ self.que = queue.Queue(num_prefetch_queue)
244
+ self.img_list = img_list
245
+
246
+ def run(self):
247
+ for img_path in self.img_list:
248
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
249
+ self.que.put(img)
250
+
251
+ self.que.put(None)
252
+
253
+ def __next__(self):
254
+ next_item = self.que.get()
255
+ if next_item is None:
256
+ raise StopIteration
257
+ return next_item
258
+
259
+ def __iter__(self):
260
+ return self
261
+
262
+
263
+ class IOConsumer(threading.Thread):
264
+
265
+ def __init__(self, opt, que, qid):
266
+ super().__init__()
267
+ self._queue = que
268
+ self.qid = qid
269
+ self.opt = opt
270
+
271
+ def run(self):
272
+ while True:
273
+ msg = self._queue.get()
274
+ if isinstance(msg, str) and msg == 'quit':
275
+ break
276
+
277
+ output = msg['output']
278
+ save_path = msg['save_path']
279
+ cv2.imwrite(save_path, output)
280
+ print(f'IO worker {self.qid} is done.')
scripts/extract_subimages.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+ from basicsr.utils import scandir
7
+ from multiprocessing import Pool
8
+ from os import path as osp
9
+ from tqdm import tqdm
10
+
11
+
12
+ def main(args):
13
+ """A multi-thread tool to crop large images to sub-images for faster IO.
14
+
15
+ opt (dict): Configuration dict. It contains:
16
+ n_thread (int): Thread number.
17
+ compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size
18
+ and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
19
+ input_folder (str): Path to the input folder.
20
+ save_folder (str): Path to save folder.
21
+ crop_size (int): Crop size.
22
+ step (int): Step for overlapped sliding window.
23
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
24
+
25
+ Usage:
26
+ For each folder, run this script.
27
+ Typically, there are GT folder and LQ folder to be processed for DIV2K dataset.
28
+ After process, each sub_folder should have the same number of subimages.
29
+ Remember to modify opt configurations according to your settings.
30
+ """
31
+
32
+ opt = {}
33
+ opt['n_thread'] = args.n_thread
34
+ opt['compression_level'] = args.compression_level
35
+ opt['input_folder'] = args.input
36
+ opt['save_folder'] = args.output
37
+ opt['crop_size'] = args.crop_size
38
+ opt['step'] = args.step
39
+ opt['thresh_size'] = args.thresh_size
40
+ extract_subimages(opt)
41
+
42
+
43
+ def extract_subimages(opt):
44
+ """Crop images to subimages.
45
+
46
+ Args:
47
+ opt (dict): Configuration dict. It contains:
48
+ input_folder (str): Path to the input folder.
49
+ save_folder (str): Path to save folder.
50
+ n_thread (int): Thread number.
51
+ """
52
+ input_folder = opt['input_folder']
53
+ save_folder = opt['save_folder']
54
+ if not osp.exists(save_folder):
55
+ os.makedirs(save_folder)
56
+ print(f'mkdir {save_folder} ...')
57
+ else:
58
+ print(f'Folder {save_folder} already exists. Exit.')
59
+ sys.exit(1)
60
+
61
+ # scan all images
62
+ img_list = list(scandir(input_folder, full_path=True))
63
+
64
+ pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
65
+ pool = Pool(opt['n_thread'])
66
+ for path in img_list:
67
+ pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
68
+ pool.close()
69
+ pool.join()
70
+ pbar.close()
71
+ print('All processes done.')
72
+
73
+
74
+ def worker(path, opt):
75
+ """Worker for each process.
76
+
77
+ Args:
78
+ path (str): Image path.
79
+ opt (dict): Configuration dict. It contains:
80
+ crop_size (int): Crop size.
81
+ step (int): Step for overlapped sliding window.
82
+ thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
83
+ save_folder (str): Path to save folder.
84
+ compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
85
+
86
+ Returns:
87
+ process_info (str): Process information displayed in progress bar.
88
+ """
89
+ crop_size = opt['crop_size']
90
+ step = opt['step']
91
+ thresh_size = opt['thresh_size']
92
+ img_name, extension = osp.splitext(osp.basename(path))
93
+
94
+ # remove the x2, x3, x4 and x8 in the filename for DIV2K
95
+ img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
96
+
97
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
98
+
99
+ h, w = img.shape[0:2]
100
+ h_space = np.arange(0, h - crop_size + 1, step)
101
+ if h - (h_space[-1] + crop_size) > thresh_size:
102
+ h_space = np.append(h_space, h - crop_size)
103
+ w_space = np.arange(0, w - crop_size + 1, step)
104
+ if w - (w_space[-1] + crop_size) > thresh_size:
105
+ w_space = np.append(w_space, w - crop_size)
106
+
107
+ index = 0
108
+ for x in h_space:
109
+ for y in w_space:
110
+ index += 1
111
+ cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
112
+ cropped_img = np.ascontiguousarray(cropped_img)
113
+ cv2.imwrite(
114
+ osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
115
+ [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
116
+ process_info = f'Processing {img_name} ...'
117
+ return process_info
118
+
119
+
120
+ if __name__ == '__main__':
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
123
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_HR_sub', help='Output folder')
124
+ parser.add_argument('--crop_size', type=int, default=480, help='Crop size')
125
+ parser.add_argument('--step', type=int, default=240, help='Step for overlapped sliding window')
126
+ parser.add_argument(
127
+ '--thresh_size',
128
+ type=int,
129
+ default=0,
130
+ help='Threshold size. Patches whose size is lower than thresh_size will be dropped.')
131
+ parser.add_argument('--n_thread', type=int, default=20, help='Thread number.')
132
+ parser.add_argument('--compression_level', type=int, default=3, help='Compression level')
133
+ args = parser.parse_args()
134
+
135
+ main(args)
scripts/generate_meta_info.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import os
5
+
6
+
7
+ def main(args):
8
+ txt_file = open(args.meta_info, 'w')
9
+ for folder, root in zip(args.input, args.root):
10
+ img_paths = sorted(glob.glob(os.path.join(folder, '*')))
11
+ for img_path in img_paths:
12
+ status = True
13
+ if args.check:
14
+ # read the image once for check, as some images may have errors
15
+ try:
16
+ img = cv2.imread(img_path)
17
+ except (IOError, OSError) as error:
18
+ print(f'Read {img_path} error: {error}')
19
+ status = False
20
+ if img is None:
21
+ status = False
22
+ print(f'Img is None: {img_path}')
23
+ if status:
24
+ # get the relative path
25
+ img_name = os.path.relpath(img_path, root)
26
+ print(img_name)
27
+ txt_file.write(f'{img_name}\n')
28
+
29
+
30
+ if __name__ == '__main__':
31
+ """Generate meta info (txt file) for only Ground-Truth images.
32
+
33
+ It can also generate meta info from several folders into one txt file.
34
+ """
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument(
37
+ '--input',
38
+ nargs='+',
39
+ default=['datasets/DF2K/DF2K_HR', 'datasets/DF2K/DF2K_multiscale'],
40
+ help='Input folder, can be a list')
41
+ parser.add_argument(
42
+ '--root',
43
+ nargs='+',
44
+ default=['datasets/DF2K', 'datasets/DF2K'],
45
+ help='Folder root, should have the length as input folders')
46
+ parser.add_argument(
47
+ '--meta_info',
48
+ type=str,
49
+ default='datasets/DF2K/meta_info/meta_info_DF2Kmultiscale.txt',
50
+ help='txt path for meta info')
51
+ parser.add_argument('--check', action='store_true', help='Read image to check whether it is ok')
52
+ args = parser.parse_args()
53
+
54
+ assert len(args.input) == len(args.root), ('Input folder and folder root should have the same length, but got '
55
+ f'{len(args.input)} and {len(args.root)}.')
56
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
57
+
58
+ main(args)
scripts/generate_meta_info_pairdata.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+
5
+
6
+ def main(args):
7
+ txt_file = open(args.meta_info, 'w')
8
+ # sca images
9
+ img_paths_gt = sorted(glob.glob(os.path.join(args.input[0], '*')))
10
+ img_paths_lq = sorted(glob.glob(os.path.join(args.input[1], '*')))
11
+
12
+ assert len(img_paths_gt) == len(img_paths_lq), ('GT folder and LQ folder should have the same length, but got '
13
+ f'{len(img_paths_gt)} and {len(img_paths_lq)}.')
14
+
15
+ for img_path_gt, img_path_lq in zip(img_paths_gt, img_paths_lq):
16
+ # get the relative paths
17
+ img_name_gt = os.path.relpath(img_path_gt, args.root[0])
18
+ img_name_lq = os.path.relpath(img_path_lq, args.root[1])
19
+ print(f'{img_name_gt}, {img_name_lq}')
20
+ txt_file.write(f'{img_name_gt}, {img_name_lq}\n')
21
+
22
+
23
+ if __name__ == '__main__':
24
+ """This script is used to generate meta info (txt file) for paired images.
25
+ """
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument(
28
+ '--input',
29
+ nargs='+',
30
+ default=['datasets/DF2K/DIV2K_train_HR_sub', 'datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub'],
31
+ help='Input folder, should be [gt_folder, lq_folder]')
32
+ parser.add_argument('--root', nargs='+', default=[None, None], help='Folder root, will use the ')
33
+ parser.add_argument(
34
+ '--meta_info',
35
+ type=str,
36
+ default='datasets/DF2K/meta_info/meta_info_DIV2K_sub_pair.txt',
37
+ help='txt path for meta info')
38
+ args = parser.parse_args()
39
+
40
+ assert len(args.input) == 2, 'Input folder should have two elements: gt folder and lq folder'
41
+ assert len(args.root) == 2, 'Root path should have two elements: root for gt folder and lq folder'
42
+ os.makedirs(os.path.dirname(args.meta_info), exist_ok=True)
43
+ for i in range(2):
44
+ if args.input[i].endswith('/'):
45
+ args.input[i] = args.input[i][:-1]
46
+ if args.root[i] is None:
47
+ args.root[i] = os.path.dirname(args.input[i])
48
+
49
+ main(args)
scripts/generate_multiscale_DF2K.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ from PIL import Image
5
+
6
+
7
+ def main(args):
8
+ # For DF2K, we consider the following three scales,
9
+ # and the smallest image whose shortest edge is 400
10
+ scale_list = [0.75, 0.5, 1 / 3]
11
+ shortest_edge = 400
12
+
13
+ path_list = sorted(glob.glob(os.path.join(args.input, '*')))
14
+ for path in path_list:
15
+ print(path)
16
+ basename = os.path.splitext(os.path.basename(path))[0]
17
+
18
+ img = Image.open(path)
19
+ width, height = img.size
20
+ for idx, scale in enumerate(scale_list):
21
+ print(f'\t{scale:.2f}')
22
+ rlt = img.resize((int(width * scale), int(height * scale)), resample=Image.LANCZOS)
23
+ rlt.save(os.path.join(args.output, f'{basename}T{idx}.png'))
24
+
25
+ # save the smallest image which the shortest edge is 400
26
+ if width < height:
27
+ ratio = height / width
28
+ width = shortest_edge
29
+ height = int(width * ratio)
30
+ else:
31
+ ratio = width / height
32
+ height = shortest_edge
33
+ width = int(height * ratio)
34
+ rlt = img.resize((int(width), int(height)), resample=Image.LANCZOS)
35
+ rlt.save(os.path.join(args.output, f'{basename}T{idx+1}.png'))
36
+
37
+
38
+ if __name__ == '__main__':
39
+ """Generate multi-scale versions for GT images with LANCZOS resampling.
40
+ It is now used for DF2K dataset (DIV2K + Flickr 2K)
41
+ """
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument('--input', type=str, default='datasets/DF2K/DF2K_HR', help='Input folder')
44
+ parser.add_argument('--output', type=str, default='datasets/DF2K/DF2K_multiscale', help='Output folder')
45
+ args = parser.parse_args()
46
+
47
+ os.makedirs(args.output, exist_ok=True)
48
+ main(args)
scripts/pytorch2onnx.py CHANGED
@@ -1,17 +1,36 @@
 
1
  import torch
2
  import torch.onnx
3
  from basicsr.archs.rrdbnet_arch import RRDBNet
4
 
5
- # An instance of your model
6
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32)
7
- model.load_state_dict(torch.load('experiments/pretrained_models/RealESRGAN_x4plus.pth')['params_ema'])
8
- # set the train mode to false since we will only run the forward pass.
9
- model.train(False)
10
- model.cpu().eval()
11
 
12
- # An example input you would normally provide to your model's forward() method
13
- x = torch.rand(1, 3, 64, 64)
 
 
 
 
 
 
 
 
 
14
 
15
- # Export the model
16
- with torch.no_grad():
17
- torch_out = torch.onnx._export(model, x, 'realesrgan-x4.onnx', opset_version=11, export_params=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
  import torch
3
  import torch.onnx
4
  from basicsr.archs.rrdbnet_arch import RRDBNet
5
 
 
 
 
 
 
 
6
 
7
+ def main(args):
8
+ # An instance of the model
9
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
10
+ if args.params:
11
+ keyname = 'params'
12
+ else:
13
+ keyname = 'params_ema'
14
+ model.load_state_dict(torch.load(args.input)[keyname])
15
+ # set the train mode to false since we will only run the forward pass.
16
+ model.train(False)
17
+ model.cpu().eval()
18
 
19
+ # An example input
20
+ x = torch.rand(1, 3, 64, 64)
21
+ # Export the model
22
+ with torch.no_grad():
23
+ torch_out = torch.onnx._export(model, x, args.output, opset_version=11, export_params=True)
24
+ print(torch_out.shape)
25
+
26
+
27
+ if __name__ == '__main__':
28
+ """Convert pytorch model to onnx models"""
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument(
31
+ '--input', type=str, default='experiments/pretrained_models/RealESRGAN_x4plus.pth', help='Input model path')
32
+ parser.add_argument('--output', type=str, default='realesrgan-x4.onnx', help='Output onnx path')
33
+ parser.add_argument('--params', action='store_false', help='Use params instead of params_ema')
34
+ args = parser.parse_args()
35
+
36
+ main(args)
setup.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import subprocess
7
  import time
8
 
9
- version_file = 'version.py'
10
 
11
 
12
  def readme():
@@ -43,12 +43,6 @@ def get_git_hash():
43
  def get_hash():
44
  if os.path.exists('.git'):
45
  sha = get_git_hash()[:7]
46
- elif os.path.exists(version_file):
47
- try:
48
- from facexlib.version import __version__
49
- sha = __version__.split('+')[-1]
50
- except ImportError:
51
- raise ImportError('Unable to get git version')
52
  else:
53
  sha = 'unknown'
54
 
6
  import subprocess
7
  import time
8
 
9
+ version_file = 'realesrgan/version.py'
10
 
11
 
12
  def readme():
43
  def get_hash():
44
  if os.path.exists('.git'):
45
  sha = get_git_hash()[:7]
 
 
 
 
 
 
46
  else:
47
  sha = 'unknown'
48
 
tests/data/gt.lmdb/data.mdb ADDED
Binary file (758 kB). View file
tests/data/gt.lmdb/lock.mdb ADDED
Binary file (8.19 kB). View file
tests/data/gt.lmdb/meta_info.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ baboon.png (480,500,3) 1
2
+ comic.png (360,240,3) 1
tests/data/gt/baboon.png ADDED
tests/data/gt/comic.png ADDED
tests/data/lq.lmdb/data.mdb ADDED
Binary file (65.5 kB). View file
tests/data/lq.lmdb/lock.mdb ADDED
Binary file (8.19 kB). View file
tests/data/lq.lmdb/meta_info.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ baboon.png (120,125,3) 1
2
+ comic.png (80,60,3) 1
tests/data/lq/baboon.png ADDED
tests/data/lq/comic.png ADDED
tests/data/meta_info_gt.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ baboon.png
2
+ comic.png
tests/data/meta_info_pair.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ gt/baboon.png, lq/baboon.png
2
+ gt/comic.png, lq/comic.png
tests/data/test_realesrgan_dataset.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Demo
2
+ type: RealESRGANDataset
3
+ dataroot_gt: tests/data/gt
4
+ meta_info: tests/data/meta_info_gt.txt
5
+ io_backend:
6
+ type: disk
7
+
8
+ blur_kernel_size: 21
9
+ kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
10
+ kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
11
+ sinc_prob: 1
12
+ blur_sigma: [0.2, 3]
13
+ betag_range: [0.5, 4]
14
+ betap_range: [1, 2]
15
+
16
+ blur_kernel_size2: 21
17
+ kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
18
+ kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
19
+ sinc_prob2: 1
20
+ blur_sigma2: [0.2, 1.5]
21
+ betag_range2: [0.5, 4]
22
+ betap_range2: [1, 2]
23
+
24
+ final_sinc_prob: 1
25
+
26
+ gt_size: 128
27
+ use_hflip: True
28
+ use_rot: False
tests/data/test_realesrgan_model.yml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ num_gpu: 1
3
+ manual_seed: 0
4
+ is_train: True
5
+ dist: False
6
+
7
+ # ----------------- options for synthesizing training data ----------------- #
8
+ # USM the ground-truth
9
+ l1_gt_usm: True
10
+ percep_gt_usm: True
11
+ gan_gt_usm: False
12
+
13
+ # the first degradation process
14
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
15
+ resize_range: [0.15, 1.5]
16
+ gaussian_noise_prob: 1
17
+ noise_range: [1, 30]
18
+ poisson_scale_range: [0.05, 3]
19
+ gray_noise_prob: 1
20
+ jpeg_range: [30, 95]
21
+
22
+ # the second degradation process
23
+ second_blur_prob: 1
24
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
25
+ resize_range2: [0.3, 1.2]
26
+ gaussian_noise_prob2: 1
27
+ noise_range2: [1, 25]
28
+ poisson_scale_range2: [0.05, 2.5]
29
+ gray_noise_prob2: 1
30
+ jpeg_range2: [30, 95]
31
+
32
+ gt_size: 32
33
+ queue_size: 1
34
+
35
+ # network structures
36
+ network_g:
37
+ type: RRDBNet
38
+ num_in_ch: 3
39
+ num_out_ch: 3
40
+ num_feat: 4
41
+ num_block: 1
42
+ num_grow_ch: 2
43
+
44
+ network_d:
45
+ type: UNetDiscriminatorSN
46
+ num_in_ch: 3
47
+ num_feat: 2
48
+ skip_connection: True
49
+
50
+ # path
51
+ path:
52
+ pretrain_network_g: ~
53
+ param_key_g: params_ema
54
+ strict_load_g: true
55
+ resume_state: ~
56
+
57
+ # training settings
58
+ train:
59
+ ema_decay: 0.999
60
+ optim_g:
61
+ type: Adam
62
+ lr: !!float 1e-4
63
+ weight_decay: 0
64
+ betas: [0.9, 0.99]
65
+ optim_d:
66
+ type: Adam
67
+ lr: !!float 1e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.99]
70
+
71
+ scheduler:
72
+ type: MultiStepLR
73
+ milestones: [400000]
74
+ gamma: 0.5
75
+
76
+ total_iter: 400000
77
+ warmup_iter: -1 # no warm up
78
+
79
+ # losses
80
+ pixel_opt:
81
+ type: L1Loss
82
+ loss_weight: 1.0
83
+ reduction: mean
84
+ # perceptual loss (content and style losses)
85
+ perceptual_opt:
86
+ type: PerceptualLoss
87
+ layer_weights:
88
+ # before relu
89
+ 'conv1_2': 0.1
90
+ 'conv2_2': 0.1
91
+ 'conv3_4': 1
92
+ 'conv4_4': 1
93
+ 'conv5_4': 1
94
+ vgg_type: vgg19
95
+ use_input_norm: true
96
+ perceptual_weight: !!float 1.0
97
+ style_weight: 0
98
+ range_norm: false
99
+ criterion: l1
100
+ # gan loss
101
+ gan_opt:
102
+ type: GANLoss
103
+ gan_type: vanilla
104
+ real_label_val: 1.0
105
+ fake_label_val: 0.0
106
+ loss_weight: !!float 1e-1
107
+
108
+ net_d_iters: 1
109
+ net_d_init_iters: 0
110
+
111
+
112
+ # validation settings
113
+ val:
114
+ val_freq: !!float 5e3
115
+ save_img: False
tests/data/test_realesrgan_paired_dataset.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Demo
2
+ type: RealESRGANPairedDataset
3
+ scale: 4
4
+ dataroot_gt: tests/data
5
+ dataroot_lq: tests/data
6
+ meta_info: tests/data/meta_info_pair.txt
7
+ io_backend:
8
+ type: disk
9
+
10
+ phase: train
11
+ gt_size: 128
12
+ use_hflip: True
13
+ use_rot: False
tests/data/test_realesrnet_model.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scale: 4
2
+ num_gpu: 1
3
+ manual_seed: 0
4
+ is_train: True
5
+ dist: False
6
+
7
+ # ----------------- options for synthesizing training data ----------------- #
8
+ gt_usm: True # USM the ground-truth
9
+
10
+ # the first degradation process
11
+ resize_prob: [0.2, 0.7, 0.1] # up, down, keep
12
+ resize_range: [0.15, 1.5]
13
+ gaussian_noise_prob: 1
14
+ noise_range: [1, 30]
15
+ poisson_scale_range: [0.05, 3]
16
+ gray_noise_prob: 1
17
+ jpeg_range: [30, 95]
18
+
19
+ # the second degradation process
20
+ second_blur_prob: 1
21
+ resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
22
+ resize_range2: [0.3, 1.2]
23
+ gaussian_noise_prob2: 1
24
+ noise_range2: [1, 25]
25
+ poisson_scale_range2: [0.05, 2.5]
26
+ gray_noise_prob2: 1
27
+ jpeg_range2: [30, 95]
28
+
29
+ gt_size: 32
30
+ queue_size: 1
31
+
32
+ # network structures
33
+ network_g:
34
+ type: RRDBNet
35
+ num_in_ch: 3
36
+ num_out_ch: 3
37
+ num_feat: 4
38
+ num_block: 1
39
+ num_grow_ch: 2
40
+
41
+ # path
42
+ path:
43
+ pretrain_network_g: ~
44
+ param_key_g: params_ema
45
+ strict_load_g: true
46
+ resume_state: ~
47
+
48
+ # training settings
49
+ train:
50
+ ema_decay: 0.999
51
+ optim_g:
52
+ type: Adam
53
+ lr: !!float 2e-4
54
+ weight_decay: 0
55
+ betas: [0.9, 0.99]
56
+
57
+ scheduler:
58
+ type: MultiStepLR
59
+ milestones: [1000000]
60
+ gamma: 0.5
61
+
62
+ total_iter: 1000000
63
+ warmup_iter: -1 # no warm up
64
+
65
+ # losses
66
+ pixel_opt:
67
+ type: L1Loss
68
+ loss_weight: 1.0
69
+ reduction: mean
70
+
71
+
72
+ # validation settings
73
+ val:
74
+ val_freq: !!float 5e3
75
+ save_img: False
tests/test_dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import yaml
3
+
4
+ from realesrgan.data.realesrgan_dataset import RealESRGANDataset
5
+ from realesrgan.data.realesrgan_paired_dataset import RealESRGANPairedDataset
6
+
7
+
8
+ def test_realesrgan_dataset():
9
+
10
+ with open('tests/data/test_realesrgan_dataset.yml', mode='r') as f:
11
+ opt = yaml.load(f, Loader=yaml.FullLoader)
12
+
13
+ dataset = RealESRGANDataset(opt)
14
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
15
+ assert len(dataset) == 2 # whether to read correct meta info
16
+ assert dataset.kernel_list == [
17
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
18
+ ] # correct initialization the degradation configurations
19
+ assert dataset.betag_range2 == [0.5, 4]
20
+
21
+ # test __getitem__
22
+ result = dataset.__getitem__(0)
23
+ # check returned keys
24
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
25
+ assert set(expected_keys).issubset(set(result.keys()))
26
+ # check shape and contents
27
+ assert result['gt'].shape == (3, 400, 400)
28
+ assert result['kernel1'].shape == (21, 21)
29
+ assert result['kernel2'].shape == (21, 21)
30
+ assert result['sinc_kernel'].shape == (21, 21)
31
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
32
+
33
+ # ------------------ test lmdb backend -------------------- #
34
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
35
+ opt['io_backend']['type'] = 'lmdb'
36
+
37
+ dataset = RealESRGANDataset(opt)
38
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
39
+ assert len(dataset.paths) == 2 # whether to read correct meta info
40
+ assert dataset.kernel_list == [
41
+ 'iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'
42
+ ] # correct initialization the degradation configurations
43
+ assert dataset.betag_range2 == [0.5, 4]
44
+
45
+ # test __getitem__
46
+ result = dataset.__getitem__(1)
47
+ # check returned keys
48
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
49
+ assert set(expected_keys).issubset(set(result.keys()))
50
+ # check shape and contents
51
+ assert result['gt'].shape == (3, 400, 400)
52
+ assert result['kernel1'].shape == (21, 21)
53
+ assert result['kernel2'].shape == (21, 21)
54
+ assert result['sinc_kernel'].shape == (21, 21)
55
+ assert result['gt_path'] == 'comic'
56
+
57
+ # ------------------ test with sinc_prob = 0 -------------------- #
58
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
59
+ opt['io_backend']['type'] = 'lmdb'
60
+ opt['sinc_prob'] = 0
61
+ opt['sinc_prob2'] = 0
62
+ opt['final_sinc_prob'] = 0
63
+ dataset = RealESRGANDataset(opt)
64
+ result = dataset.__getitem__(0)
65
+ # check returned keys
66
+ expected_keys = ['gt', 'kernel1', 'kernel2', 'sinc_kernel', 'gt_path']
67
+ assert set(expected_keys).issubset(set(result.keys()))
68
+ # check shape and contents
69
+ assert result['gt'].shape == (3, 400, 400)
70
+ assert result['kernel1'].shape == (21, 21)
71
+ assert result['kernel2'].shape == (21, 21)
72
+ assert result['sinc_kernel'].shape == (21, 21)
73
+ assert result['gt_path'] == 'baboon'
74
+
75
+ # ------------------ lmdb backend should have paths ends with lmdb -------------------- #
76
+ with pytest.raises(ValueError):
77
+ opt['dataroot_gt'] = 'tests/data/gt'
78
+ opt['io_backend']['type'] = 'lmdb'
79
+ dataset = RealESRGANDataset(opt)
80
+
81
+
82
+ def test_realesrgan_paired_dataset():
83
+
84
+ with open('tests/data/test_realesrgan_paired_dataset.yml', mode='r') as f:
85
+ opt = yaml.load(f, Loader=yaml.FullLoader)
86
+
87
+ dataset = RealESRGANPairedDataset(opt)
88
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
89
+ assert len(dataset) == 2 # whether to read correct meta info
90
+
91
+ # test __getitem__
92
+ result = dataset.__getitem__(0)
93
+ # check returned keys
94
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
95
+ assert set(expected_keys).issubset(set(result.keys()))
96
+ # check shape and contents
97
+ assert result['gt'].shape == (3, 128, 128)
98
+ assert result['lq'].shape == (3, 32, 32)
99
+ assert result['gt_path'] == 'tests/data/gt/baboon.png'
100
+ assert result['lq_path'] == 'tests/data/lq/baboon.png'
101
+
102
+ # ------------------ test lmdb backend -------------------- #
103
+ opt['dataroot_gt'] = 'tests/data/gt.lmdb'
104
+ opt['dataroot_lq'] = 'tests/data/lq.lmdb'
105
+ opt['io_backend']['type'] = 'lmdb'
106
+
107
+ dataset = RealESRGANPairedDataset(opt)
108
+ assert dataset.io_backend_opt['type'] == 'lmdb' # io backend
109
+ assert len(dataset) == 2 # whether to read correct meta info
110
+
111
+ # test __getitem__
112
+ result = dataset.__getitem__(1)
113
+ # check returned keys
114
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
115
+ assert set(expected_keys).issubset(set(result.keys()))
116
+ # check shape and contents
117
+ assert result['gt'].shape == (3, 128, 128)
118
+ assert result['lq'].shape == (3, 32, 32)
119
+ assert result['gt_path'] == 'comic'
120
+ assert result['lq_path'] == 'comic'
121
+
122
+ # ------------------ test paired_paths_from_folder -------------------- #
123
+ opt['dataroot_gt'] = 'tests/data/gt'
124
+ opt['dataroot_lq'] = 'tests/data/lq'
125
+ opt['io_backend'] = dict(type='disk')
126
+ opt['meta_info'] = None
127
+
128
+ dataset = RealESRGANPairedDataset(opt)
129
+ assert dataset.io_backend_opt['type'] == 'disk' # io backend
130
+ assert len(dataset) == 2 # whether to read correct meta info
131
+
132
+ # test __getitem__
133
+ result = dataset.__getitem__(0)
134
+ # check returned keys
135
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
136
+ assert set(expected_keys).issubset(set(result.keys()))
137
+ # check shape and contents
138
+ assert result['gt'].shape == (3, 128, 128)
139
+ assert result['lq'].shape == (3, 32, 32)
140
+
141
+ # ------------------ test normalization -------------------- #
142
+ dataset.mean = [0.5, 0.5, 0.5]
143
+ dataset.std = [0.5, 0.5, 0.5]
144
+ # test __getitem__
145
+ result = dataset.__getitem__(0)
146
+ # check returned keys
147
+ expected_keys = ['gt', 'lq', 'gt_path', 'lq_path']
148
+ assert set(expected_keys).issubset(set(result.keys()))
149
+ # check shape and contents
150
+ assert result['gt'].shape == (3, 128, 128)
151
+ assert result['lq'].shape == (3, 32, 32)
tests/test_discriminator_arch.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
4
+
5
+
6
+ def test_unetdiscriminatorsn():
7
+ """Test arch: UNetDiscriminatorSN."""
8
+
9
+ # model init and forward (cpu)
10
+ net = UNetDiscriminatorSN(num_in_ch=3, num_feat=4, skip_connection=True)
11
+ img = torch.rand((1, 3, 32, 32), dtype=torch.float32)
12
+ output = net(img)
13
+ assert output.shape == (1, 1, 32, 32)
14
+
15
+ # model init and forward (gpu)
16
+ if torch.cuda.is_available():
17
+ net.cuda()
18
+ output = net(img.cuda())
19
+ assert output.shape == (1, 1, 32, 32)
tests/test_model.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ from basicsr.archs.rrdbnet_arch import RRDBNet
4
+ from basicsr.data.paired_image_dataset import PairedImageDataset
5
+ from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
6
+
7
+ from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
8
+ from realesrgan.models.realesrgan_model import RealESRGANModel
9
+ from realesrgan.models.realesrnet_model import RealESRNetModel
10
+
11
+
12
+ def test_realesrnet_model():
13
+ with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
14
+ opt = yaml.load(f, Loader=yaml.FullLoader)
15
+
16
+ # build model
17
+ model = RealESRNetModel(opt)
18
+ # test attributes
19
+ assert model.__class__.__name__ == 'RealESRNetModel'
20
+ assert isinstance(model.net_g, RRDBNet)
21
+ assert isinstance(model.cri_pix, L1Loss)
22
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
23
+
24
+ # prepare data
25
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
26
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
27
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
28
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
29
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
30
+ model.feed_data(data)
31
+ # check dequeue
32
+ model.feed_data(data)
33
+ # check data shape
34
+ assert model.lq.shape == (1, 3, 8, 8)
35
+ assert model.gt.shape == (1, 3, 32, 32)
36
+
37
+ # change probability to test if-else
38
+ model.opt['gaussian_noise_prob'] = 0
39
+ model.opt['gray_noise_prob'] = 0
40
+ model.opt['second_blur_prob'] = 0
41
+ model.opt['gaussian_noise_prob2'] = 0
42
+ model.opt['gray_noise_prob2'] = 0
43
+ model.feed_data(data)
44
+ # check data shape
45
+ assert model.lq.shape == (1, 3, 8, 8)
46
+ assert model.gt.shape == (1, 3, 32, 32)
47
+
48
+ # ----------------- test nondist_validation -------------------- #
49
+ # construct dataloader
50
+ dataset_opt = dict(
51
+ name='Demo',
52
+ dataroot_gt='tests/data/gt',
53
+ dataroot_lq='tests/data/lq',
54
+ io_backend=dict(type='disk'),
55
+ scale=4,
56
+ phase='val')
57
+ dataset = PairedImageDataset(dataset_opt)
58
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
59
+ assert model.is_train is True
60
+ model.nondist_validation(dataloader, 1, None, False)
61
+ assert model.is_train is True
62
+
63
+
64
+ def test_realesrgan_model():
65
+ with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
66
+ opt = yaml.load(f, Loader=yaml.FullLoader)
67
+
68
+ # build model
69
+ model = RealESRGANModel(opt)
70
+ # test attributes
71
+ assert model.__class__.__name__ == 'RealESRGANModel'
72
+ assert isinstance(model.net_g, RRDBNet) # generator
73
+ assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
74
+ assert isinstance(model.cri_pix, L1Loss)
75
+ assert isinstance(model.cri_perceptual, PerceptualLoss)
76
+ assert isinstance(model.cri_gan, GANLoss)
77
+ assert isinstance(model.optimizers[0], torch.optim.Adam)
78
+ assert isinstance(model.optimizers[1], torch.optim.Adam)
79
+
80
+ # prepare data
81
+ gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
82
+ kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
83
+ kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
84
+ sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
85
+ data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
86
+ model.feed_data(data)
87
+ # check dequeue
88
+ model.feed_data(data)
89
+ # check data shape
90
+ assert model.lq.shape == (1, 3, 8, 8)
91
+ assert model.gt.shape == (1, 3, 32, 32)
92
+
93
+ # change probability to test if-else
94
+ model.opt['gaussian_noise_prob'] = 0
95
+ model.opt['gray_noise_prob'] = 0
96
+ model.opt['second_blur_prob'] = 0
97
+ model.opt['gaussian_noise_prob2'] = 0
98
+ model.opt['gray_noise_prob2'] = 0
99
+ model.feed_data(data)
100
+ # check data shape
101
+ assert model.lq.shape == (1, 3, 8, 8)
102
+ assert model.gt.shape == (1, 3, 32, 32)
103
+
104
+ # ----------------- test nondist_validation -------------------- #
105
+ # construct dataloader
106
+ dataset_opt = dict(
107
+ name='Demo',
108
+ dataroot_gt='tests/data/gt',
109
+ dataroot_lq='tests/data/lq',
110
+ io_backend=dict(type='disk'),
111
+ scale=4,
112
+ phase='val')
113
+ dataset = PairedImageDataset(dataset_opt)
114
+ dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
115
+ assert model.is_train is True
116
+ model.nondist_validation(dataloader, 1, None, False)
117
+ assert model.is_train is True
118
+
119
+ # ----------------- test optimize_parameters -------------------- #
120
+ model.feed_data(data)
121
+ model.optimize_parameters(1)
122
+ assert model.output.shape == (1, 3, 32, 32)
123
+ assert isinstance(model.log_dict, dict)
124
+ # check returned keys
125
+ expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
126
+ assert set(expected_keys).issubset(set(model.log_dict.keys()))
tests/test_utils.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from basicsr.archs.rrdbnet_arch import RRDBNet
3
+
4
+ from realesrgan.utils import RealESRGANer
5
+
6
+
7
+ def test_realesrganer():
8
+ # initialize with default model
9
+ restorer = RealESRGANer(
10
+ scale=4,
11
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus.pth',
12
+ model=None,
13
+ tile=10,
14
+ tile_pad=10,
15
+ pre_pad=2,
16
+ half=False)
17
+ assert isinstance(restorer.model, RRDBNet)
18
+ assert restorer.half is False
19
+ # initialize with user-defined model
20
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
21
+ restorer = RealESRGANer(
22
+ scale=4,
23
+ model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
24
+ model=model,
25
+ tile=10,
26
+ tile_pad=10,
27
+ pre_pad=2,
28
+ half=True)
29
+ # test attribute
30
+ assert isinstance(restorer.model, RRDBNet)
31
+ assert restorer.half is True
32
+
33
+ # ------------------ test pre_process ---------------- #
34
+ img = np.random.random((12, 12, 3)).astype(np.float32)
35
+ restorer.pre_process(img)
36
+ assert restorer.img.shape == (1, 3, 14, 14)
37
+ # with modcrop
38
+ restorer.scale = 1
39
+ restorer.pre_process(img)
40
+ assert restorer.img.shape == (1, 3, 16, 16)
41
+
42
+ # ------------------ test process ---------------- #
43
+ restorer.process()
44
+ assert restorer.output.shape == (1, 3, 64, 64)
45
+
46
+ # ------------------ test post_process ---------------- #
47
+ restorer.mod_scale = 4
48
+ output = restorer.post_process()
49
+ assert output.shape == (1, 3, 60, 60)
50
+
51
+ # ------------------ test tile_process ---------------- #
52
+ restorer.scale = 4
53
+ img = np.random.random((12, 12, 3)).astype(np.float32)
54
+ restorer.pre_process(img)
55
+ restorer.tile_process()
56
+ assert restorer.output.shape == (1, 3, 64, 64)
57
+
58
+ # ------------------ test enhance ---------------- #
59
+ img = np.random.random((12, 12, 3)).astype(np.float32)
60
+ result = restorer.enhance(img, outscale=2)
61
+ assert result[0].shape == (24, 24, 3)
62
+ assert result[1] == 'RGB'
63
+
64
+ # ------------------ test enhance with 16-bit image---------------- #
65
+ img = np.random.random((4, 4, 3)).astype(np.uint16) + 512
66
+ result = restorer.enhance(img, outscale=2)
67
+ assert result[0].shape == (8, 8, 3)
68
+ assert result[1] == 'RGB'
69
+
70
+ # ------------------ test enhance with gray image---------------- #
71
+ img = np.random.random((4, 4)).astype(np.float32)
72
+ result = restorer.enhance(img, outscale=2)
73
+ assert result[0].shape == (8, 8)
74
+ assert result[1] == 'L'
75
+
76
+ # ------------------ test enhance with RGBA---------------- #
77
+ img = np.random.random((4, 4, 4)).astype(np.float32)
78
+ result = restorer.enhance(img, outscale=2)
79
+ assert result[0].shape == (8, 8, 4)
80
+ assert result[1] == 'RGBA'
81
+
82
+ # ------------------ test enhance with RGBA, alpha_upsampler---------------- #
83
+ restorer.tile_size = 0
84
+ img = np.random.random((4, 4, 4)).astype(np.float32)
85
+ result = restorer.enhance(img, outscale=2, alpha_upsampler=None)
86
+ assert result[0].shape == (8, 8, 4)
87
+ assert result[1] == 'RGBA'