natexcvi
Add fecnet
0edd049 unverified
raw
history blame
No virus
2.07 kB
import os
import numpy as np
from fastapi import Depends, FastAPI, File, Response, UploadFile, status
from auth import validate_token
from model.model import Model
from routers.fecnet_router import router as fecnet_router
from schema import EmbeddingResponse, SimilarityResponse
app = FastAPI(
title="Facial Expression Embedding Service",
)
app.include_router(fecnet_router)
model = Model(
os.getenv("MODEL_REPO_ID", ""),
os.getenv("MODEL_FILENAME", ""),
os.getenv("HF_TOKEN", ""),
)
@app.post(
"/embed",
status_code=status.HTTP_200_OK,
dependencies=[Depends(validate_token)],
response_model=EmbeddingResponse,
)
async def calculate_embedding(
image: UploadFile = File(...),
):
try:
image_content = await image.read()
if isinstance(image_content, str):
image_content = image_content.encode()
pred = model.predict(model.preprocess(image_content))
return EmbeddingResponse(embedding=pred[0].tolist())
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post(
"/similarity",
status_code=status.HTTP_200_OK,
dependencies=[Depends(validate_token)],
response_model=SimilarityResponse,
)
async def calculate_similarity_score(
image1: UploadFile = File(...),
image2: UploadFile = File(...),
):
try:
image1_content = await image1.read()
if isinstance(image1_content, str):
image1_content = image1_content.encode()
image2_content = await image2.read()
if isinstance(image2_content, str):
image2_content = image2_content.encode()
pred = model.predict(
np.vstack(
[model.preprocess(image1_content), model.preprocess(image2_content)]
)
)
return SimilarityResponse(score=float(model.distance(pred[0], pred[1])))
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)