|
from fastapi.responses import FileResponse |
|
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
|
from PIL import Image, ImageFilter |
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
import numpy as np |
|
import uvicorn |
|
import torch |
|
import os |
|
|
|
app = FastAPI() |
|
|
|
|
|
preprocessor = AutoImageProcessor.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513") |
|
model = AutoModelForSemanticSegmentation.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513") |
|
|
|
|
|
os.makedirs("output", exist_ok=True) |
|
|
|
def get_segmentation_mask(image: Image.Image) -> Image.Image: |
|
"""Generate a binary segmentation mask with feathered edges from an input image.""" |
|
|
|
inputs = preprocessor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
predicted_mask = preprocessor.post_process_semantic_segmentation(outputs)[0] |
|
mask_np = predicted_mask.cpu().numpy().astype("uint8") * 255 |
|
binary_mask = Image.fromarray(mask_np) |
|
|
|
|
|
feathered_mask = binary_mask.filter(ImageFilter.GaussianBlur(1)) |
|
feathered_mask = feathered_mask.resize(image.size, Image.BICUBIC) |
|
|
|
return feathered_mask |
|
|
|
def apply_mask_to_image(image: Image.Image, mask: Image.Image) -> Image.Image: |
|
"""Apply the segmentation mask to the input image to create a transparent sticker.""" |
|
image = image.convert("RGBA") |
|
sticker = Image.new("RGBA", image.size) |
|
sticker.paste(image, (0, 0), mask) |
|
return sticker |
|
|
|
@app.post("/create_sticker/") |
|
async def create_sticker(file: UploadFile = File(...)): |
|
"""Endpoint to convert an uploaded image to a sticker.""" |
|
if file.content_type not in ["image/png", "image/jpeg"]: |
|
raise HTTPException(status_code=400, detail="Invalid image format. Only PNG and JPEG are supported.") |
|
|
|
|
|
input_image = Image.open(file.file).convert("RGB") |
|
|
|
|
|
mask = get_segmentation_mask(input_image) |
|
sticker = apply_mask_to_image(input_image, mask) |
|
|
|
|
|
output_path = f"output/sticker_{file.filename}" |
|
sticker.save(output_path, "PNG") |
|
|
|
return FileResponse(output_path, media_type="image/png") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|