import os from typing import Union from fastapi import ( Depends, FastAPI, File, HTTPException, Query, Response, UploadFile, status, ) from model import Model app = FastAPI() model = Model( os.getenv("MODEL_REPO_ID", ""), os.getenv("MODEL_FILENAME", ""), os.getenv("HF_TOKEN", ""), ) async def validate_token( token: Union[str, None] = Query(default=None), ): if token is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided") if token != os.getenv("CLIENT_TOKEN"): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token") return token @app.post("/embed", status_code=status.HTTP_200_OK) async def calculate_embedding( image: UploadFile = File(...), _: str = Depends(validate_token) ): try: image_content = await image.read() pred = model.predict(model.preprocess(image_content)) return {"embedding": pred.tolist()} except Exception as e: return Response( content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR ) @app.post("/similarity", status_code=status.HTTP_200_OK) async def calculate_similarity_score( image1: UploadFile = File(...), image2: UploadFile = File(...), _: str = Depends(validate_token), ): try: image1_content = await image1.read() image2_content = await image2.read() pred1 = model.predict(model.preprocess(image1_content)) pred2 = model.predict(model.preprocess(image2_content)) return {"score": float(model.distance(pred1, pred2))} except Exception as e: return Response( content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR )