Upload realesrnet_model.py
Browse files
realesrgan/models/realesrnet_model.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
|
5 |
+
from basicsr.data.transforms import paired_random_crop
|
6 |
+
from basicsr.models.sr_model import SRModel
|
7 |
+
from basicsr.utils import DiffJPEG, USMSharp
|
8 |
+
from basicsr.utils.img_process_util import filter2D
|
9 |
+
from basicsr.utils.registry import MODEL_REGISTRY
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
@MODEL_REGISTRY.register()
|
14 |
+
class RealESRNetModel(SRModel):
|
15 |
+
"""RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
16 |
+
|
17 |
+
It is trained without GAN losses.
|
18 |
+
It mainly performs:
|
19 |
+
1. randomly synthesize LQ images in GPU tensors
|
20 |
+
2. optimize the networks with GAN training.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
super(RealESRNetModel, self).__init__(opt)
|
25 |
+
self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts
|
26 |
+
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
|
27 |
+
self.queue_size = opt.get('queue_size', 180)
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def _dequeue_and_enqueue(self):
|
31 |
+
"""It is the training pair pool for increasing the diversity in a batch.
|
32 |
+
|
33 |
+
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
|
34 |
+
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
|
35 |
+
to increase the degradation diversity in a batch.
|
36 |
+
"""
|
37 |
+
# initialize
|
38 |
+
b, c, h, w = self.lq.size()
|
39 |
+
if not hasattr(self, 'queue_lr'):
|
40 |
+
assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}'
|
41 |
+
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
|
42 |
+
_, c, h, w = self.gt.size()
|
43 |
+
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
|
44 |
+
self.queue_ptr = 0
|
45 |
+
if self.queue_ptr == self.queue_size: # the pool is full
|
46 |
+
# do dequeue and enqueue
|
47 |
+
# shuffle
|
48 |
+
idx = torch.randperm(self.queue_size)
|
49 |
+
self.queue_lr = self.queue_lr[idx]
|
50 |
+
self.queue_gt = self.queue_gt[idx]
|
51 |
+
# get first b samples
|
52 |
+
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
|
53 |
+
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
|
54 |
+
# update the queue
|
55 |
+
self.queue_lr[0:b, :, :, :] = self.lq.clone()
|
56 |
+
self.queue_gt[0:b, :, :, :] = self.gt.clone()
|
57 |
+
|
58 |
+
self.lq = lq_dequeue
|
59 |
+
self.gt = gt_dequeue
|
60 |
+
else:
|
61 |
+
# only do enqueue
|
62 |
+
self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
|
63 |
+
self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
|
64 |
+
self.queue_ptr = self.queue_ptr + b
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def feed_data(self, data):
|
68 |
+
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images.
|
69 |
+
"""
|
70 |
+
if self.is_train and self.opt.get('high_order_degradation', True):
|
71 |
+
# training data synthesis
|
72 |
+
self.gt = data['gt'].to(self.device)
|
73 |
+
# USM sharpen the GT images
|
74 |
+
if self.opt['gt_usm'] is True:
|
75 |
+
self.gt = self.usm_sharpener(self.gt)
|
76 |
+
|
77 |
+
self.kernel1 = data['kernel1'].to(self.device)
|
78 |
+
self.kernel2 = data['kernel2'].to(self.device)
|
79 |
+
self.sinc_kernel = data['sinc_kernel'].to(self.device)
|
80 |
+
|
81 |
+
ori_h, ori_w = self.gt.size()[2:4]
|
82 |
+
|
83 |
+
# ----------------------- The first degradation process ----------------------- #
|
84 |
+
# blur
|
85 |
+
out = filter2D(self.gt, self.kernel1)
|
86 |
+
# random resize
|
87 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
|
88 |
+
if updown_type == 'up':
|
89 |
+
scale = np.random.uniform(1, self.opt['resize_range'][1])
|
90 |
+
elif updown_type == 'down':
|
91 |
+
scale = np.random.uniform(self.opt['resize_range'][0], 1)
|
92 |
+
else:
|
93 |
+
scale = 1
|
94 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
95 |
+
out = F.interpolate(out, scale_factor=scale, mode=mode)
|
96 |
+
# add noise
|
97 |
+
gray_noise_prob = self.opt['gray_noise_prob']
|
98 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob']:
|
99 |
+
out = random_add_gaussian_noise_pt(
|
100 |
+
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
101 |
+
else:
|
102 |
+
out = random_add_poisson_noise_pt(
|
103 |
+
out,
|
104 |
+
scale_range=self.opt['poisson_scale_range'],
|
105 |
+
gray_prob=gray_noise_prob,
|
106 |
+
clip=True,
|
107 |
+
rounds=False)
|
108 |
+
# JPEG compression
|
109 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
|
110 |
+
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
|
111 |
+
out = self.jpeger(out, quality=jpeg_p)
|
112 |
+
|
113 |
+
# ----------------------- The second degradation process ----------------------- #
|
114 |
+
# blur
|
115 |
+
if np.random.uniform() < self.opt['second_blur_prob']:
|
116 |
+
out = filter2D(out, self.kernel2)
|
117 |
+
# random resize
|
118 |
+
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
|
119 |
+
if updown_type == 'up':
|
120 |
+
scale = np.random.uniform(1, self.opt['resize_range2'][1])
|
121 |
+
elif updown_type == 'down':
|
122 |
+
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
|
123 |
+
else:
|
124 |
+
scale = 1
|
125 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
126 |
+
out = F.interpolate(
|
127 |
+
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
|
128 |
+
# add noise
|
129 |
+
gray_noise_prob = self.opt['gray_noise_prob2']
|
130 |
+
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
|
131 |
+
out = random_add_gaussian_noise_pt(
|
132 |
+
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
|
133 |
+
else:
|
134 |
+
out = random_add_poisson_noise_pt(
|
135 |
+
out,
|
136 |
+
scale_range=self.opt['poisson_scale_range2'],
|
137 |
+
gray_prob=gray_noise_prob,
|
138 |
+
clip=True,
|
139 |
+
rounds=False)
|
140 |
+
|
141 |
+
# JPEG compression + the final sinc filter
|
142 |
+
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
|
143 |
+
# as one operation.
|
144 |
+
# We consider two orders:
|
145 |
+
# 1. [resize back + sinc filter] + JPEG compression
|
146 |
+
# 2. JPEG compression + [resize back + sinc filter]
|
147 |
+
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
|
148 |
+
if np.random.uniform() < 0.5:
|
149 |
+
# resize back + the final sinc filter
|
150 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
151 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
152 |
+
out = filter2D(out, self.sinc_kernel)
|
153 |
+
# JPEG compression
|
154 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
155 |
+
out = torch.clamp(out, 0, 1)
|
156 |
+
out = self.jpeger(out, quality=jpeg_p)
|
157 |
+
else:
|
158 |
+
# JPEG compression
|
159 |
+
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
|
160 |
+
out = torch.clamp(out, 0, 1)
|
161 |
+
out = self.jpeger(out, quality=jpeg_p)
|
162 |
+
# resize back + the final sinc filter
|
163 |
+
mode = random.choice(['area', 'bilinear', 'bicubic'])
|
164 |
+
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode)
|
165 |
+
out = filter2D(out, self.sinc_kernel)
|
166 |
+
|
167 |
+
# clamp and round
|
168 |
+
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
169 |
+
|
170 |
+
# random crop
|
171 |
+
gt_size = self.opt['gt_size']
|
172 |
+
self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale'])
|
173 |
+
|
174 |
+
# training pair pool
|
175 |
+
self._dequeue_and_enqueue()
|
176 |
+
self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
|
177 |
+
else:
|
178 |
+
# for paired training or validation
|
179 |
+
self.lq = data['lq'].to(self.device)
|
180 |
+
if 'gt' in data:
|
181 |
+
self.gt = data['gt'].to(self.device)
|
182 |
+
self.gt_usm = self.usm_sharpener(self.gt)
|
183 |
+
|
184 |
+
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
|
185 |
+
# do not use the synthetic process during validation
|
186 |
+
self.is_train = False
|
187 |
+
super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img)
|
188 |
+
self.is_train = True
|