from base64 import b64encode, b64decode from io import BytesIO from pathlib import Path import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from realesrgan import RealESRGANer class EndpointHandler: def __init__(self, path=""): model = RRDBNet(num_in_ch=3, num_out_ch=3) self.upsampler = RealESRGANer( scale=4, model_path=str(Path(path) / "RealESRGAN_x4plus.pth"), model=model, tile=0, pre_pad=0, half=True, ) def __call__(self, data): """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. base64 encoded image """ image = data.pop("inputs", data) # This lets us pass local images as well while developing if isinstance(image, str): image = Image.open(BytesIO(b64decode(image))) elif isinstance(image, bytes): image = Image.open(BytesIO(image)) image = np.array(image) image = image[:, :, ::-1] # RGB -> BGR image, _ = self.upsampler.enhance(image, outscale=4) image = image[:, :, ::-1] # BGR -> RGB image = Image.fromarray(image) # Turn output image into bytestr buffered = BytesIO() image.save(buffered, format="PNG") img_bytes = b64encode(buffered.getvalue()) img_str = img_bytes.decode() return {"image": img_str}