haydpw's picture
change calculation
9471dba
raw
history blame
2.66 kB
from typing import Union
import dotenv
import traceback
import json
import io
import os
import base64
from fastapi import FastAPI, File, HTTPException, UploadFile, Response
import models.face_classifier as classifier
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from rembg import remove
from utils.helpers import image_to_base64, calculate_mask_area
dotenv.load_dotenv()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# CLIENT = create_client(os.getenv("ROBOFLOW_API_KEY"))
# model = FaceClassifierModel(client=CLIENT)
model = classifier.FaceSegmentationModel()
@app.post("/segment/", summary="Classify skin type based on image given",tags=["Classify"])
async def predict_image(file: UploadFile = File(...)):
try:
# Menangani file yang diunggah
image_file = await file.read()
pil_image= Image.open(io.BytesIO(image_file)).convert("RGB")
# resize image to 512x512
pil_image = pil_image.resize((500, 500))
image_bg_removed= remove(pil_image, bgcolor=(0,0,255,255))
# Memanggil metode classify untuk melakukan klasifikasi
results = model.infer(image_bg_removed)
print(len(results))
background_element = next((element for element in results if element['label'] == 'background'), None)
if background_element:
background_area = calculate_mask_area(background_element['mask'], True)
else:
background_area = 0
# change the mask to base64 and calculate the score
for i in range(len(results)):
results[i]["mask"] = image_to_base64(results[i]["mask"])
if results[i]["label"] == "background":
continue
mask_area = calculate_mask_area(results[i]["mask"])
score = mask_area / (500 * 500 - background_area)
results[i]["score"] = score
# add original image base 64 as original image:
image_bg_removed = image_bg_removed.convert("RGB")
response = {
"original_image": image_to_base64(image_bg_removed),
"segmentation_results":results
}
# Kembalikan hasil klasifikasi
return Response(content=json.dumps(response), status_code=200)
except Exception as e:
# Mendapatkan stack trace
error_traceback = traceback.format_exc()
raise HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")