Spaces:
Build error
Build error
natexcvi
commited on
Commit
•
cd7b3bb
1
Parent(s):
790a68a
Fixes
Browse files- Makefile +2 -0
- app.py +8 -4
- test_service.py +2 -2
Makefile
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
test:
|
2 |
+
python -m pytest -vv
|
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
from typing import Union
|
3 |
|
|
|
4 |
from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
|
5 |
from fastapi.security import APIKeyQuery
|
6 |
|
@@ -42,7 +43,7 @@ async def calculate_embedding(
|
|
42 |
try:
|
43 |
image_content = await image.read()
|
44 |
pred = model.predict(model.preprocess(image_content))
|
45 |
-
return EmbeddingResponse(embedding=pred.tolist())
|
46 |
except Exception as e:
|
47 |
return Response(
|
48 |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
@@ -62,9 +63,12 @@ async def calculate_similarity_score(
|
|
62 |
try:
|
63 |
image1_content = await image1.read()
|
64 |
image2_content = await image2.read()
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
68 |
except Exception as e:
|
69 |
return Response(
|
70 |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
|
1 |
import os
|
2 |
from typing import Union
|
3 |
|
4 |
+
import numpy as np
|
5 |
from fastapi import Depends, FastAPI, File, HTTPException, Response, UploadFile, status
|
6 |
from fastapi.security import APIKeyQuery
|
7 |
|
|
|
43 |
try:
|
44 |
image_content = await image.read()
|
45 |
pred = model.predict(model.preprocess(image_content))
|
46 |
+
return EmbeddingResponse(embedding=pred[0].tolist())
|
47 |
except Exception as e:
|
48 |
return Response(
|
49 |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
|
|
63 |
try:
|
64 |
image1_content = await image1.read()
|
65 |
image2_content = await image2.read()
|
66 |
+
pred = model.predict(
|
67 |
+
np.vstack(
|
68 |
+
[model.preprocess(image1_content), model.preprocess(image2_content)]
|
69 |
+
)
|
70 |
+
)
|
71 |
+
return SimilarityResponse(score=float(model.distance(pred[0], pred[1])))
|
72 |
except Exception as e:
|
73 |
return Response(
|
74 |
content=str(e), status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
test_service.py
CHANGED
@@ -45,13 +45,13 @@ def test_embed(client: TestClient, image_path):
|
|
45 |
pytest.param(
|
46 |
"testdata/face_pic.jpeg",
|
47 |
"testdata/face_pic3.jpeg",
|
48 |
-
0.
|
49 |
id="similar expression",
|
50 |
),
|
51 |
pytest.param(
|
52 |
"testdata/face_pic.jpeg",
|
53 |
"testdata/face_pic2.jpeg",
|
54 |
-
|
55 |
id="different expression",
|
56 |
),
|
57 |
],
|
|
|
45 |
pytest.param(
|
46 |
"testdata/face_pic.jpeg",
|
47 |
"testdata/face_pic3.jpeg",
|
48 |
+
0.0588636,
|
49 |
id="similar expression",
|
50 |
),
|
51 |
pytest.param(
|
52 |
"testdata/face_pic.jpeg",
|
53 |
"testdata/face_pic2.jpeg",
|
54 |
+
1.413582,
|
55 |
id="different expression",
|
56 |
),
|
57 |
],
|