natexcvi commited on
Commit
cd7b3bb
1 Parent(s): 790a68a
Files changed (3) hide show
  1. Makefile +2 -0
  2. app.py +8 -4
  3. test_service.py +2 -2
Makefile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ test:
2
+ python -m pytest -vv
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from typing import Union
3
 
 
4
  from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
5
  from fastapi.security import APIKeyQuery
6
 
@@ -42,7 +43,7 @@ async def calculate_embedding(
42
  try:
43
  image_content = await image.read()
44
  pred = model.predict(model.preprocess(image_content))
45
- return EmbeddingResponse(embedding=pred.tolist())
46
  except Exception as e:
47
  return Response(
48
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -62,9 +63,12 @@ async def calculate_similarity_score(
62
  try:
63
  image1_content = await image1.read()
64
  image2_content = await image2.read()
65
- pred1 = model.predict(model.preprocess(image1_content))
66
- pred2 = model.predict(model.preprocess(image2_content))
67
- return SimilarityResponse(score=float(model.distance(pred1, pred2)))
 
 
 
68
  except Exception as e:
69
  return Response(
70
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
 
1
  import os
2
  from typing import Union
3
 
4
+ import numpy as np
5
  from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
6
  from fastapi.security import APIKeyQuery
7
 
 
43
  try:
44
  image_content = await image.read()
45
  pred = model.predict(model.preprocess(image_content))
46
+ return EmbeddingResponse(embedding=pred[0].tolist())
47
  except Exception as e:
48
  return Response(
49
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
 
63
  try:
64
  image1_content = await image1.read()
65
  image2_content = await image2.read()
66
+ pred = model.predict(
67
+ np.vstack(
68
+ [model.preprocess(image1_content), model.preprocess(image2_content)]
69
+ )
70
+ )
71
+ return SimilarityResponse(score=float(model.distance(pred[0], pred[1])))
72
  except Exception as e:
73
  return Response(
74
  content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
test_service.py CHANGED
@@ -45,13 +45,13 @@ def test_embed(client: TestClient, image_path):
45
  pytest.param(
46
  "testdata/face_pic.jpeg",
47
  "testdata/face_pic3.jpeg",
48
- 0.00026,
49
  id="similar expression",
50
  ),
51
  pytest.param(
52
  "testdata/face_pic.jpeg",
53
  "testdata/face_pic2.jpeg",
54
- 2,
55
  id="different expression",
56
  ),
57
  ],
 
45
  pytest.param(
46
  "testdata/face_pic.jpeg",
47
  "testdata/face_pic3.jpeg",
48
+ 0.0588636,
49
  id="similar expression",
50
  ),
51
  pytest.param(
52
  "testdata/face_pic.jpeg",
53
  "testdata/face_pic2.jpeg",
54
+ 1.413582,
55
  id="different expression",
56
  ),
57
  ],