Maol commited on
Commit
45f3819
·
1 Parent(s): f4354b0

Upload utils.py

Browse files
Files changed (1) hide show
  1. realesrgan/utils.py +280 -0
realesrgan/utils.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import queue
6
+ import threading
7
+ import torch
8
+ from basicsr.utils.download_util import load_file_from_url
9
+ from torch.nn import functional as F
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+
14
+ class RealESRGANer():
15
+ """A helper class for upsampling images with RealESRGAN.
16
+
17
+ Args:
18
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
19
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
20
+ model (nn.Module): The defined network. Default: None.
21
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
22
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
23
+ 0 denotes for do not use tile. Default: 0.
24
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
25
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
26
+ half (float): Whether to use half precision during inference. Default: False.
27
+ """
28
+
29
+ def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=10, half=False):
30
+ self.scale = scale
31
+ self.tile_size = tile
32
+ self.tile_pad = tile_pad
33
+ self.pre_pad = pre_pad
34
+ self.mod_scale = None
35
+ self.half = half
36
+
37
+ # initialize model
38
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
40
+ if model_path.startswith('https://'):
41
+ model_path = load_file_from_url(
42
+ url=model_path, model_dir=os.path.join(ROOT_DIR, 'realesrgan/weights'), progress=True, file_name=None)
43
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
44
+ # prefer to use params_ema
45
+ if 'params_ema' in loadnet:
46
+ keyname = 'params_ema'
47
+ else:
48
+ keyname = 'params'
49
+ model.load_state_dict(loadnet[keyname], strict=True)
50
+ model.eval()
51
+ self.model = model.to(self.device)
52
+ if self.half:
53
+ self.model = self.model.half()
54
+
55
+ def pre_process(self, img):
56
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
57
+ """
58
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
59
+ self.img = img.unsqueeze(0).to(self.device)
60
+ if self.half:
61
+ self.img = self.img.half()
62
+
63
+ # pre_pad
64
+ if self.pre_pad != 0:
65
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
66
+ # mod pad for divisible borders
67
+ if self.scale == 2:
68
+ self.mod_scale = 2
69
+ elif self.scale == 1:
70
+ self.mod_scale = 4
71
+ if self.mod_scale is not None:
72
+ self.mod_pad_h, self.mod_pad_w = 0, 0
73
+ _, _, h, w = self.img.size()
74
+ if (h % self.mod_scale != 0):
75
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
76
+ if (w % self.mod_scale != 0):
77
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
78
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
79
+
80
+ def process(self):
81
+ # model inference
82
+ self.output = self.model(self.img)
83
+
84
+ def tile_process(self):
85
+ """It will first crop input images to tiles, and then process each tile.
86
+ Finally, all the processed tiles are merged into one images.
87
+
88
+ Modified from: https://github.com/ata4/esrgan-launcher
89
+ """
90
+ batch, channel, height, width = self.img.shape
91
+ output_height = height * self.scale
92
+ output_width = width * self.scale
93
+ output_shape = (batch, channel, output_height, output_width)
94
+
95
+ # start with black image
96
+ self.output = self.img.new_zeros(output_shape)
97
+ tiles_x = math.ceil(width / self.tile_size)
98
+ tiles_y = math.ceil(height / self.tile_size)
99
+
100
+ # loop over all tiles
101
+ for y in range(tiles_y):
102
+ for x in range(tiles_x):
103
+ # extract tile from input image
104
+ ofs_x = x * self.tile_size
105
+ ofs_y = y * self.tile_size
106
+ # input tile area on total image
107
+ input_start_x = ofs_x
108
+ input_end_x = min(ofs_x + self.tile_size, width)
109
+ input_start_y = ofs_y
110
+ input_end_y = min(ofs_y + self.tile_size, height)
111
+
112
+ # input tile area on total image with padding
113
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
114
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
115
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
116
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
117
+
118
+ # input tile dimensions
119
+ input_tile_width = input_end_x - input_start_x
120
+ input_tile_height = input_end_y - input_start_y
121
+ tile_idx = y * tiles_x + x + 1
122
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
123
+
124
+ # upscale tile
125
+ try:
126
+ with torch.no_grad():
127
+ output_tile = self.model(input_tile)
128
+ except RuntimeError as error:
129
+ print('Error', error)
130
+ print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
131
+
132
+ # output tile area on total image
133
+ output_start_x = input_start_x * self.scale
134
+ output_end_x = input_end_x * self.scale
135
+ output_start_y = input_start_y * self.scale
136
+ output_end_y = input_end_y * self.scale
137
+
138
+ # output tile area without padding
139
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
140
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
141
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
142
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
143
+
144
+ # put tile into output image
145
+ self.output[:, :, output_start_y:output_end_y,
146
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
147
+ output_start_x_tile:output_end_x_tile]
148
+
149
+ def post_process(self):
150
+ # remove extra pad
151
+ if self.mod_scale is not None:
152
+ _, _, h, w = self.output.size()
153
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
154
+ # remove prepad
155
+ if self.pre_pad != 0:
156
+ _, _, h, w = self.output.size()
157
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
158
+ return self.output
159
+
160
+ @torch.no_grad()
161
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
162
+ h_input, w_input = img.shape[0:2]
163
+ # img: numpy
164
+ img = img.astype(np.float32)
165
+ if np.max(img) > 256: # 16-bit image
166
+ max_range = 65535
167
+ print('\tInput is a 16-bit image')
168
+ else:
169
+ max_range = 255
170
+ img = img / max_range
171
+ if len(img.shape) == 2: # gray image
172
+ img_mode = 'L'
173
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
174
+ elif img.shape[2] == 4: # RGBA image with alpha channel
175
+ img_mode = 'RGBA'
176
+ alpha = img[:, :, 3]
177
+ img = img[:, :, 0:3]
178
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
179
+ if alpha_upsampler == 'realesrgan':
180
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
181
+ else:
182
+ img_mode = 'RGB'
183
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
184
+
185
+ # ------------------- process image (without the alpha channel) ------------------- #
186
+ self.pre_process(img)
187
+ if self.tile_size > 0:
188
+ self.tile_process()
189
+ else:
190
+ self.process()
191
+ output_img = self.post_process()
192
+ output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy()
193
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
194
+ if img_mode == 'L':
195
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
196
+
197
+ # ------------------- process the alpha channel if necessary ------------------- #
198
+ if img_mode == 'RGBA':
199
+ if alpha_upsampler == 'realesrgan':
200
+ self.pre_process(alpha)
201
+ if self.tile_size > 0:
202
+ self.tile_process()
203
+ else:
204
+ self.process()
205
+ output_alpha = self.post_process()
206
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
207
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
208
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
209
+ else: # use the cv2 resize for alpha channel
210
+ h, w = alpha.shape[0:2]
211
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
212
+
213
+ # merge the alpha channel
214
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
215
+ output_img[:, :, 3] = output_alpha
216
+
217
+ # ------------------------------ return ------------------------------ #
218
+ if max_range == 65535: # 16-bit image
219
+ output = (output_img * 65535.0).round().astype(np.uint16)
220
+ else:
221
+ output = (output_img * 255.0).round().astype(np.uint8)
222
+
223
+ if outscale is not None and outscale != float(self.scale):
224
+ output = cv2.resize(
225
+ output, (
226
+ int(w_input * outscale),
227
+ int(h_input * outscale),
228
+ ), interpolation=cv2.INTER_LANCZOS4)
229
+
230
+ return output, img_mode
231
+
232
+
233
+ class PrefetchReader(threading.Thread):
234
+ """Prefetch images.
235
+
236
+ Args:
237
+ img_list (list[str]): A image list of image paths to be read.
238
+ num_prefetch_queue (int): Number of prefetch queue.
239
+ """
240
+
241
+ def __init__(self, img_list, num_prefetch_queue):
242
+ super().__init__()
243
+ self.que = queue.Queue(num_prefetch_queue)
244
+ self.img_list = img_list
245
+
246
+ def run(self):
247
+ for img_path in self.img_list:
248
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
249
+ self.que.put(img)
250
+
251
+ self.que.put(None)
252
+
253
+ def __next__(self):
254
+ next_item = self.que.get()
255
+ if next_item is None:
256
+ raise StopIteration
257
+ return next_item
258
+
259
+ def __iter__(self):
260
+ return self
261
+
262
+
263
+ class IOConsumer(threading.Thread):
264
+
265
+ def __init__(self, opt, que, qid):
266
+ super().__init__()
267
+ self._queue = que
268
+ self.qid = qid
269
+ self.opt = opt
270
+
271
+ def run(self):
272
+ while True:
273
+ msg = self._queue.get()
274
+ if isinstance(msg, str) and msg == 'quit':
275
+ break
276
+
277
+ output = msg['output']
278
+ save_path = msg['save_path']
279
+ cv2.imwrite(save_path, output)
280
+ print(f'IO worker {self.qid} is done.')