SAMI156's picture
Updated the code
94ea426
from fastapi.responses import FileResponse
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
from PIL import Image, ImageFilter
from fastapi import FastAPI, File, UploadFile, HTTPException
import numpy as np
import uvicorn
import torch
import os
app = FastAPI()
# Load the model and processor once to avoid reloading on every request
preprocessor = AutoImageProcessor.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")
model = AutoModelForSemanticSegmentation.from_pretrained("google/deeplabv3_mobilenet_v2_1.0_513")
# Create the directory for saving output if it doesn’t exist
os.makedirs("output", exist_ok=True)
def get_segmentation_mask(image: Image.Image) -> Image.Image:
"""Generate a binary segmentation mask with feathered edges from an input image."""
# Step 1: Preprocess and run model inference
inputs = preprocessor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Step 2: Post-process the segmentation output to get a clean binary mask
predicted_mask = preprocessor.post_process_semantic_segmentation(outputs)[0]
mask_np = predicted_mask.cpu().numpy().astype("uint8") * 255 # Convert to binary values (0 or 255)
binary_mask = Image.fromarray(mask_np)
# Step 3: Apply a slight Gaussian blur to soften the edges
feathered_mask = binary_mask.filter(ImageFilter.GaussianBlur(1)) # Adjust blur radius as needed
feathered_mask = feathered_mask.resize(image.size, Image.BICUBIC)
return feathered_mask
def apply_mask_to_image(image: Image.Image, mask: Image.Image) -> Image.Image:
"""Apply the segmentation mask to the input image to create a transparent sticker."""
image = image.convert("RGBA") # Ensure image is in RGBA mode
sticker = Image.new("RGBA", image.size)
sticker.paste(image, (0, 0), mask) # Use mask as the alpha channel
return sticker
@app.post("/create_sticker/")
async def create_sticker(file: UploadFile = File(...)):
"""Endpoint to convert an uploaded image to a sticker."""
if file.content_type not in ["image/png", "image/jpeg"]:
raise HTTPException(status_code=400, detail="Invalid image format. Only PNG and JPEG are supported.")
# Load the image
input_image = Image.open(file.file).convert("RGB")
# Generate segmentation mask and apply it to create a sticker
mask = get_segmentation_mask(input_image)
sticker = apply_mask_to_image(input_image, mask)
# Save the output sticker
output_path = f"output/sticker_{file.filename}"
sticker.save(output_path, "PNG")
return FileResponse(output_path, media_type="image/png")
# Run the app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)