natexcvi commited on
Commit
790a68a
1 Parent(s): 515d778

Add response schema

Browse files
Files changed (4) hide show
  1. app.py +6 -12
  2. schema.py +12 -0
  3. testdata/face_pic4.jpeg +0 -0
  4. testdata/face_pic5.jpeg +0 -0
app.py CHANGED
@@ -1,19 +1,11 @@
1
  import os
2
  from typing import Union
3
 
4
- from fastapi import (
5
- Depends,
6
- FastAPI,
7
- File,
8
- HTTPException,
9
- Query,
10
- Response,
11
- UploadFile,
12
- status,
13
- )
14
  from fastapi.security import APIKeyQuery
15
 
16
  from model import Model
 
17
 
18
  app = FastAPI(
19
  title="Facial Expression Embedding Service",
@@ -42,6 +34,7 @@ async def validate_token(
42
  "/embed",
43
  status_code=status.HTTP_200_OK,
44
  dependencies=[Depends(validate_token)],
 
45
  )
46
  async def calculate_embedding(
47
  image: UploadFile = File(...),
@@ -49,7 +42,7 @@ async def calculate_embedding(
49
  try:
50
  image_content = await image.read()
51
  pred = model.predict(model.preprocess(image_content))
52
- return {"embedding": pred.tolist()}
53
  except Exception as e:
54
  return Response(
55
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -60,6 +53,7 @@ async def calculate_embedding(
60
  "/similarity",
61
  status_code=status.HTTP_200_OK,
62
  dependencies=[Depends(validate_token)],
 
63
  )
64
  async def calculate_similarity_score(
65
  image1: UploadFile = File(...),
@@ -70,7 +64,7 @@ async def calculate_similarity_score(
70
  image2_content = await image2.read()
71
  pred1 = model.predict(model.preprocess(image1_content))
72
  pred2 = model.predict(model.preprocess(image2_content))
73
- return {"score": float(model.distance(pred1, pred2))}
74
  except Exception as e:
75
  return Response(
76
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
 
1
  import os
2
  from typing import Union
3
 
4
+ from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
 
 
 
 
 
 
 
 
 
5
  from fastapi.security import APIKeyQuery
6
 
7
  from model import Model
8
+ from schema import EmbeddingResponse, SimilarityResponse
9
 
10
  app = FastAPI(
11
  title="Facial Expression Embedding Service",
 
34
  "/embed",
35
  status_code=status.HTTP_200_OK,
36
  dependencies=[Depends(validate_token)],
37
+ response_model=EmbeddingResponse,
38
  )
39
  async def calculate_embedding(
40
  image: UploadFile = File(...),
 
42
  try:
43
  image_content = await image.read()
44
  pred = model.predict(model.preprocess(image_content))
45
+ return EmbeddingResponse(embedding=pred.tolist())
46
  except Exception as e:
47
  return Response(
48
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
 
53
  "/similarity",
54
  status_code=status.HTTP_200_OK,
55
  dependencies=[Depends(validate_token)],
56
+ response_model=SimilarityResponse,
57
  )
58
  async def calculate_similarity_score(
59
  image1: UploadFile = File(...),
 
64
  image2_content = await image2.read()
65
  pred1 = model.predict(model.preprocess(image1_content))
66
  pred2 = model.predict(model.preprocess(image2_content))
67
+ return SimilarityResponse(score=float(model.distance(pred1, pred2)))
68
  except Exception as e:
69
  return Response(
70
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
schema.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class EmbeddingResponse:
7
+ embedding: List[float]
8
+
9
+
10
+ @dataclass
11
+ class SimilarityResponse:
12
+ score: float
testdata/face_pic4.jpeg ADDED
testdata/face_pic5.jpeg ADDED