kelvinou01
commited on
Commit
•
841a649
1
Parent(s):
b435ec9
Update handler
Browse files- handler.py +29 -7
handler.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1 |
|
|
|
|
|
2 |
import os
|
3 |
from typing import Dict, List, Any
|
|
|
4 |
import groundingdino
|
5 |
from groundingdino.util.inference import load_model, load_image, predict, annotate
|
6 |
-
import
|
7 |
|
8 |
# /app
|
9 |
HOME = os.getcwd()
|
@@ -20,6 +23,9 @@ class EndpointHandler():
|
|
20 |
|
21 |
self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
|
22 |
|
|
|
|
|
|
|
23 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
24 |
"""
|
25 |
data args:
|
@@ -29,10 +35,26 @@ class EndpointHandler():
|
|
29 |
A :obj:`list` | `dict`: will be serialized and returned
|
30 |
"""
|
31 |
inputs = data.pop("inputs")
|
32 |
-
|
33 |
prompt = inputs.pop("prompt")
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
import os
|
5 |
from typing import Dict, List, Any
|
6 |
+
import cv2
|
7 |
import groundingdino
|
8 |
from groundingdino.util.inference import load_model, load_image, predict, annotate
|
9 |
+
import tempfile
|
10 |
|
11 |
# /app
|
12 |
HOME = os.getcwd()
|
|
|
23 |
|
24 |
self.model = load_model(CONFIG_PATH, os.path.join(path, "weights", "groundingdino_swint_ogc.pth"))
|
25 |
|
26 |
+
self.box_threshold = 0.35
|
27 |
+
self.text_threshold = 0.25
|
28 |
+
|
29 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
30 |
"""
|
31 |
data args:
|
|
|
35 |
A :obj:`list` | `dict`: will be serialized and returned
|
36 |
"""
|
37 |
inputs = data.pop("inputs")
|
38 |
+
image_base64 = inputs.pop("image")
|
39 |
prompt = inputs.pop("prompt")
|
40 |
+
|
41 |
+
image_data = base64.b64decode(image_base64)
|
42 |
+
|
43 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=True) as f:
|
44 |
+
f.write(image_data)
|
45 |
+
image_source, image = load_image(f.name)
|
46 |
+
boxes, logits, phrases = predict(
|
47 |
+
model=self.model,
|
48 |
+
image=image,
|
49 |
+
caption=prompt,
|
50 |
+
box_threshold=self.box_threshold,
|
51 |
+
text_threshold=self.text_threshold
|
52 |
+
)
|
53 |
+
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
|
54 |
+
_, annotated_image = cv2.imencode(".jpg", annotated_frame)
|
55 |
+
annotated_image_b64 = base64.b64encode(annotated_image).decode("utf-8")
|
56 |
+
|
57 |
+
return [{
|
58 |
+
"image": annotated_image_b64,
|
59 |
+
"prompt": prompt,
|
60 |
+
}]
|