| | from typing import Dict, Any |
| | import os |
| | import requests |
| | from io import BytesIO |
| | from PIL import Image |
| | import torch |
| | from torchvision import transforms |
| | from transformers import AutoModelForImageSegmentation |
| |
|
| | |
| | torch.set_float32_matmul_precision(["high", "highest"][0]) |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=''): |
| | |
| | self.model = AutoModelForImageSegmentation.from_pretrained( |
| | 'zhengpeng7/BiRefNet', |
| | trust_remote_code=True |
| | ) |
| | self.model.to(device) |
| | self.model.eval() |
| | self.model.half() |
| |
|
| | def __call__(self, data: Dict[str, Any]): |
| | |
| | image_src = data["inputs"] |
| | image = None |
| | |
| | if isinstance(image_src, Image.Image): |
| | image = image_src |
| | elif isinstance(image_src, str): |
| | if image_src.startswith('http'): |
| | image = Image.open(BytesIO(requests.get(image_src).content)) |
| | else: |
| | image = Image.open(image_src) |
| | else: |
| | image = Image.open(BytesIO(image_src)) |
| |
|
| | |
| | image = image.convert("RGB") |
| | orig_size = image.size |
| | |
| | |
| | transform = transforms.Compose([ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ]) |
| | |
| | input_tensor = transform(image).unsqueeze(0).to(device).half() |
| | |
| | with torch.no_grad(): |
| | preds = self.model(input_tensor)[-1].sigmoid().cpu() |
| | |
| | |
| | pred = preds[0].squeeze() |
| | mask_pil = transforms.ToPILImage()(pred) |
| | mask_pil = mask_pil.resize(orig_size, resample=Image.Resampling.LANCZOS) |
| | |
| | |
| | image.putalpha(mask_pil) |
| | |
| | return image |