natexcvi commited on
Commit
2f18963
1 Parent(s): cd7b3bb

Fix token comparison

Browse files
Files changed (2) hide show
  1. app.py +9 -1
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from typing import Union
3
 
@@ -13,6 +14,7 @@ app = FastAPI(
13
  )
14
 
15
  api_key = APIKeyQuery(name="token", auto_error=False)
 
16
 
17
  model = Model(
18
  os.getenv("MODEL_REPO_ID", ""),
@@ -26,7 +28,7 @@ async def validate_token(
26
  ):
27
  if token is None:
28
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
29
- if token != os.getenv("CLIENT_TOKEN"):
30
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
31
  return token
32
 
@@ -42,6 +44,8 @@ async def calculate_embedding(
42
  ):
43
  try:
44
  image_content = await image.read()
 
 
45
  pred = model.predict(model.preprocess(image_content))
46
  return EmbeddingResponse(embedding=pred[0].tolist())
47
  except Exception as e:
@@ -62,7 +66,11 @@ async def calculate_similarity_score(
62
  ):
63
  try:
64
  image1_content = await image1.read()
 
 
65
  image2_content = await image2.read()
 
 
66
  pred = model.predict(
67
  np.vstack(
68
  [model.preprocess(image1_content), model.preprocess(image2_content)]
 
1
+ import hmac
2
  import os
3
  from typing import Union
4
 
 
14
  )
15
 
16
  api_key = APIKeyQuery(name="token", auto_error=False)
17
+ client_token: str = os.getenv("CLIENT_TOKEN", "")
18
 
19
  model = Model(
20
  os.getenv("MODEL_REPO_ID", ""),
 
28
  ):
29
  if token is None:
30
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
31
+ if not hmac.compare_digest(token, client_token):
32
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
33
  return token
34
 
 
44
  ):
45
  try:
46
  image_content = await image.read()
47
+ if isinstance(image_content, str):
48
+ image_content = image_content.encode()
49
  pred = model.predict(model.preprocess(image_content))
50
  return EmbeddingResponse(embedding=pred[0].tolist())
51
  except Exception as e:
 
66
  ):
67
  try:
68
  image1_content = await image1.read()
69
+ if isinstance(image1_content, str):
70
+ image1_content = image1_content.encode()
71
  image2_content = await image2.read()
72
+ if isinstance(image2_content, str):
73
+ image2_content = image2_content.encode()
74
  pred = model.predict(
75
  np.vstack(
76
  [model.preprocess(image1_content), model.preprocess(image2_content)]
requirements.txt CHANGED
@@ -8,4 +8,5 @@ python-multipart
8
  mediapipe
9
  pandas
10
  pytest
11
- python-dotenv
 
 
8
  mediapipe
9
  pandas
10
  pytest
11
+ python-dotenv
12
+ bcrypt