HieuPM's picture
EDIT: remove super-resolution model and add inpainting sd
c6aeea5
from typing import Dict, List, Any
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, EulerAncestralDiscreteScheduler
from PIL import Image
import base64
from io import BytesIO
import numpy as np
# from RealESRGAN import RealESRGAN
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler():
def __init__(self, path=""):
# load StableDiffusionInpaintPipeline pipeline
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)
# use EulerAncestralDiscreteScheduler
self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
# pipe.enable_sequential_cpu_offload()
# move to device
self.pipe.to(device)
self.pipe.enable_xformers_memory_efficient_attention()
# self.upscaler = RealESRGAN(device, scale=4)
# self.upscaler.load_weights('weights/RealESRGAN_x4.pth', download=True)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
inputs = data.pop("inputs", data)
encoded_image = data.pop("image", None)
encoded_mask_image = data.pop("mask_image", None)
num_images = data.pop("num_images", None)
print(f"num_image {num_images}")
if num_images > 4 or num_images < 1:
return {"Invalid Request": "Number of generated images must be >= 1 and <=4"}
# hyperparamters
num_inference_steps = data.pop("num_inference_steps", 50)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
height = data.pop("height", None)
width = data.pop("width", None)
# process image
if encoded_image is not None and encoded_mask_image is not None:
image = self.decode_base64_image(encoded_image)
mask_image = self.decode_base64_image(encoded_mask_image)
else:
image = None
mask_image = None
# run inference pipeline
out = self.pipe(inputs,
image=image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
negative_prompt=negative_prompt,
height=height,
width=width
).images
# for i in range(len(out)):
# gen_img = Image.composite(out[i], image.resize(out[i].size), mask_image.resize(out[i].size))
# gen_img = self.upscaler.predict(gen_img)
# gen_img = Image.composite(gen_img, image.resize(gen_img.size), mask_image.resize(gen_img.size))
# out[i] = gen_img
# return first generate PIL image
json_imgs = {}
for i in range(len(out)):
buffered = BytesIO()
out[i].save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue())
json_imgs[f"{i}"] = img_str.decode()
return json_imgs
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image