|
from typing import Dict, List, Any |
|
from diffusers import AutoPipelineForInpainting |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
import torch |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.pipeline = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16") |
|
|
|
def __call__(self, data: Dict[str, Any]): |
|
""" |
|
data args: |
|
image: b64 string |
|
mask: b64 string |
|
prompt string |
|
returns: |
|
image |
|
""" |
|
inputs = data.pop("inputs", data) |
|
|
|
|
|
image = Image.open(BytesIO(base64.b64decode(inputs['image']))) |
|
mask = Image.open(BytesIO(base64.b64decode(inputs['mask']))) |
|
prompt = inputs['prompt'] |
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(0) |
|
|
|
image = self.pipeline( |
|
prompt=prompt, |
|
image=image, |
|
mask_image=mask, |
|
guidance_scale=8.0, |
|
num_inference_steps=20, |
|
strength=0.99, |
|
generator=generator, |
|
).images[0] |
|
|
|
return image |