sergeipetrov's picture
Update handler.py
796fee2 verified
from typing import Dict, List, Any
from diffusers import AutoPipelineForInpainting
from PIL import Image
from io import BytesIO
import base64
import torch
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = AutoPipelineForInpainting.from_pretrained("OzzyGT/RealVisXL_V4.0_inpainting", torch_dtype=torch.float16, variant="fp16").to("cuda")
def __call__(self, data: Dict[str, Any]):
"""
data args:
image: b64 string
mask: b64 string
prompt string
returns:
image
"""
inputs = data.pop("inputs", data)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
mask = Image.open(BytesIO(base64.b64decode(inputs['mask'])))
prompt = inputs['prompt']
# blur mask for smooth transition
mask = self.pipeline.mask_processor.blur(mask, blur_factor=20)
# fix the seed
generator = torch.Generator(device="cuda").manual_seed(0)
image = self.pipeline(
prompt=prompt,
image=image,
mask_image=mask,
guidance_scale=8.0,
num_inference_steps=20, # steps between 15 and 30 work well for us (from model card)
strength=0.99, # make sure to use `strength` below 1.0
generator=generator,
).images[0]
return image