remove_bg_api / handler.py
whlzy's picture
Upload handler.py
9a6d4b3 verified
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
import base64
from io import BytesIO
import os
import boto3
import datetime
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForImageSegmentation.from_pretrained(
'whlzy/remove_bg_api',
trust_remote_code=True,
token=os.environ.get("HUGGINGFACE_TOKEN")
)
self.model.to(device)
self.model.eval()
image_size = (1024, 1024)
self.transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def update_to_s3(self, image):
BUCKET_NAME = 'popwear-assets'
BUCKET_PREFIX_PATH = 'removebg'
ACCOUNT_ID = '18cc2282d0ee72171c1ea322ed22983c'
ACCESS_KEY_ID = '007f1852a377a2df43a21d5c8d54542e'
SECRET_ACCESS_KEY = 'db2658e2429950bb05e15afb6c53c8b7fd23ab9e1bf79cd42604c89f276068e4'
ENDPOINT_URL = f'https://{ACCOUNT_ID}.r2.cloudflarestorage.com'
bucket_postfix_path = f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg"
image_url = f"https://assets.popwear.ai/{BUCKET_PREFIX_PATH}/{bucket_postfix_path}"
s3 = boto3.client(
's3',
endpoint_url=ENDPOINT_URL,
aws_access_key_id=ACCESS_KEY_ID,
aws_secret_access_key=SECRET_ACCESS_KEY,
region_name='auto'
)
output_buffer = BytesIO()
image.save(output_buffer, format='WEBP', quality=85, method=4)
output_buffer.seek(0)
s3.upload_fileobj(output_buffer, BUCKET_NAME, f"{BUCKET_PREFIX_PATH}/{bucket_postfix_path}")
return image_url
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
image = data.pop("inputs", data)
# image = self.decode_base64_image(image)
input_images = self.transform_image(image).unsqueeze(0).to('cuda')
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
image.putalpha(mask)
image_url = self.update_to_s3(image)
return image_url
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image