checking-model-1 / handler.py
anujg1508's picture
update handler
3de8d46
from typing import Dict, List, Any
import torch
from torch import autocast
from diffusers import StableDiffusionXLPipeline
import base64
from io import BytesIO
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
if device.type != "cuda":
raise ValueError('need to run on gpu')
class EndpointHandler():
def __init__(self, path="") :
self.pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
self.pipe = self.pipe.to(device)
def __call__(self, data:Any) -> List[List[Dict[str, float]]]:
print(data)
inputs = data.pop("inputs", data)
print(device)
with autocast(device.type):
image = self.pipe(inputs, guidance_scale=7.5).images[0]
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return { "image" : img_str.decode()}