|
from typing import Dict, List, Any |
|
import torch |
|
from diffusers import DPMSolverMultistepScheduler, DiffusionPipeline |
|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if device.type != "cuda": |
|
raise ValueError("need to run on GPU") |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.base = DiffusionPipeline.from_pretrained( |
|
path, torch_dtype=torch.float16, variant="fp16", use_safetensors=True |
|
) |
|
|
|
self.base.scheduler = DPMSolverMultistepScheduler.from_config( |
|
self.base.scheduler.config |
|
) |
|
|
|
self.base = self.base.to(device) |
|
self.base.unet = torch.compile(self.base.unet, mode="reduce-overhead", fullgraph=True) |
|
|
|
self.refiner = DiffusionPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-refiner-1.0", |
|
text_encoder_2=self.base.text_encoder_2, |
|
vae=self.base.vae, |
|
torch_dtype=torch.float16, |
|
use_safetensors=True, |
|
variant="fp16", |
|
) |
|
|
|
self.refiner.scheduler = DPMSolverMultistepScheduler.from_config( |
|
self.refiner.scheduler.config |
|
) |
|
self.refiner = self.refiner.to(device) |
|
self.refiner.unet = torch.compile(self.refiner.unet, mode="reduce-overhead", fullgraph=True) |
|
|
|
def __call__(self, data: Any) -> List[List[Dict[str, float]]]: |
|
""" |
|
:param data: A dictionary contains `inputs` and optional `image` field. |
|
:return: A dictionary with `image` field contains image in base64. |
|
""" |
|
prompt = data.pop("inputs", None) |
|
|
|
if prompt is None: |
|
return {"error": "Please provide a prompt"} |
|
|
|
|
|
|
|
use_refiner = True if data.pop("use_refiner", False) else False |
|
num_inference_steps = data.pop("num_inference_steps", 30) |
|
guidance_scale = data.pop("guidance_scale", 8) |
|
negative_prompt = data.pop("negative_prompt", None) |
|
high_noise_frac = data.pop("high_noise_frac", 0.8) |
|
height = data.pop("height", None) |
|
width = data.pop("width", None) |
|
|
|
if use_refiner: |
|
image = self.base( |
|
prompt=prompt, |
|
num_inference_steps=num_inference_steps, |
|
denoising_end=high_noise_frac, |
|
output_type="latent", |
|
).images |
|
out = self.refiner( |
|
prompt=prompt, |
|
num_inference_steps=num_inference_steps, |
|
denoising_start=high_noise_frac, |
|
image=image, |
|
) |
|
else: |
|
out = self.base( |
|
prompt, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=1, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
) |
|
|
|
|
|
buffered = BytesIO() |
|
out.images[0].save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
|
|
|
|
return {"image": img_str.decode()} |
|
|
|
|
|
|
|
|