import hmac import os from typing import Union import numpy as np from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status from fastapi.security import APIKeyQuery from model import Model from schema import EmbeddingResponse, SimilarityResponse app = FastAPI( title="Facial Expression Embedding Service", ) api_key = APIKeyQuery(name="token", auto_error=False) client_token: str = os.getenv("CLIENT_TOKEN", "") model = Model( os.getenv("MODEL_REPO_ID", ""), os.getenv("MODEL_FILENAME", ""), os.getenv("HF_TOKEN", ""), ) async def validate_token( token: Union[str, None] = Depends(api_key), ): if token is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided") if not hmac.compare_digest(token, client_token): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token") return token @app.post( "/embed", status_code=status.HTTP_200_OK, dependencies=[Depends(validate_token)], response_model=EmbeddingResponse, ) async def calculate_embedding( image: UploadFile = File(...), ): try: image_content = await image.read() if isinstance(image_content, str): image_content = image_content.encode() pred = model.predict(model.preprocess(image_content)) return EmbeddingResponse(embedding=pred[0].tolist()) except Exception as e: return Response( content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) @app.post( "/similarity", status_code=status.HTTP_200_OK, dependencies=[Depends(validate_token)], response_model=SimilarityResponse, ) async def calculate_similarity_score( image1: UploadFile = File(...), image2: UploadFile = File(...), ): try: image1_content = await image1.read() if isinstance(image1_content, str): image1_content = image1_content.encode() image2_content = await image2.read() if isinstance(image2_content, str): image2_content = image2_content.encode() pred = model.predict( np.vstack( [model.preprocess(image1_content), model.preprocess(image2_content)] ) ) return SimilarityResponse(score=float(model.distance(pred[0], pred[1]))) except Exception as e: return Response( content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR )