File size: 1,743 Bytes
b87deef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from typing import Union

from fastapi import (
    Depends,
    FastAPI,
    File,
    HTTPException,
    Query,
    Response,
    UploadFile,
    status,
)

from model import Model

app = FastAPI()

model = Model(
    os.getenv("MODEL_REPO_ID", ""),
    os.getenv("MODEL_FILENAME", ""),
    os.getenv("HF_TOKEN", ""),
)


async def validate_token(
    token: Union[str, None] = Query(default=None),
):
    if token is None:
        raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
    if token != os.getenv("CLIENT_TOKEN"):
        raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
    return token


@app.post("/embed", status_code=status.HTTP_200_OK)
async def calculate_embedding(
    image: UploadFile = File(...), _: str = Depends(validate_token)
):
    try:
        image_content = await image.read()
        pred = model.predict(model.preprocess(image_content))
        return {"embedding": pred.tolist()}
    except Exception as e:
        return Response(
            content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
        )


@app.post("/similarity", status_code=status.HTTP_200_OK)
async def calculate_similarity_score(
    image1: UploadFile = File(...),
    image2: UploadFile = File(...),
    _: str = Depends(validate_token),
):
    try:
        image1_content = await image1.read()
        image2_content = await image2.read()
        pred1 = model.predict(model.preprocess(image1_content))
        pred2 = model.predict(model.preprocess(image2_content))
        return {"score": float(model.distance(pred1, pred2))}
    except Exception as e:
        return Response(
            content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
        )