import base64 from io import BytesIO from typing import Dict, List, Any import torch from PIL import Image from diffusers import StableDiffusionPipeline REPO_ID = "runwayml/stable-diffusion-v1-5" # helper decoder def decode_base64_image(image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) return Image.open(buffer) class EndpointHandler: def __init__(self, path=""): self.pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") # self.pipe = self.pipe.to("cuda") def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ Args: data (:obj:): includes the input data and the parameters for the inference. Return: A :obj:`dict`:. base64 encoded image """ prompts = data.pop("inputs", None) encoded_image = data.pop("image", None) init_image = None if encoded_image: init_image = decode_base64_image(encoded_image) init_image.thumbnail((768, 768)) image = self.pipe(prompts, init_image=init_image).images[0] # encode image as base 64 buffered = BytesIO() image.save(buffered, format="png") # post process the prediction return {"image": buffered.getvalue()}