sd-kerascv / handler.py
chansung's picture
update custom handler
44565ca
from typing import Dict, List, Any
import sys
import base64
import logging
import keras_cv
class EndpointHandler():
def __init__(self, path="", version="2"):
self.sd = self._instantiate_stable_diffusion(version)
if isinstance(self.sd, str):
sys.exit(self.sd)
else:
self.sd.text_to_image("test prompt", batch_size=1)
logging.warning(f"Stable Diffusion v{version} is fully loaded")
def _instantiate_stable_diffusion(self, version: str):
if version is "1.4":
return keras_cv.models.StableDiffusion(img_width=512, img_height=512)
elif version is "2":
return keras_cv.models.StableDiffusionV2(img_width=512, img_height=512)
else:
return f"v{version} is not supported"
def __call__(self, data: Dict[str, Any]) -> str:
prompt = data.pop("inputs", data)
batch_size = data.pop("batch_size", 1)
images = self.sd.text_to_image(prompt, batch_size=batch_size)
return base64.b64encode(images.tobytes()).decode()