|
from typing import Union |
|
import dotenv |
|
import traceback |
|
import json |
|
import io |
|
import os |
|
import base64 |
|
from fastapi import FastAPI, File, HTTPException, UploadFile, Response |
|
import models.face_classifier as classifier |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from PIL import Image |
|
from rembg import remove |
|
from utils.helpers import image_to_base64, calculate_mask_area |
|
|
|
|
|
dotenv.load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
model = classifier.FaceSegmentationModel() |
|
|
|
|
|
|
|
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"]) |
|
async def predict_image(file: UploadFile = File(...)): |
|
try: |
|
|
|
image_file = await file.read() |
|
pil_image= Image.open(io.BytesIO(image_file)).convert("RGB") |
|
|
|
|
|
pil_image = pil_image.resize((500, 500)) |
|
|
|
image_bg_removed= remove(pil_image, bgcolor=(0,0,255,255)) |
|
|
|
|
|
results = model.infer(image_bg_removed) |
|
|
|
print(len(results)) |
|
|
|
background_element = next((element for element in results if element['label'] == 'background'), None) |
|
|
|
if background_element: |
|
background_area = calculate_mask_area(background_element['mask'], True) |
|
else: |
|
background_area = 0 |
|
|
|
|
|
for i in range(len(results)): |
|
results[i]["mask"] = image_to_base64(results[i]["mask"]) |
|
if results[i]["label"] == "background": |
|
continue |
|
mask_area = calculate_mask_area(results[i]["mask"]) |
|
score = mask_area / (500 * 500 - background_area) |
|
results[i]["score"] = score |
|
|
|
|
|
|
|
image_bg_removed = image_bg_removed.convert("RGB") |
|
|
|
response = { |
|
"original_image": image_to_base64(image_bg_removed), |
|
"segmentation_results":results |
|
} |
|
|
|
|
|
return Response(content=json.dumps(response), status_code=200) |
|
|
|
except Exception as e: |
|
|
|
error_traceback = traceback.format_exc() |
|
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}") |