from typing import Dict, List, Any import torch from diffusers import DPMSolverMultistepScheduler, DiffusionPipeline from PIL import Image import base64 from io import BytesIO # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type != "cuda": raise ValueError("need to run on GPU") class EndpointHandler: def __init__(self, path=""): # load StableDiffusionInpaintPipeline pipeline self.base = DiffusionPipeline.from_pretrained( path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True ) # use DPMSolverMultistepScheduler self.base.scheduler = DPMSolverMultistepScheduler.from_config( self.base.scheduler.config ) # move to device self.base = self.base.to(device) self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) self.refiner = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-refiner-1.0", text_encoder_2=self.base.text_encoder_2, vae=self.base.vae, torch_dtype=torch.float16, use_safetensors=True, variant="fp16", ) # use DPMSolverMultistepScheduler self.refiner.scheduler = DPMSolverMultistepScheduler.from_config( self.refiner.scheduler.config ) self.refiner = self.refiner.to(device) self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ :param data: A dictionary contains `inputs` and optional `image` field. :return: A dictionary with `image` field contains image in base64. """ prompt = data.pop("inputs", None) if prompt is None: return {"error": "Please provide a prompt"} # hyperparamters use_refiner = True if data.pop("use_refiner", False) else False num_inference_steps = data.pop("num_inference_steps", 30) guidance_scale = data.pop("guidance_scale", 8) negative_prompt = data.pop("negative_prompt", None) high_noise_frac = data.pop("high_noise_frac", 0.8) height = data.pop("height", None) width = data.pop("width", None) if use_refiner: image = self.base( prompt=prompt, num_inference_steps=num_inference_steps, denoising_end=high_noise_frac, output_type="latent", ).images out = self.refiner( prompt=prompt, num_inference_steps=num_inference_steps, denoising_start=high_noise_frac, image=image, ) else: out = self.base( prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=1, negative_prompt=negative_prompt, height=height, width=width, ) # encode image as base 64 buffered = BytesIO() out.images[0].save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) # postprocess the prediction return {"image": img_str.decode()} # # return first generate PIL image # return out.images[0]