Oysiyl's picture
Update handler.py
323c56a verified
raw
history blame
2.72 kB
from typing import Dict, List, Any
import base64
from PIL import Image
from io import BytesIO
from diffusers import AutoPipelineForText2Image
import torch
import numpy as np
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
# Load StableDiffusionPipeline
self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
self.pipe = AutoPipelineForText2Image.from_pretrained(self.stable_diffusion_id,
torch_dtype=dtype)
self.pipe.load_lora_weights("pytorch_lora_weights.bin")
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe = self.pipe.to(device)
self.seed = 42
# Define Generator with seed
self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs`.
:return: A dictionary with `image` field contains image in base64.
"""
prompt = data.pop("inputs", None)
seed = data.pop("seed", 42)
# Check if prompt is not provided
if prompt is None:
return {"error": "Please provide a prompt."}
# Check if seed changed
if seed is not None and seed != self.seed:
print(f"changing seed from {self.seed} to {seed}")
self.seed = seed
self.generator = torch.Generator(device="cpu").manual_seed(self.seed)
# hyperparamters
num_inference_steps = data.pop("num_inference_steps", 50)
guidance_scale = data.pop("guidance_scale", 7.5)
temperature = data.pop("temperature", 1.0)
# process image
image = self.decode_base64_image(image)
# run inference pipeline
out = self.pipe(
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
temperature=temperature,
num_images_per_prompt=1,
generator=self.generator
)
# return first generate PIL image
return out.images[0]
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image