Suraj Narayanan Sasikumar
bug fix
b034b83
from typing import Dict, List, Any
import torch
from diffusers import DPMSolverMultistepScheduler, DiffusionPipeline
from PIL import Image
import base64
from io import BytesIO
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type != "cuda":
raise ValueError("need to run on GPU")
class EndpointHandler:
def __init__(self, path=""):
# load StableDiffusionInpaintPipeline pipeline
self.base = DiffusionPipeline.from_pretrained(
path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
# use DPMSolverMultistepScheduler
self.base.scheduler = DPMSolverMultistepScheduler.from_config(
self.base.scheduler.config
)
# move to device
self.base = self.base.to(device)
self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True)
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.base.text_encoder_2,
vae=self.base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
# use DPMSolverMultistepScheduler
self.refiner.scheduler = DPMSolverMultistepScheduler.from_config(
self.refiner.scheduler.config
)
self.refiner = self.refiner.to(device)
self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
prompt = data.pop("inputs", None)
if prompt is None:
return {"error": "Please provide a prompt"}
# hyperparamters
use_refiner = True if data.pop("use_refiner", False) else False
num_inference_steps = data.pop("num_inference_steps", 30)
guidance_scale = data.pop("guidance_scale", 8)
negative_prompt = data.pop("negative_prompt", None)
high_noise_frac = data.pop("high_noise_frac", 0.8)
height = data.pop("height", None)
width = data.pop("width", None)
if use_refiner:
image = self.base(
prompt=prompt,
num_inference_steps=num_inference_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
out = self.refiner(
prompt=prompt,
num_inference_steps=num_inference_steps,
denoising_start=high_noise_frac,
image=image,
)
else:
out = self.base(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
height=height,
width=width,
)
# encode image as base 64
buffered = BytesIO()
out.images[0].save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
# postprocess the prediction
return {"image": img_str.decode()}
# # return first generate PIL image
# return out.images[0]