Spaces:
Build error
Build error
import os | |
from typing import Union | |
from fastapi import ( | |
Depends, | |
FastAPI, | |
File, | |
HTTPException, | |
Query, | |
Response, | |
UploadFile, | |
status, | |
) | |
from model import Model | |
app = FastAPI() | |
model = Model( | |
os.getenv("MODEL_REPO_ID", ""), | |
os.getenv("MODEL_FILENAME", ""), | |
os.getenv("HF_TOKEN", ""), | |
) | |
async def validate_token( | |
token: Union[str, None] = Query(default=None), | |
): | |
if token is None: | |
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided") | |
if token != os.getenv("CLIENT_TOKEN"): | |
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token") | |
return token | |
async def calculate_embedding( | |
image: UploadFile = File(...), _: str = Depends(validate_token) | |
): | |
try: | |
image_content = await image.read() | |
pred = model.predict(model.preprocess(image_content)) | |
return {"embedding": pred.tolist()} | |
except Exception as e: | |
return Response( | |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR | |
) | |
async def calculate_similarity_score( | |
image1: UploadFile = File(...), | |
image2: UploadFile = File(...), | |
_: str = Depends(validate_token), | |
): | |
try: | |
image1_content = await image1.read() | |
image2_content = await image2.read() | |
pred1 = model.predict(model.preprocess(image1_content)) | |
pred2 = model.predict(model.preprocess(image2_content)) | |
return {"score": float(model.distance(pred1, pred2))} | |
except Exception as e: | |
return Response( | |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR | |
) | |