justin-shopcapsule's picture
Update handler.py
498022f
raw
history blame
25 kB
# install thing, just like in segment anything
from typing import Dict, List, Any
from PIL import Image
from io import BytesIO
from transformers import AutoModelForSemanticSegmentation, AutoFeatureExtractor
import base64
import torch
from torch import nn
# import subprocess
# result = subprocess.run(["pip", "install", "git+https://github.com/sberbank-ai/Real-ESRGAN.git"], check=True)
# print(f"git+https://github.com/sberbank-ai/Real-ESRGAN.git = {result}")
# from RealESRGAN import RealESRGAN
# no need to install, just take in all of the necessary files from the notebook
import math
import torch
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
print('PIXEL UNSHUFFLE X SIZE', x.size())
output = []
# new batch size for it here
b, c, hh, hw = x.size()
# okay ugh, what is this all doing ...
# i mean you could concat each of those in a llok
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
x_view = x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
# output = torch.stack(output)
# print('output shape', x_view.shape)
# 1/0
return x_view
import os
import torch
from torch.nn import functional as F
from PIL import Image
import numpy as np
from huggingface_hub import hf_hub_url, cached_download
HF_MODELS = {
2: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x2.pth',
),
4: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x4.pth',
),
8: dict(
repo_id='sberbank-ai/Real-ESRGAN',
filename='RealESRGAN_x8.pth',
),
}
class RealESRGAN:
def __init__(self, device, scale=4):
self.device = device
self.scale = scale
self.model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=scale
)
def load_weights(self, model_path, download=True):
if not os.path.exists(model_path) and download:
assert self.scale in [2,4,8], 'You can download models only with scales: 2, 4, 8'
config = HF_MODELS[self.scale]
cache_dir = os.path.dirname(model_path)
local_filename = os.path.basename(model_path)
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=local_filename)
print('Weights downloaded to:', os.path.join(cache_dir, local_filename))
loadnet = torch.load(model_path)
if 'params' in loadnet:
self.model.load_state_dict(loadnet['params'], strict=True)
elif 'params_ema' in loadnet:
self.model.load_state_dict(loadnet['params_ema'], strict=True)
else:
self.model.load_state_dict(loadnet, strict=True)
self.model.eval()
self.model.to(self.device)
@torch.cuda.amp.autocast()
def predict(self, numpy_images, batch_size=4, patches_size=192,
padding=24, pad_size=15):
import time
start = time.time()
# okay i think that's good with variability for now ...
# ***IMPORTANT VARIABLE***
batch_size = len(numpy_images) * 4
scale = self.scale
device = self.device
list_of_inputs = []
for lr_image in numpy_images:
lr_image = np.array(lr_image)
lr_image = pad_reflect(lr_image, pad_size)
patches, p_shape = split_image_into_overlapping_patches(
lr_image, patch_size=patches_size, padding_size=padding
)
# print('patches.shape', patches.shape)
# print('p_shape', p_shape)
img = torch.FloatTensor(patches/255).permute((0,3,1,2)).to(device).detach()
list_of_inputs.append(img)
input_batch = torch.concat(list_of_inputs)
# print('input_batch.shape', input_batch.shape)
start2 = time.time()
with torch.no_grad():
# res = self.model(input_batch[0:batch_size])
# okay what does the input size really need to be?
# print('input_batch.shape', input_batch.shape)
# print('input_batch[0:batch_size].shape', input_batch[0:batch_size].shape)
# 1/0
res = self.model(input_batch[0:batch_size])
# print('res.shape 1', res.shape)
# print('batch_size', batch_size)
# 1/0
for i in range(batch_size, img.shape[0], batch_size):
print('i is', i)
res = torch.cat((res, self.model(img[i:i+batch_size])), 0)
# print('res.shape 2', res.shape)
print('inference alone takes', time.time() - start2)
# print('res.shape 3', res.shape)
# 1/0
sr_image = res.permute((0,2,3,1)).clamp_(0, 1).cpu()
np_sr_image_batch = sr_image.numpy()
# print('np_sr_image_batch.shape', np_sr_image_batch.shape)
# print('np_sr_image_batch[0].shape', np_sr_image_batch[0].shape)
# 1/0
padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,)
output_images = []
for i in range(0,batch_size,4):
# get first time from original input image size
scaled_image_shape = tuple(np.multiply(lr_image.shape[0:2], scale)) + (3,)
# print('scaled_image_shape', scaled_image_shape)
# print('padded_size_scaled', padded_size_scaled)
# print("padding * scale", padding * scale)
np_sr_image = stich_together(
np_sr_image_batch[i:i+4], padded_image_shape=padded_size_scaled,
target_shape=scaled_image_shape, padding_size=padding * scale
)
sr_img = (np_sr_image*255).astype(np.uint8)
# print('sr_img.shape', sr_img.shape)
sr_img = unpad_image(sr_img, pad_size*scale)
sr_img = Image.fromarray(sr_img)
output_images.append(sr_img)
print('len of output_images', len(output_images))
# for debugging
# for idx, image in enumerate(output_images):
# image.save(f'output_image_{idx}.png')
print("EVERYTHING TOOK", time.time() - start)
return output_images
import torch
from torch import nn as nn
from torch.nn import functional as F
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
# this part happens 23 times per pass
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
print('num_in_ch', num_in_ch)
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if scale == 8:
self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
# print('IN FORWARD, X.shape is', x.shape)
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
# print('feat shape', feat.shape)
# breaks here ...
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 8:
feat = self.lrelu(self.conv_up3(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
import numpy as np
import torch
from PIL import Image
import os
import io
def pad_reflect(image, pad_size):
imsize = image.shape
height, width = imsize[:2]
print('imsize', imsize)
new_img = np.zeros([height+pad_size*2, width+pad_size*2, imsize[2]]).astype(np.uint8)
new_img[pad_size:-pad_size, pad_size:-pad_size, :] = image
# print('new_img.shape 1', new_img.shape)
new_img[0:pad_size, pad_size:-pad_size, :] = np.flip(image[0:pad_size, :, :], axis=0) #top
new_img[-pad_size:, pad_size:-pad_size, :] = np.flip(image[-pad_size:, :, :], axis=0) #bottom
new_img[:, 0:pad_size, :] = np.flip(new_img[:, pad_size:pad_size*2, :], axis=1) #left
new_img[:, -pad_size:, :] = np.flip(new_img[:, -pad_size*2:-pad_size, :], axis=1) #right
# print('new_img.shape 2', new_img.shape)
return new_img
def unpad_image(image, pad_size):
return image[pad_size:-pad_size, pad_size:-pad_size, :]
def process_array(image_array, expand=True):
""" Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """
image_batch = image_array / 255.0
if expand:
image_batch = np.expand_dims(image_batch, axis=0)
return image_batch
def process_output(output_tensor):
""" Transforms the 4-dimensional output tensor into a suitable image format. """
sr_img = output_tensor.clip(0, 1) * 255
sr_img = np.uint8(sr_img)
return sr_img
def pad_patch(image_patch, padding_size, channel_last=True):
""" Pads image_patch with with padding_size edge values. """
if channel_last:
return np.pad(
image_patch,
((padding_size, padding_size), (padding_size, padding_size), (0, 0)),
'edge',
)
else:
return np.pad(
image_patch,
((0, 0), (padding_size, padding_size), (padding_size, padding_size)),
'edge',
)
def unpad_patches(image_patches, padding_size):
return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]
def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2):
""" Splits the image into partially overlapping patches.
The patches overlap by padding_size pixels.
Pads the image twice:
- first to have a size multiple of the patch size,
- then to have equal padding at the borders.
Args:
image_array: numpy array of the input image.
patch_size: size of the patches from the original image (without padding).
padding_size: size of the overlapping area.
"""
xmax, ymax, _ = image_array.shape
x_remainder = xmax % patch_size
y_remainder = ymax % patch_size
# modulo here is to avoid extending of patch_size instead of 0
x_extend = (patch_size - x_remainder) % patch_size
y_extend = (patch_size - y_remainder) % patch_size
# make sure the image is divisible into regular patches
extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')
# add padding around the image to simplify computations
padded_image = pad_patch(extended_image, padding_size, channel_last=True)
xmax, ymax, _ = padded_image.shape
patches = []
x_lefts = range(padding_size, xmax - padding_size, patch_size)
y_tops = range(padding_size, ymax - padding_size, patch_size)
for x in x_lefts:
for y in y_tops:
x_left = x - padding_size
y_top = y - padding_size
x_right = x + patch_size + padding_size
y_bottom = y + patch_size + padding_size
patch = padded_image[x_left:x_right, y_top:y_bottom, :]
patches.append(patch)
return np.array(patches), padded_image.shape
def stich_together(patches, padded_image_shape, target_shape, padding_size=4):
""" Reconstruct the image from overlapping patches.
After scaling, shapes and padding should be scaled too.
Args:
patches: patches obtained with split_image_into_overlapping_patches
padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches
target_shape: shape of the final image
padding_size: size of the overlapping area.
"""
xmax, ymax, _ = padded_image_shape
patches = unpad_patches(patches, padding_size)
patch_size = patches.shape[1]
n_patches_per_row = ymax // patch_size
complete_image = np.zeros((xmax, ymax, 3))
row = -1
col = 0
for i in range(len(patches)):
if i % n_patches_per_row == 0:
row += 1
col = 0
complete_image[
row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size,:
] = patches[i]
col += 1
return complete_image[0: target_shape[0], 0: target_shape[1], :]
class EndpointHandler():
def __init__(self, path="."):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = RealESRGAN(self.device, scale=2)
self.model.load_weights('/repository/RealESRGAN_x2.pth', download=True)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
images (:obj:`PIL.Image`)
candiates (:obj:`list`)
Return:
A :obj:`list`:. The list contains items that are dicts should be liked {"label": "XXX", "score": 0.82}
"""
inputs = data.pop("inputs", data)
if isinstance(inputs['image'], list) and len(inputs['image']) > 1:
input_images = []
for base64_string in inputs['image']:
image = Image.open(BytesIO(base64.b64decode(base64_string)))
input_images.append(image)
for i in range(len(input_images)):
input_images[i] = input_images[i].resize((194, 250))
numpy_images = [np.array(img) for img in input_images]
output_images = self.model.predict(numpy_images)
base64_strings = []
for output_image in output_images:
buffered = BytesIO()
output_image = output_image.convert('RGB')
output_image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue())
base64_strings.append(img_str.decode('utf-8'))
return base64_strings
else:
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
# forward pass
output_image = self.model.predict(image)
# base64 encode output
buffered = BytesIO()
output_image = output_image.convert('RGB')
output_image.save(buffered, format="png")
img_str = base64.b64encode(buffered.getvalue())
# postprocess the prediction
return {"image": img_str.decode()}