mjv256's picture
add handler
c4b70cd
from typing import Any
import torch, base64
from PIL import Image
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
from io import BytesIO
class EndpointHandler():
def __init__(self, path=""):
self.controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v11p_sd21", torch_dtype=torch.float16)
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16)
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_model_cpu_offload()
def __call__(self, data):
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
params = data.pop("parameters", data)
prompt = params.get("prompt")
negative_prompt = params.get("negative_prompt")
def resize_image(input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
orriginal_qr_code_image = load_image(inputs)
img_path = 'https://images.squarespace-cdn.com/content/v1/59413d96e6f2e1c6837c7ecd/1536503659130-R84NUPOY4QPQTEGCTSAI/15fe1e62172035.5a87280d713e4.png'
init_image = load_image(img_path)
condition_image = resize_image(orriginal_qr_code_image, 768)
init_image = resize_image(init_image, 768)
generator = torch.manual_seed(123121231)
image = self.pipe(prompt=prompt or "a bilboard in NYC with a qrcode",
negative_prompt=negative_prompt or "ugly, disfigured, low quality, blurry, nsfw, worst quality, illustration, drawing",
image=init_image,
control_image=condition_image,
width=768,
height=768,
guidance_scale=20,
controlnet_conditioning_scale=2.5,
generator=generator,
strength=0.9,
num_inference_steps=150,
)
image = image.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return {"image": img_str.decode()}