from typing import Any import torch, base64 from PIL import Image from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, DDIMScheduler from diffusers.utils import load_image from io import BytesIO class EndpointHandler(): def __init__(self, path=""): self.controlnet = ControlNetModel.from_pretrained("DionTimmer/controlnet_qrcode-control_v11p_sd21", torch_dtype=torch.float16) self.pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16) self.pipe.enable_xformers_memory_efficient_attention() self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) self.pipe.enable_model_cpu_offload() def __call__(self, data): """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs", data) params = data.pop("parameters", data) prompt = params.get("prompt") negative_prompt = params.get("negative_prompt") def resize_image(input_image: Image, resolution: int): input_image = input_image.convert("RGB") W, H = input_image.size k = float(resolution) / min(H, W) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img orriginal_qr_code_image = load_image(inputs) img_path = 'https://images.squarespace-cdn.com/content/v1/59413d96e6f2e1c6837c7ecd/1536503659130-R84NUPOY4QPQTEGCTSAI/15fe1e62172035.5a87280d713e4.png' init_image = load_image(img_path) condition_image = resize_image(orriginal_qr_code_image, 768) init_image = resize_image(init_image, 768) generator = torch.manual_seed(123121231) image = self.pipe(prompt=prompt or "a bilboard in NYC with a qrcode", negative_prompt=negative_prompt or "ugly, disfigured, low quality, blurry, nsfw, worst quality, illustration, drawing", image=init_image, control_image=condition_image, width=768, height=768, guidance_scale=20, controlnet_conditioning_scale=2.5, generator=generator, strength=0.9, num_inference_steps=150, ) image = image.images[0] buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return {"image": img_str.decode()}