testHF_sd / handler.py
dmingod's picture
a
473b89a
from typing import Dict, List, Any
from transformers import pipeline
import torch
import base64
from io import BytesIO
from PIL import Image
# from diffusers import StableDiffusionXLImg2ImgPipeline
# from diffusers.utils import load_image
import numpy as np
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
class EndpointHandler():
def __init__(self, path=""):
self.pipe = AutoPipelineForImage2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
self.pipe.to("cuda")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
inputs = data.pop("inputs", data)
encoded_image = data.pop("image", None)
# hyperparamters
num_inference_steps = data.pop("num_inference_steps", 25)
guidance_scale = data.pop("guidance_scale", 7.5)
negative_prompt = data.pop("negative_prompt", None)
strength = data.pop("strength", 0.7)
denoising_start = data.pop("denoising_start_step", 0)
denoising_end = data.pop("denoising_end_step", 1)
num_images_per_prompt = data.pop("num_images_per_prompt", 1)
aesthetic_score = data.pop("aesthetic_score", 0.6)
# process image
if encoded_image is not None:
image = self.decode_base64_image(encoded_image)
print("Image is getting loaded")
else:
print("Image is None")
image = None
print(f"Prompt: {inputs}, strength: {strength}, inf steps: {num_inference_steps}, denoise start: {denoising_start}, denoise_end: {denoising_end}")
print(f"Imgs per prompt: {num_images_per_prompt}, aesthetic_score: {aesthetic_score}, guidance_scale: {guidance_scale}, negative_prompt: {negative_prompt}")
# run inference pipeline
out = self.pipe(inputs,
image=image,
strength=strength,
num_inference_steps=num_inference_steps,
denoising_start=denoising_start,
denoising_end=denoising_end,
num_images_per_prompt=num_images_per_prompt,
aesthetic_score=aesthetic_score,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt
)
# return first generate PIL image
return out.images[0]
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
pil_image = Image.fromarray(np.array(image))
return pil_image