natexcvi
Impl. + tests
b87deef unverified
raw
history blame
1.74 kB
import os
from typing import Union
from fastapi import (
Depends,
FastAPI,
File,
HTTPException,
Query,
Response,
UploadFile,
status,
)
from model import Model
app = FastAPI()
model = Model(
os.getenv("MODEL_REPO_ID", ""),
os.getenv("MODEL_FILENAME", ""),
os.getenv("HF_TOKEN", ""),
)
async def validate_token(
token: Union[str, None] = Query(default=None),
):
if token is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
if token != os.getenv("CLIENT_TOKEN"):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
return token
@app.post("/embed", status_code=status.HTTP_200_OK)
async def calculate_embedding(
image: UploadFile = File(...), _: str = Depends(validate_token)
):
try:
image_content = await image.read()
pred = model.predict(model.preprocess(image_content))
return {"embedding": pred.tolist()}
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@app.post("/similarity", status_code=status.HTTP_200_OK)
async def calculate_similarity_score(
image1: UploadFile = File(...),
image2: UploadFile = File(...),
_: str = Depends(validate_token),
):
try:
image1_content = await image1.read()
image2_content = await image2.read()
pred1 = model.predict(model.preprocess(image1_content))
pred2 = model.predict(model.preprocess(image2_content))
return {"score": float(model.distance(pred1, pred2))}
except Exception as e:
return Response(
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)