Justin Dulay commited on
Commit
8b26614
1 Parent(s): 39bb242

update for batches

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ env/
handler.py CHANGED
@@ -10,10 +10,632 @@ import torch
10
  from torch import nn
11
 
12
 
13
- import subprocess
14
- result = subprocess.run(["pip", "install", "git+https://github.com/sberbank-ai/Real-ESRGAN.git"], check=True)
15
- print(f"git+https://github.com/sberbank-ai/Real-ESRGAN.git = {result}")
16
- from RealESRGAN import RealESRGAN
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  class EndpointHandler():
@@ -32,18 +654,42 @@ class EndpointHandler():
32
  A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
33
  """
34
  inputs = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # decode base64 image to PIL
37
- image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
38
 
39
- # forward pass
40
- output_image = self.model.predict(image)
41
 
42
- # base64 encode output
43
- buffered = BytesIO()
44
- output_image = output_image.convert('RGB')
45
- output_image.save(buffered, format="png")
46
- img_str = base64.b64encode(buffered.getvalue())
47
 
48
- # postprocess the prediction
49
- return {"image": img_str.decode()}
 
10
  from torch import nn
11
 
12
 
13
+ # import subprocess
14
+ # result = subprocess.run(["pip", "install", "git+https://github.com/sberbank-ai/Real-ESRGAN.git"], check=True)
15
+ # print(f"git+https://github.com/sberbank-ai/Real-ESRGAN.git = {result}")
16
+ # from RealESRGAN import RealESRGAN
17
+
18
+ # no need to install, just take in all of the necessary files from the notebook
19
+ import math
20
+ import torch
21
+ from torch import nn as nn
22
+ from torch.nn import functional as F
23
+ from torch.nn import init as init
24
+ from torch.nn.modules.batchnorm import _BatchNorm
25
+
26
+ @torch.no_grad()
27
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
28
+ """Initialize network weights.
29
+
30
+ Args:
31
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
32
+ scale (float): Scale initialized weights, especially for residual
33
+ blocks. Default: 1.
34
+ bias_fill (float): The value to fill bias. Default: 0
35
+ kwargs (dict): Other arguments for initialization function.
36
+ """
37
+ if not isinstance(module_list, list):
38
+ module_list = [module_list]
39
+ for module in module_list:
40
+ for m in module.modules():
41
+ if isinstance(m, nn.Conv2d):
42
+ init.kaiming_normal_(m.weight, **kwargs)
43
+ m.weight.data *= scale
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+ elif isinstance(m, nn.Linear):
47
+ init.kaiming_normal_(m.weight, **kwargs)
48
+ m.weight.data *= scale
49
+ if m.bias is not None:
50
+ m.bias.data.fill_(bias_fill)
51
+ elif isinstance(m, _BatchNorm):
52
+ init.constant_(m.weight, 1)
53
+ if m.bias is not None:
54
+ m.bias.data.fill_(bias_fill)
55
+
56
+
57
+ def make_layer(basic_block, num_basic_block, **kwarg):
58
+ """Make layers by stacking the same blocks.
59
+
60
+ Args:
61
+ basic_block (nn.module): nn.module class for basic block.
62
+ num_basic_block (int): number of blocks.
63
+
64
+ Returns:
65
+ nn.Sequential: Stacked blocks in nn.Sequential.
66
+ """
67
+ layers = []
68
+ for _ in range(num_basic_block):
69
+ layers.append(basic_block(**kwarg))
70
+ return nn.Sequential(*layers)
71
+
72
+
73
+ class ResidualBlockNoBN(nn.Module):
74
+ """Residual block without BN.
75
+
76
+ It has a style of:
77
+ ---Conv-ReLU-Conv-+-
78
+ |________________|
79
+
80
+ Args:
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64.
83
+ res_scale (float): Residual scale. Default: 1.
84
+ pytorch_init (bool): If set to True, use pytorch default init,
85
+ otherwise, use default_init_weights. Default: False.
86
+ """
87
+
88
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
89
+ super(ResidualBlockNoBN, self).__init__()
90
+ self.res_scale = res_scale
91
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
92
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
93
+ self.relu = nn.ReLU(inplace=True)
94
+
95
+ if not pytorch_init:
96
+ default_init_weights([self.conv1, self.conv2], 0.1)
97
+
98
+ def forward(self, x):
99
+ identity = x
100
+ out = self.conv2(self.relu(self.conv1(x)))
101
+ return identity + out * self.res_scale
102
+
103
+
104
+ class Upsample(nn.Sequential):
105
+ """Upsample module.
106
+
107
+ Args:
108
+ scale (int): Scale factor. Supported scales: 2^n and 3.
109
+ num_feat (int): Channel number of intermediate features.
110
+ """
111
+
112
+ def __init__(self, scale, num_feat):
113
+ m = []
114
+ if (scale & (scale - 1)) == 0: # scale = 2^n
115
+ for _ in range(int(math.log(scale, 2))):
116
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
117
+ m.append(nn.PixelShuffle(2))
118
+ elif scale == 3:
119
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
120
+ m.append(nn.PixelShuffle(3))
121
+ else:
122
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
123
+ super(Upsample, self).__init__(*m)
124
+
125
+
126
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
127
+ """Warp an image or feature map with optical flow.
128
+
129
+ Args:
130
+ x (Tensor): Tensor with size (n, c, h, w).
131
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
132
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
133
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
134
+ Default: 'zeros'.
135
+ align_corners (bool): Before pytorch 1.3, the default value is
136
+ align_corners=True. After pytorch 1.3, the default value is
137
+ align_corners=False. Here, we use the True as default.
138
+
139
+ Returns:
140
+ Tensor: Warped image or feature map.
141
+ """
142
+ assert x.size()[-2:] == flow.size()[1:3]
143
+ _, _, h, w = x.size()
144
+ # create mesh grid
145
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
146
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
147
+ grid.requires_grad = False
148
+
149
+ vgrid = grid + flow
150
+ # scale grid to [-1,1]
151
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
152
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
153
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
154
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
155
+
156
+ # TODO, what if align_corners=False
157
+ return output
158
+
159
+
160
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
161
+ """Resize a flow according to ratio or shape.
162
+
163
+ Args:
164
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
165
+ size_type (str): 'ratio' or 'shape'.
166
+ sizes (list[int | float]): the ratio for resizing or the final output
167
+ shape.
168
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
169
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
170
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
171
+ ratio > 1.0).
172
+ 2) The order of output_size should be [out_h, out_w].
173
+ interp_mode (str): The mode of interpolation for resizing.
174
+ Default: 'bilinear'.
175
+ align_corners (bool): Whether align corners. Default: False.
176
+
177
+ Returns:
178
+ Tensor: Resized flow.
179
+ """
180
+ _, _, flow_h, flow_w = flow.size()
181
+ if size_type == 'ratio':
182
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
183
+ elif size_type == 'shape':
184
+ output_h, output_w = sizes[0], sizes[1]
185
+ else:
186
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
187
+
188
+ input_flow = flow.clone()
189
+ ratio_h = output_h / flow_h
190
+ ratio_w = output_w / flow_w
191
+ input_flow[:, 0, :, :] *= ratio_w
192
+ input_flow[:, 1, :, :] *= ratio_h
193
+ resized_flow = F.interpolate(
194
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
195
+ return resized_flow
196
+
197
+
198
+ # TODO: may write a cpp file
199
+ def pixel_unshuffle(x, scale):
200
+ """ Pixel unshuffle.
201
+
202
+ Args:
203
+ x (Tensor): Input feature with shape (b, c, hh, hw).
204
+ scale (int): Downsample ratio.
205
+
206
+ Returns:
207
+ Tensor: the pixel unshuffled feature.
208
+ """
209
+ print('PIXEL UNSHUFFLE X SIZE', x.size())
210
+ output = []
211
+ # new batch size for it here
212
+ b, c, hh, hw = x.size()
213
+
214
+ # okay ugh, what is this all doing ...
215
+ # i mean you could concat each of those in a llok
216
+ out_channel = c * (scale**2)
217
+ assert hh % scale == 0 and hw % scale == 0
218
+ h = hh // scale
219
+ w = hw // scale
220
+ x_view = x.view(b, c, h, scale, w, scale)
221
+ x_view = x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
222
+
223
+ # output = torch.stack(output)
224
+ print('output shape', x_view.shape)
225
+ # 1/0
226
+ return x_view
227
+
228
+
229
+ import os
230
+ import torch
231
+ from torch.nn import functional as F
232
+ from PIL import Image
233
+ import numpy as np
234
+ from huggingface_hub import hf_hub_url, cached_download
235
+
236
+
237
+ HF_MODELS = {
238
+ 2: dict(
239
+ repo_id='sberbank-ai/Real-ESRGAN',
240
+ filename='RealESRGAN_x2.pth',
241
+ ),
242
+ 4: dict(
243
+ repo_id='sberbank-ai/Real-ESRGAN',
244
+ filename='RealESRGAN_x4.pth',
245
+ ),
246
+ 8: dict(
247
+ repo_id='sberbank-ai/Real-ESRGAN',
248
+ filename='RealESRGAN_x8.pth',
249
+ ),
250
+ }
251
+
252
+
253
+ class RealESRGAN:
254
+ def __init__(self, device, scale=4):
255
+ self.device = device
256
+ self.scale = scale
257
+ self.model = RRDBNet(
258
+ num_in_ch=3, num_out_ch=3, num_feat=64,
259
+ num_block=23, num_grow_ch=32, scale=scale
260
+ )
261
+
262
+ def load_weights(self, model_path, download=True):
263
+ if not os.path.exists(model_path) and download:
264
+ assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
265
+ config = HF_MODELS[self.scale]
266
+ cache_dir = os.path.dirname(model_path)
267
+ local_filename = os.path.basename(model_path)
268
+ config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
269
+ cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
270
+ print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
271
+
272
+ loadnet = torch.load(model_path)
273
+ if 'params' in loadnet:
274
+ self.model.load_state_dict(loadnet['params'], strict=True)
275
+ elif 'params_ema' in loadnet:
276
+ self.model.load_state_dict(loadnet['params_ema'], strict=True)
277
+ else:
278
+ self.model.load_state_dict(loadnet, strict=True)
279
+ self.model.eval()
280
+ self.model.to(self.device)
281
+
282
+ @torch.cuda.amp.autocast()
283
+ def predict(self, numpy_images, batch_size=4, patches_size=192,
284
+ padding=24, pad_size=15):
285
+ import time
286
+ start = time.time()
287
+ # okay i think that's good with variability for now ...
288
+ # ***IMPORTANT VARIABLE***
289
+ batch_size = len(numpy_images) * 4
290
+ scale = self.scale
291
+ device = self.device
292
+
293
+ list_of_inputs = []
294
+ for lr_image in numpy_images:
295
+ lr_image = np.array(lr_image)
296
+ lr_image = pad_reflect(lr_image, pad_size)
297
+
298
+ patches, p_shape = split_image_into_overlapping_patches(
299
+ lr_image, patch_size=patches_size, padding_size=padding
300
+ )
301
+
302
+ print('patches.shape', patches.shape)
303
+ print('p_shape', p_shape)
304
+
305
+ img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
306
+ list_of_inputs.append(img)
307
+
308
+
309
+ input_batch = torch.concat(list_of_inputs)
310
+
311
+ print('input_batch.shape', input_batch.shape)
312
+
313
+
314
+ with torch.no_grad():
315
+ # res = self.model(input_batch[0:batch_size])
316
+
317
+ # okay what does the input size really need to be?
318
+
319
+ print('input_batch.shape', input_batch.shape)
320
+ print('input_batch[0:batch_size].shape', input_batch[0:batch_size].shape)
321
+ # 1/0
322
+ res = self.model(input_batch[0:batch_size])
323
+
324
+ print('res.shape 1', res.shape)
325
+ print('batch_size', batch_size)
326
+ # 1/0
327
+ for i in range(batch_size, img.shape[0], batch_size):
328
+ print('i is', i)
329
+ res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
330
+ print('res.shape 2', res.shape)
331
+
332
+ print('res.shape 3', res.shape)
333
+
334
+ # 1/0
335
+
336
+ sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
337
+ np_sr_image_batch = sr_image.numpy()
338
+
339
+ print('np_sr_image_batch.shape', np_sr_image_batch.shape)
340
+ print('np_sr_image_batch[0].shape', np_sr_image_batch[0].shape)
341
+ # 1/0
342
+
343
+ padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
344
+
345
+ output_images = []
346
+ for i in range(0,batch_size,4):
347
+ # get first time from original input image size
348
+ scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
349
+ print('scaled_image_shape', scaled_image_shape)
350
+ print('padded_size_scaled', padded_size_scaled)
351
+ print("padding * scale", padding * scale)
352
+ np_sr_image = stich_together(
353
+ np_sr_image_batch[i:i+4], padded_image_shape=padded_size_scaled,
354
+ target_shape=scaled_image_shape, padding_size=padding * scale
355
+ )
356
+ sr_img = (np_sr_image*255).astype(np.uint8)
357
+ print('sr_img.shape', sr_img.shape)
358
+ sr_img = unpad_image(sr_img, pad_size*scale)
359
+ sr_img = Image.fromarray(sr_img)
360
+ output_images.append(sr_img)
361
+
362
+ print('len of output_images', len(output_images))
363
+
364
+ # for debugging
365
+ # for idx, image in enumerate(output_images):
366
+ # image.save(f'output_image_{idx}.png')
367
+
368
+
369
+ print("EVERYTHING TOOK", time.time() - start)
370
+
371
+ return output_images
372
+
373
+
374
+ import torch
375
+ from torch import nn as nn
376
+ from torch.nn import functional as F
377
+
378
+
379
+ class ResidualDenseBlock(nn.Module):
380
+ """Residual Dense Block.
381
+
382
+ Used in RRDB block in ESRGAN.
383
+
384
+ Args:
385
+ num_feat (int): Channel number of intermediate features.
386
+ num_grow_ch (int): Channels for each growth.
387
+ """
388
+
389
+ def __init__(self, num_feat=64, num_grow_ch=32):
390
+ super(ResidualDenseBlock, self).__init__()
391
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
392
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
393
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
394
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
395
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
396
+
397
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
398
+
399
+ # initialization
400
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
401
+
402
+ def forward(self, x):
403
+ x1 = self.lrelu(self.conv1(x))
404
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
405
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
406
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
407
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
408
+ # Emperically, we use 0.2 to scale the residual for better performance
409
+ return x5 * 0.2 + x
410
+
411
+
412
+ class RRDB(nn.Module):
413
+ """Residual in Residual Dense Block.
414
+
415
+ Used in RRDB-Net in ESRGAN.
416
+
417
+ Args:
418
+ num_feat (int): Channel number of intermediate features.
419
+ num_grow_ch (int): Channels for each growth.
420
+ """
421
+
422
+ def __init__(self, num_feat, num_grow_ch=32):
423
+ super(RRDB, self).__init__()
424
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
425
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
426
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
427
+
428
+ def forward(self, x):
429
+ # this part happens 23 times per pass
430
+ out = self.rdb1(x)
431
+ out = self.rdb2(out)
432
+ out = self.rdb3(out)
433
+ # Emperically, we use 0.2 to scale the residual for better performance
434
+ return out * 0.2 + x
435
+
436
+
437
+ class RRDBNet(nn.Module):
438
+ """Networks consisting of Residual in Residual Dense Block, which is used
439
+ in ESRGAN.
440
+
441
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
442
+
443
+ We extend ESRGAN for scale x2 and scale x1.
444
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
445
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
446
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
447
+
448
+ Args:
449
+ num_in_ch (int): Channel number of inputs.
450
+ num_out_ch (int): Channel number of outputs.
451
+ num_feat (int): Channel number of intermediate features.
452
+ Default: 64
453
+ num_block (int): Block number in the trunk network. Defaults: 23
454
+ num_grow_ch (int): Channels for each growth. Default: 32.
455
+ """
456
+
457
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
458
+ super(RRDBNet, self).__init__()
459
+
460
+ self.scale = scale
461
+ if scale == 2:
462
+ num_in_ch = num_in_ch * 4
463
+ elif scale == 1:
464
+ num_in_ch = num_in_ch * 16
465
+
466
+ print('num_in_ch', num_in_ch)
467
+
468
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
469
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
470
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
471
+ # upsample
472
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
473
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
474
+ if scale == 8:
475
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
476
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
477
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
478
+
479
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
480
+
481
+ def forward(self, x):
482
+ print('IN FORWARD, X.shape is', x.shape)
483
+ if self.scale == 2:
484
+ feat = pixel_unshuffle(x, scale=2)
485
+ elif self.scale == 1:
486
+ feat = pixel_unshuffle(x, scale=4)
487
+ else:
488
+ feat = x
489
+ print('feat shape', feat.shape)
490
+ # breaks here ...
491
+ feat = self.conv_first(feat)
492
+ body_feat = self.conv_body(self.body(feat))
493
+ feat = feat + body_feat
494
+ # upsample
495
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
496
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
497
+ if self.scale == 8:
498
+ feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
499
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
500
+ return out
501
+
502
+ import numpy as np
503
+ import torch
504
+ from PIL import Image
505
+ import os
506
+ import io
507
+
508
+ def pad_reflect(image, pad_size):
509
+ imsize = image.shape
510
+ height, width = imsize[:2]
511
+ print('imsize', imsize)
512
+ new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
513
+ new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
514
+ print('new_img.shape 1', new_img.shape)
515
+
516
+ new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
517
+ new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
518
+ new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
519
+ new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
520
+ print('new_img.shape 2', new_img.shape)
521
+
522
+ return new_img
523
+
524
+ def unpad_image(image, pad_size):
525
+ return image[pad_size:-pad_size, pad_size:-pad_size, :]
526
+
527
+
528
+ def process_array(image_array, expand=True):
529
+ """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
530
+
531
+ image_batch = image_array / 255.0
532
+ if expand:
533
+ image_batch = np.expand_dims(image_batch, axis=0)
534
+ return image_batch
535
+
536
+
537
+ def process_output(output_tensor):
538
+ """ Transforms the 4-dimensional output tensor into a suitable image format. """
539
+
540
+ sr_img = output_tensor.clip(0, 1) * 255
541
+ sr_img = np.uint8(sr_img)
542
+ return sr_img
543
+
544
+
545
+ def pad_patch(image_patch, padding_size, channel_last=True):
546
+ """ Pads image_patch with with padding_size edge values. """
547
+
548
+ if channel_last:
549
+ return np.pad(
550
+ image_patch,
551
+ ((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
552
+ 'edge',
553
+ )
554
+ else:
555
+ return np.pad(
556
+ image_patch,
557
+ ((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
558
+ 'edge',
559
+ )
560
+
561
+
562
+ def unpad_patches(image_patches, padding_size):
563
+ return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
564
+
565
+
566
+ def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
567
+ """ Splits the image into partially overlapping patches.
568
+ The patches overlap by padding_size pixels.
569
+ Pads the image twice:
570
+ - first to have a size multiple of the patch size,
571
+ - then to have equal padding at the borders.
572
+ Args:
573
+ image_array: numpy array of the input image.
574
+ patch_size: size of the patches from the original image (without padding).
575
+ padding_size: size of the overlapping area.
576
+ """
577
+
578
+ xmax, ymax, _ = image_array.shape
579
+ x_remainder = xmax % patch_size
580
+ y_remainder = ymax % patch_size
581
+
582
+ # modulo here is to avoid extending of patch_size instead of 0
583
+ x_extend = (patch_size - x_remainder) % patch_size
584
+ y_extend = (patch_size - y_remainder) % patch_size
585
+
586
+ # make sure the image is divisible into regular patches
587
+ extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
588
+
589
+ # add padding around the image to simplify computations
590
+ padded_image = pad_patch(extended_image, padding_size, channel_last=True)
591
+
592
+ xmax, ymax, _ = padded_image.shape
593
+ patches = []
594
+
595
+ x_lefts = range(padding_size, xmax - padding_size, patch_size)
596
+ y_tops = range(padding_size, ymax - padding_size, patch_size)
597
+
598
+ for x in x_lefts:
599
+ for y in y_tops:
600
+ x_left = x - padding_size
601
+ y_top = y - padding_size
602
+ x_right = x + patch_size + padding_size
603
+ y_bottom = y + patch_size + padding_size
604
+ patch = padded_image[x_left:x_right, y_top:y_bottom, :]
605
+ patches.append(patch)
606
+
607
+ return np.array(patches), padded_image.shape
608
+
609
+
610
+ def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
611
+ """ Reconstruct the image from overlapping patches.
612
+ After scaling, shapes and padding should be scaled too.
613
+ Args:
614
+ patches: patches obtained with split_image_into_overlapping_patches
615
+ padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
616
+ target_shape: shape of the final image
617
+ padding_size: size of the overlapping area.
618
+ """
619
+
620
+ xmax, ymax, _ = padded_image_shape
621
+ patches = unpad_patches(patches, padding_size)
622
+ patch_size = patches.shape[1]
623
+ n_patches_per_row = ymax // patch_size
624
+
625
+ complete_image = np.zeros((xmax, ymax, 3))
626
+
627
+ row = -1
628
+ col = 0
629
+ for i in range(len(patches)):
630
+ if i % n_patches_per_row == 0:
631
+ row += 1
632
+ col = 0
633
+ complete_image[
634
+ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
635
+ ] = patches[i]
636
+ col += 1
637
+ return complete_image[0: target_shape[0], 0: target_shape[1], :]
638
+
639
 
640
 
641
  class EndpointHandler():
 
654
  A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
655
  """
656
  inputs = data.pop("inputs", data)
657
+ if isinstance(inputs['image'], list) and len(inputs['image']) > 1:
658
+ input_images = []
659
+ for base64_string in inputs['image']:
660
+ image = Image.open(BytesIO(base64.b64decode(base64_string)))
661
+ input_images.append(image)
662
+
663
+ for i in range(len(input_images)):
664
+ input_images[i] = input_images[i].resize((194, 250))
665
+
666
+ numpy_images = [np.array(img) for img in input_images]
667
+ output_images = self.model.predict(numpy_images)
668
+
669
+ base64_strings = []
670
+ for output_image in output_images:
671
+ buffered = BytesIO()
672
+ output_image = output_image.convert('RGB')
673
+ output_image.save(buffered, format="png")
674
+ img_str = base64.b64encode(buffered.getvalue())
675
+ base64_strings.append(img_str)
676
+
677
+ return base64_strings
678
+
679
+ else:
680
+ inputs = data.pop("inputs", data)
681
 
682
+ # decode base64 image to PIL
683
+ image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
684
 
685
+ # forward pass
686
+ output_image = self.model.predict(image)
687
 
688
+ # base64 encode output
689
+ buffered = BytesIO()
690
+ output_image = output_image.convert('RGB')
691
+ output_image.save(buffered, format="png")
692
+ img_str = base64.b64encode(buffered.getvalue())
693
 
694
+ # postprocess the prediction
695
+ return {"image": img_str.decode()}
local_test.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+ import json
3
+
4
+ # init handler
5
+ my_handler = EndpointHandler(path=".")
6
+
7
+ import base64
8
+
9
+ # prepare sample payload
10
+ # non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
11
+ # holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
12
+
13
+ # with open('sample_input.json', 'r') as file:
14
+ # data = json.load(file)
15
+ import io
16
+ from PIL import Image
17
+ import requests
18
+ response = requests.get('https://mystore-12345-product-images.s3-us-east-2.amazonaws.com/0c817b58-2774-4f02-95b8-3ae379aa2e98.jpeg')
19
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
20
+
21
+ response2 = requests.get('https://mystore-12345-product-images.s3-us-east-2.amazonaws.com/3b9c698b-b7ae-4c5d-a978-2179ccc08d12.jpeg')
22
+ image2 = Image.open(io.BytesIO(response2.content)).convert('RGB')
23
+
24
+ response3 = requests.get('https://mystore-12345-product-images.s3-us-east-2.amazonaws.com/72801dfa-5d6a-442e-91cd-80bdb394a323.jpeg')
25
+ image3 = Image.open(io.BytesIO(response3.content)).convert('RGB')
26
+
27
+ pil_images = [image.copy() for i in range(10)]
28
+ pil_images[1] = image2.copy()
29
+ pil_images[2] = image3.copy()
30
+
31
+ base64_strings = []
32
+ for output_image in pil_images:
33
+ buffered = io.BytesIO()
34
+ output_image = output_image.convert('RGB')
35
+ output_image.save(buffered, format="png")
36
+ img_str = base64.b64encode(buffered.getvalue())
37
+ base64_strings.append(img_str)
38
+
39
+ inputs = {
40
+ 'image': base64_strings
41
+ }
42
+
43
+
44
+ # test the handler
45
+ import time
46
+ start = time.time()
47
+ prediction=my_handler(inputs)
48
+ # holiday_payload=my_handler(holiday_payload)
49
+ print("inference time itself is", time.time() - start)
50
+
51
+ # show results
52
+ # print("prediction", prediction)
53
+
54
+ print("type of prediction", type(prediction))
55
+
56
+ data_json = prediction
57
+
58
+
59
+ print("type of prediction", data_json.keys())
60
+
61
+ img_str = data_json['image']
output_image_0.png ADDED
output_image_1.png ADDED
output_image_2.png ADDED
output_image_3.png ADDED
output_image_4.png ADDED
output_image_5.png ADDED
output_image_6.png ADDED
output_image_7.png ADDED
output_image_8.png ADDED
output_image_9.png ADDED