sergeipetrov's picture
Update handler.py
58c4864 verified
raw
history blame
1.32 kB
from typing import Dict, List, Any
from diffusers import AutoPipelineForInpainting
from PIL import Image
from io import BytesIO
import base64
import torch
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16")
def __call__(self, data: Dict[str, Any]):
"""
data args:
image: b64 string
mask: b64 string
prompt string
returns:
image
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
mask = Image.open(BytesIO(base64.b64decode(inputs['mask'])))
prompt = inputs['prompt']
# fix the seed
generator = torch.Generator(device="cuda").manual_seed(0)
image = self.pipeline(
prompt=prompt,
image=image,
mask_image=mask,
guidance_scale=8.0,
num_inference_steps=20, # steps between 15 and 30 work well for us (from model card)
strength=0.99, # make sure to use `strength` below 1.0
generator=generator,
).images[0]
return image