natexcvi commited on
Commit
515d778
1 Parent(s): 067d3bb

Refactor token auth

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -11,6 +11,7 @@ from fastapi import (
11
  UploadFile,
12
  status,
13
  )
 
14
 
15
  from model import Model
16
 
@@ -18,6 +19,8 @@ app = FastAPI(
18
  title="Facial Expression Embedding Service",
19
  )
20
 
 
 
21
  model = Model(
22
  os.getenv("MODEL_REPO_ID", ""),
23
  os.getenv("MODEL_FILENAME", ""),
@@ -26,7 +29,7 @@ model = Model(
26
 
27
 
28
  async def validate_token(
29
- token: Union[str, None] = Query(default=None),
30
  ):
31
  if token is None:
32
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
@@ -35,9 +38,13 @@ async def validate_token(
35
  return token
36
 
37
 
38
- @app.post("/embed", status_code=status.HTTP_200_OK)
 
 
 
 
39
  async def calculate_embedding(
40
- image: UploadFile = File(...), _: str = Depends(validate_token)
41
  ):
42
  try:
43
  image_content = await image.read()
@@ -49,11 +56,14 @@ async def calculate_embedding(
49
  )
50
 
51
 
52
- @app.post("/similarity", status_code=status.HTTP_200_OK)
 
 
 
 
53
  async def calculate_similarity_score(
54
  image1: UploadFile = File(...),
55
  image2: UploadFile = File(...),
56
- _: str = Depends(validate_token),
57
  ):
58
  try:
59
  image1_content = await image1.read()
 
11
  UploadFile,
12
  status,
13
  )
14
+ from fastapi.security import APIKeyQuery
15
 
16
  from model import Model
17
 
 
19
  title="Facial Expression Embedding Service",
20
  )
21
 
22
+ api_key = APIKeyQuery(name="token", auto_error=False)
23
+
24
  model = Model(
25
  os.getenv("MODEL_REPO_ID", ""),
26
  os.getenv("MODEL_FILENAME", ""),
 
29
 
30
 
31
  async def validate_token(
32
+ token: Union[str, None] = Depends(api_key),
33
  ):
34
  if token is None:
35
  raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token provided")
 
38
  return token
39
 
40
 
41
+ @app.post(
42
+ "/embed",
43
+ status_code=status.HTTP_200_OK,
44
+ dependencies=[Depends(validate_token)],
45
+ )
46
  async def calculate_embedding(
47
+ image: UploadFile = File(...),
48
  ):
49
  try:
50
  image_content = await image.read()
 
56
  )
57
 
58
 
59
+ @app.post(
60
+ "/similarity",
61
+ status_code=status.HTTP_200_OK,
62
+ dependencies=[Depends(validate_token)],
63
+ )
64
  async def calculate_similarity_score(
65
  image1: UploadFile = File(...),
66
  image2: UploadFile = File(...),
 
67
  ):
68
  try:
69
  image1_content = await image1.read()