natexcvi
Fix token comparison
2f18963 unverified
raw
history blame
2.46 kB
import hmac
import os
from typing import Union
import numpy as np
from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
from fastapi.security import APIKeyQuery
from model import Model
from schema import EmbeddingResponse, SimilarityResponse
app = FastAPI(
title="Facial Expression Embedding Service",
)
api_key = APIKeyQuery(name="token", auto_error=False)
client_token: str = os.getenv("CLIENT_TOKEN", "")
model = Model(
os.getenv("MODEL_REPO_ID", ""),
os.getenv("MODEL_FILENAME", ""),
os.getenv("HF_TOKEN", ""),
)
async def validate_token(
token: Union[str, None] = Depends(api_key),
):
if token is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
if not hmac.compare_digest(token, client_token):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
return token
@app.post(
"/embed",
status_code=status.HTTP_200_OK,
dependencies=[Depends(validate_token)],
response_model=EmbeddingResponse,
)
async def calculate_embedding(
image: UploadFile = File(...),
):
try:
image_content = await image.read()
if isinstance(image_content, str):
image_content = image_content.encode()
pred = model.predict(model.preprocess(image_content))
return EmbeddingResponse(embedding=pred[0].tolist())
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post(
"/similarity",
status_code=status.HTTP_200_OK,
dependencies=[Depends(validate_token)],
response_model=SimilarityResponse,
)
async def calculate_similarity_score(
image1: UploadFile = File(...),
image2: UploadFile = File(...),
):
try:
image1_content = await image1.read()
if isinstance(image1_content, str):
image1_content = image1_content.encode()
image2_content = await image2.read()
if isinstance(image2_content, str):
image2_content = image2_content.encode()
pred = model.predict(
np.vstack(
[model.preprocess(image1_content), model.preprocess(image2_content)]
)
)
return SimilarityResponse(score=float(model.distance(pred[0], pred[1])))
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)