| | from typing import Dict, List, Any |
| | from PIL import Image |
| | import io |
| | import base64 |
| | from rembg import remove, new_session |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | |
| | |
| | |
| | model_name = "u2netp" |
| | |
| | self.session = new_session(model_name) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | """ |
| | Args: |
| | data (:obj: `Dict`): |
| | - "inputs": PIL.Image or base64 string |
| | - "parameters": (Optional) Dict containing rembg options like 'alpha_matting' |
| | Return: |
| | A PIL.Image or serialized image. |
| | """ |
| | |
| | inputs = data.get("inputs", None) |
| | if inputs is None: |
| | return {"error": "No inputs provided"} |
| |
|
| | |
| | |
| | if isinstance(inputs, str): |
| | |
| | image_data = base64.b64decode(inputs) |
| | image = Image.open(io.BytesIO(image_data)).convert("RGB") |
| | elif isinstance(inputs, Image.Image): |
| | image = inputs |
| | else: |
| | |
| | image = Image.open(io.BytesIO(inputs)).convert("RGB") |
| |
|
| | |
| | |
| | params = data.get("parameters", {}) |
| |
|
| | |
| | |
| | output_image = remove( |
| | image, |
| | session=self.session, |
| | **params |
| | ) |
| |
|
| | |
| | |
| | |
| | return output_image |