timothymhowe's picture
Create handler.py
80b2ac3
from typing import Dict, List, Any
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
import base64
from io import BytesIO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("Must run SDXL on a GPU instance.")
class EndpointHandler():
def __init__(self,path=""):
self.pipe = StableDiffusionPipeline.from_pretrained(path,torch_dtype=torch.float16)
self.pipe = self.pipe.to(device)
def __call__(self):
"""
"""
inputs = data.pop("inputs",data)
with autocast(device.type):
image = self.pipe(inputs,guidance_scale=9)["sample"][0]
buffer = BytesIO()
image.save(buffer, format="JPEG")
img_str = base64.b64decode(buffer.getvalue())
return {"image": img_str.decode}