mjv256's picture
add handler
c4b70cd
raw history blame
No virus
2.83 kB
from typing import Any
import torch, base64
from PIL import Image
from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler
from diffusers.utils import load_image
from io import BytesIO
class EndpointHandler():
def __init__(self, path=""):
self.controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v11p_sd21", torch_dtype=torch.float16)
self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16)
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_model_cpu_offload()
def __call__(self, data):
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
params = data.pop("parameters", data)
prompt = params.get("prompt")
negative_prompt = params.get("negative_prompt")
def resize_image(input_image: Image, resolution: int):
input_image = input_image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(H, W)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
orriginal_qr_code_image = load_image(inputs)
img_path = 'https://images.squarespace-cdn.com/content/v1/59413d96e6f2e1c6837c7ecd/1536503659130-R84NUPOY4QPQTEGCTSAI/15fe1e62172035.5a87280d713e4.png'
init_image = load_image(img_path)
condition_image = resize_image(orriginal_qr_code_image, 768)
init_image = resize_image(init_image, 768)
generator = torch.manual_seed(123121231)
image = self.pipe(prompt=prompt or "a bilboard in NYC with a qrcode",
negative_prompt=negative_prompt or "ugly, disfigured, low quality, blurry, nsfw, worst quality, illustration, drawing",
image=init_image,
control_image=condition_image,
width=768,
height=768,
guidance_scale=20,
controlnet_conditioning_scale=2.5,
generator=generator,
strength=0.9,
num_inference_steps=150,
)
image = image.images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return {"image": img_str.decode()}