Spaces:
Sleeping
Sleeping
Initial commit of the Label Studio Segment Anything space
Browse filesImplementation of a Label Studio ML backend using MobileSAM
for image segmentation.
- Dockerfile +39 -0
- _wsgi.py +113 -0
- download_models.sh +23 -0
- model.py +145 -0
- requirements.txt +13 -0
- sam_predictor.py +198 -0
- start.sh +4 -0
Dockerfile
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8-slimjjjjjj
|
2 |
+
|
3 |
+
# Install Dependencies
|
4 |
+
RUN apt-get update -q \
|
5 |
+
&& apt-get install -qy --no-install-recommends wget git libopencv-dev python3-opencv \
|
6 |
+
&& apt-get autoremove -y \
|
7 |
+
&& apt-get clean \
|
8 |
+
&& rm -rf /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
# Set up a non-root user
|
11 |
+
RUN useradd -m -u 1000 user \
|
12 |
+
&& mkdir /app \
|
13 |
+
&& chown -R user /app
|
14 |
+
|
15 |
+
# Switch to the "user" user
|
16 |
+
USER user
|
17 |
+
|
18 |
+
# Set the working directory to the user's home directory
|
19 |
+
WORKDIR /app
|
20 |
+
|
21 |
+
ENV PYTHONUNBUFFERED=True \
|
22 |
+
VITH_CHECKPOINT=/app/models/sam_vit_h_4b8939.pth \
|
23 |
+
MOBILESAM_CHECKPOINT=/app/models/mobile_sam.pt \
|
24 |
+
ONNX_CHECKPOINT=/app/models/sam_onnx_quantized_example.onnx \
|
25 |
+
PORT=7860
|
26 |
+
|
27 |
+
# Copy and run the model download script
|
28 |
+
COPY download_models.sh .
|
29 |
+
RUN bash /app/download_models.sh
|
30 |
+
|
31 |
+
# Install Python dependencies
|
32 |
+
COPY requirements.txt .
|
33 |
+
RUN pip install --user --no-cache-dir -r requirements.txt
|
34 |
+
|
35 |
+
COPY . ./
|
36 |
+
|
37 |
+
EXPOSE 7860
|
38 |
+
|
39 |
+
CMD ["/app/start.sh"]
|
_wsgi.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import logging.config
|
5 |
+
import json
|
6 |
+
|
7 |
+
logging.config.dictConfig({
|
8 |
+
"version": 1,
|
9 |
+
"formatters": {
|
10 |
+
"standard": {
|
11 |
+
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"handlers": {
|
15 |
+
"console": {
|
16 |
+
"class": "logging.StreamHandler",
|
17 |
+
"level": os.getenv('LOG_LEVEL', 'INFO'),
|
18 |
+
"stream": "ext://sys.stdout",
|
19 |
+
"formatter": "standard"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"root": {
|
23 |
+
"level": os.getenv('LOG_LEVEL', 'INFO'),
|
24 |
+
"handlers": [
|
25 |
+
"console"
|
26 |
+
],
|
27 |
+
"propagate": True
|
28 |
+
}
|
29 |
+
})
|
30 |
+
|
31 |
+
from label_studio_ml.api import init_app
|
32 |
+
from model import SamMLBackend
|
33 |
+
|
34 |
+
_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')
|
35 |
+
|
36 |
+
|
37 |
+
def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
|
38 |
+
if not os.path.exists(config_path):
|
39 |
+
return dict()
|
40 |
+
with open(config_path) as f:
|
41 |
+
config = json.load(f)
|
42 |
+
assert isinstance(config, dict)
|
43 |
+
return config
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
parser = argparse.ArgumentParser(description='Label studio')
|
48 |
+
parser.add_argument(
|
49 |
+
'-p', '--port', dest='port', type=int, default=9090,
|
50 |
+
help='Server port')
|
51 |
+
parser.add_argument(
|
52 |
+
'--host', dest='host', type=str, default='0.0.0.0',
|
53 |
+
help='Server host')
|
54 |
+
parser.add_argument(
|
55 |
+
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
|
56 |
+
help='Additional LabelStudioMLBase model initialization kwargs')
|
57 |
+
parser.add_argument(
|
58 |
+
'-d', '--debug', dest='debug', action='store_true',
|
59 |
+
help='Switch debug mode')
|
60 |
+
parser.add_argument(
|
61 |
+
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
|
62 |
+
help='Logging level')
|
63 |
+
parser.add_argument(
|
64 |
+
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
|
65 |
+
help='Directory where models are stored (relative to the project directory)')
|
66 |
+
parser.add_argument(
|
67 |
+
'--check', dest='check', action='store_true',
|
68 |
+
help='Validate model instance before launching server')
|
69 |
+
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
# setup logging level
|
73 |
+
if args.log_level:
|
74 |
+
logging.root.setLevel(args.log_level)
|
75 |
+
|
76 |
+
def isfloat(value):
|
77 |
+
try:
|
78 |
+
float(value)
|
79 |
+
return True
|
80 |
+
except ValueError:
|
81 |
+
return False
|
82 |
+
|
83 |
+
def parse_kwargs():
|
84 |
+
param = dict()
|
85 |
+
for k, v in args.kwargs:
|
86 |
+
if v.isdigit():
|
87 |
+
param[k] = int(v)
|
88 |
+
elif v == 'True' or v == 'true':
|
89 |
+
param[k] = True
|
90 |
+
elif v == 'False' or v == 'False':
|
91 |
+
param[k] = False
|
92 |
+
elif isfloat(v):
|
93 |
+
param[k] = float(v)
|
94 |
+
else:
|
95 |
+
param[k] = v
|
96 |
+
return param
|
97 |
+
|
98 |
+
kwargs = get_kwargs_from_config()
|
99 |
+
|
100 |
+
if args.kwargs:
|
101 |
+
kwargs.update(parse_kwargs())
|
102 |
+
|
103 |
+
if args.check:
|
104 |
+
print('Check "' + SamMLBackend.__name__ + '" instance creation..')
|
105 |
+
model = SamMLBackend(**kwargs)
|
106 |
+
|
107 |
+
app = init_app(model_class=SamMLBackend)
|
108 |
+
|
109 |
+
app.run(host=args.host, port=args.port, debug=args.debug)
|
110 |
+
|
111 |
+
else:
|
112 |
+
# for uWSGI use
|
113 |
+
app = init_app(model_class=SamMLBackend)
|
download_models.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
MODELS_DIR="models"
|
4 |
+
mkdir -p ${MODELS_DIR}
|
5 |
+
|
6 |
+
download_model() {
|
7 |
+
FILE_PATH="${MODELS_DIR}/$1"
|
8 |
+
URL="$2"
|
9 |
+
|
10 |
+
if [ ! -f "${FILE_PATH}" ]; then
|
11 |
+
wget -q "${URL}" -P ${MODELS_DIR}/
|
12 |
+
fi
|
13 |
+
}
|
14 |
+
|
15 |
+
# Model files and their corresponding URLs
|
16 |
+
declare -A MODELS
|
17 |
+
# We just run with MobileSAM for this example
|
18 |
+
# MODELS["sam_vit_h_4b8939.pth"]="https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
19 |
+
MODELS["mobile_sam.pt"]="https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
|
20 |
+
|
21 |
+
for model in "${!MODELS[@]}"; do
|
22 |
+
download_model "${model}" "${MODELS[${model}]}"
|
23 |
+
done
|
model.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from label_studio_converter import brush
|
4 |
+
from typing import List, Dict, Optional
|
5 |
+
from uuid import uuid4
|
6 |
+
from sam_predictor import SAMPredictor
|
7 |
+
from label_studio_ml.model import LabelStudioMLBase
|
8 |
+
|
9 |
+
SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM") # other option is just SAM
|
10 |
+
PREDICTOR = SAMPredictor(SAM_CHOICE)
|
11 |
+
|
12 |
+
|
13 |
+
class SamMLBackend(LabelStudioMLBase):
|
14 |
+
|
15 |
+
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
|
16 |
+
""" Returns the predicted mask for a smart keypoint that has been placed."""
|
17 |
+
|
18 |
+
from_name, to_name, value = self.get_first_tag_occurence('BrushLabels', 'Image')
|
19 |
+
|
20 |
+
if not context or not context.get('result'):
|
21 |
+
# if there is no context, no interaction has happened yet
|
22 |
+
return []
|
23 |
+
|
24 |
+
image_width = context['result'][0]['original_width']
|
25 |
+
image_height = context['result'][0]['original_height']
|
26 |
+
|
27 |
+
# collect context information
|
28 |
+
point_coords = []
|
29 |
+
point_labels = []
|
30 |
+
input_box = None
|
31 |
+
selected_label = None
|
32 |
+
for ctx in context['result']:
|
33 |
+
x = ctx['value']['x'] * image_width / 100
|
34 |
+
y = ctx['value']['y'] * image_height / 100
|
35 |
+
ctx_type = ctx['type']
|
36 |
+
selected_label = ctx['value'][ctx_type][0]
|
37 |
+
if ctx_type == 'keypointlabels':
|
38 |
+
point_labels.append(int(ctx['is_positive']))
|
39 |
+
point_coords.append([int(x), int(y)])
|
40 |
+
elif ctx_type == 'rectanglelabels':
|
41 |
+
box_width = ctx['value']['width'] * image_width / 100
|
42 |
+
box_height = ctx['value']['height'] * image_height / 100
|
43 |
+
input_box = [int(x), int(y), int(box_width + x), int(box_height + y)]
|
44 |
+
|
45 |
+
print(f'Point coords are {point_coords}, point labels are {point_labels}, input box is {input_box}')
|
46 |
+
|
47 |
+
img_path = tasks[0]['data'][value]
|
48 |
+
predictor_results = PREDICTOR.predict(
|
49 |
+
img_path=img_path,
|
50 |
+
point_coords=point_coords or None,
|
51 |
+
point_labels=point_labels or None,
|
52 |
+
input_box=input_box
|
53 |
+
)
|
54 |
+
|
55 |
+
predictions = self.get_results(
|
56 |
+
masks=predictor_results['masks'],
|
57 |
+
probs=predictor_results['probs'],
|
58 |
+
width=image_width,
|
59 |
+
height=image_height,
|
60 |
+
from_name=from_name,
|
61 |
+
to_name=to_name,
|
62 |
+
label=selected_label)
|
63 |
+
|
64 |
+
return predictions
|
65 |
+
|
66 |
+
def get_results(self, masks, probs, width, height, from_name, to_name, label):
|
67 |
+
results = []
|
68 |
+
for mask, prob in zip(masks, probs):
|
69 |
+
# creates a random ID for your label everytime so no chance for errors
|
70 |
+
label_id = str(uuid4())[:4]
|
71 |
+
# converting the mask from the model to RLE format which is usable in Label Studio
|
72 |
+
mask = mask * 255
|
73 |
+
rle = brush.mask2rle(mask)
|
74 |
+
|
75 |
+
results.append({
|
76 |
+
'id': label_id,
|
77 |
+
'from_name': from_name,
|
78 |
+
'to_name': to_name,
|
79 |
+
'original_width': width,
|
80 |
+
'original_height': height,
|
81 |
+
'image_rotation': 0,
|
82 |
+
'value': {
|
83 |
+
'format': 'rle',
|
84 |
+
'rle': rle,
|
85 |
+
'brushlabels': [label],
|
86 |
+
},
|
87 |
+
'score': prob,
|
88 |
+
'type': 'brushlabels',
|
89 |
+
'readonly': False
|
90 |
+
})
|
91 |
+
|
92 |
+
return [{
|
93 |
+
'result': results,
|
94 |
+
'model_version': PREDICTOR.model_name
|
95 |
+
}]
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
# test the model
|
100 |
+
model = SamMLBackend()
|
101 |
+
model.use_label_config('''
|
102 |
+
<View>
|
103 |
+
<Image name="image" value="$image" zoom="true"/>
|
104 |
+
<BrushLabels name="tag" toName="image">
|
105 |
+
<Label value="Banana" background="#FF0000"/>
|
106 |
+
<Label value="Orange" background="#0d14d3"/>
|
107 |
+
</BrushLabels>
|
108 |
+
<KeyPointLabels name="tag2" toName="image" smart="true" >
|
109 |
+
<Label value="Banana" background="#000000" showInline="true"/>
|
110 |
+
<Label value="Orange" background="#000000" showInline="true"/>
|
111 |
+
</KeyPointLabels>
|
112 |
+
<RectangleLabels name="tag3" toName="image" >
|
113 |
+
<Label value="Banana" background="#000000" showInline="true"/>
|
114 |
+
<Label value="Orange" background="#000000" showInline="true"/>
|
115 |
+
</RectangleLabels>
|
116 |
+
</View>
|
117 |
+
''')
|
118 |
+
results = model.predict(
|
119 |
+
tasks=[{
|
120 |
+
'data': {
|
121 |
+
'image': 'https://s3.amazonaws.com/htx-pub/datasets/images/125245483_152578129892066_7843809718842085333_n.jpg'
|
122 |
+
}}],
|
123 |
+
context={
|
124 |
+
'result': [{
|
125 |
+
'original_width': 1080,
|
126 |
+
'original_height': 1080,
|
127 |
+
'image_rotation': 0,
|
128 |
+
'value': {
|
129 |
+
'x': 49.441786283891545,
|
130 |
+
'y': 59.96810207336522,
|
131 |
+
'width': 0.3189792663476874,
|
132 |
+
'labels': ['Banana'],
|
133 |
+
'keypointlabels': ['Banana']
|
134 |
+
},
|
135 |
+
'is_positive': True,
|
136 |
+
'id': 'fBWv1t0S2L',
|
137 |
+
'from_name': 'tag2',
|
138 |
+
'to_name': 'image',
|
139 |
+
'type': 'keypointlabels',
|
140 |
+
'origin': 'manual'
|
141 |
+
}]}
|
142 |
+
)
|
143 |
+
import json
|
144 |
+
results[0]['result'][0]['value']['rle'] = f'...{len(results[0]["result"][0]["value"]["rle"])} integers...'
|
145 |
+
print(json.dumps(results, indent=2))
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
label_studio_converter
|
2 |
+
opencv-python
|
3 |
+
onnxruntime
|
4 |
+
onnx
|
5 |
+
torch==2.0.1
|
6 |
+
torchvision==0.15.2
|
7 |
+
gunicorn==20.1.0
|
8 |
+
rq==1.10.1
|
9 |
+
timm==0.4.12
|
10 |
+
|
11 |
+
segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
|
12 |
+
mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
|
13 |
+
label-studio-ml @ git+https://github.com/heartexlabs/label-studio-ml-backend.git
|
sam_predictor.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import torch
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from typing import List, Dict, Optional
|
8 |
+
from label_studio_ml.utils import get_image_local_path, InMemoryLRUDictCache
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
VITH_CHECKPOINT = os.environ.get("VITH_CHECKPOINT")
|
13 |
+
ONNX_CHECKPOINT = os.environ.get("ONNX_CHECKPOINT")
|
14 |
+
MOBILESAM_CHECKPOINT = os.environ.get("MOBILESAM_CHECKPOINT", "mobile_sam.pt")
|
15 |
+
LABEL_STUDIO_ACCESS_TOKEN = os.environ.get("LABEL_STUDIO_ACCESS_TOKEN")
|
16 |
+
LABEL_STUDIO_HOST = os.environ.get("LABEL_STUDIO_HOST")
|
17 |
+
|
18 |
+
|
19 |
+
class SAMPredictor(object):
|
20 |
+
|
21 |
+
def __init__(self, model_choice):
|
22 |
+
self.model_choice = model_choice
|
23 |
+
|
24 |
+
# cache for embeddings
|
25 |
+
# TODO: currently it supports only one image in cache,
|
26 |
+
# since predictor.set_image() should be called each time the new image comes
|
27 |
+
# before making predictions
|
28 |
+
# to extend it to >1 image, we need to store the "active image" state in the cache
|
29 |
+
self.cache = InMemoryLRUDictCache(1)
|
30 |
+
|
31 |
+
# if you're not using CUDA, use "cpu" instead .... good luck not burning your computer lol
|
32 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
logger.debug(f"Using device {self.device}")
|
34 |
+
|
35 |
+
if model_choice == 'ONNX':
|
36 |
+
import onnxruntime
|
37 |
+
from segment_anything import sam_model_registry, SamPredictor
|
38 |
+
|
39 |
+
self.model_checkpoint = VITH_CHECKPOINT
|
40 |
+
if self.model_checkpoint is None:
|
41 |
+
raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")
|
42 |
+
if ONNX_CHECKPOINT is None:
|
43 |
+
raise FileNotFoundError("ONNX_CHECKPOINT is not set: please set it to the path to the ONNX checkpoint")
|
44 |
+
logger.info(f"Using ONNX checkpoint {ONNX_CHECKPOINT} and SAM checkpoint {self.model_checkpoint}")
|
45 |
+
|
46 |
+
self.ort = onnxruntime.InferenceSession(ONNX_CHECKPOINT)
|
47 |
+
reg_key = "vit_h"
|
48 |
+
|
49 |
+
elif model_choice == 'SAM':
|
50 |
+
from segment_anything import SamPredictor, sam_model_registry
|
51 |
+
|
52 |
+
self.model_checkpoint = VITH_CHECKPOINT
|
53 |
+
if self.model_checkpoint is None:
|
54 |
+
raise FileNotFoundError("VITH_CHECKPOINT is not set: please set it to the path to the SAM checkpoint")
|
55 |
+
|
56 |
+
logger.info(f"Using SAM checkpoint {self.model_checkpoint}")
|
57 |
+
reg_key = "vit_h"
|
58 |
+
|
59 |
+
elif model_choice == 'MobileSAM':
|
60 |
+
from mobile_sam import SamPredictor, sam_model_registry
|
61 |
+
|
62 |
+
self.model_checkpoint = MOBILESAM_CHECKPOINT
|
63 |
+
if not self.model_checkpoint:
|
64 |
+
raise FileNotFoundError("MOBILE_CHECKPOINT is not set: please set it to the path to the MobileSAM checkpoint")
|
65 |
+
logger.info(f"Using MobileSAM checkpoint {self.model_checkpoint}")
|
66 |
+
reg_key = 'vit_t'
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Invalid model choice {model_choice}")
|
69 |
+
|
70 |
+
sam = sam_model_registry[reg_key](checkpoint=self.model_checkpoint)
|
71 |
+
sam.to(device=self.device)
|
72 |
+
self.predictor = SamPredictor(sam)
|
73 |
+
|
74 |
+
@property
|
75 |
+
def model_name(self):
|
76 |
+
return f'{self.model_choice}:{self.model_checkpoint}:{self.device}'
|
77 |
+
|
78 |
+
def set_image(self, img_path, calculate_embeddings=True):
|
79 |
+
payload = self.cache.get(img_path)
|
80 |
+
if payload is None:
|
81 |
+
# Get image and embeddings
|
82 |
+
logger.debug(f'Payload not found for {img_path} in `IN_MEM_CACHE`: calculating from scratch')
|
83 |
+
image_path = get_image_local_path(
|
84 |
+
img_path,
|
85 |
+
label_studio_access_token=LABEL_STUDIO_ACCESS_TOKEN,
|
86 |
+
label_studio_host=LABEL_STUDIO_HOST
|
87 |
+
)
|
88 |
+
image = cv2.imread(image_path)
|
89 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
90 |
+
self.predictor.set_image(image)
|
91 |
+
payload = {'image_shape': image.shape[:2]}
|
92 |
+
logger.debug(f'Finished set_image({img_path}) in `IN_MEM_CACHE`: image shape {image.shape[:2]}')
|
93 |
+
if calculate_embeddings:
|
94 |
+
image_embedding = self.predictor.get_image_embedding().cpu().numpy()
|
95 |
+
payload['image_embedding'] = image_embedding
|
96 |
+
logger.debug(f'Finished storing embeddings for {img_path} in `IN_MEM_CACHE`: '
|
97 |
+
f'embedding shape {image_embedding.shape}')
|
98 |
+
self.cache.put(img_path, payload)
|
99 |
+
else:
|
100 |
+
logger.debug(f"Using embeddings for {img_path} from `IN_MEM_CACHE`")
|
101 |
+
return payload
|
102 |
+
|
103 |
+
def predict_onnx(
|
104 |
+
self,
|
105 |
+
img_path,
|
106 |
+
point_coords: Optional[List[List]] = None,
|
107 |
+
point_labels: Optional[List] = None,
|
108 |
+
input_box: Optional[List] = None
|
109 |
+
):
|
110 |
+
# calculate embeddings
|
111 |
+
payload = self.set_image(img_path, calculate_embeddings=True)
|
112 |
+
image_shape = payload['image_shape']
|
113 |
+
image_embedding = payload['image_embedding']
|
114 |
+
|
115 |
+
onnx_point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
|
116 |
+
onnx_point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
|
117 |
+
onnx_box_coords = np.array(input_box, dtype=np.float32).reshape(2, 2) if input_box else None
|
118 |
+
|
119 |
+
onnx_coords, onnx_labels = None, None
|
120 |
+
if onnx_point_coords is not None and onnx_box_coords is not None:
|
121 |
+
# both keypoints and boxes are present
|
122 |
+
onnx_coords = np.concatenate([onnx_point_coords, onnx_box_coords], axis=0)[None, :, :]
|
123 |
+
onnx_labels = np.concatenate([onnx_point_labels, np.array([2, 3])], axis=0)[None, :].astype(np.float32)
|
124 |
+
|
125 |
+
elif onnx_point_coords is not None:
|
126 |
+
# only keypoints are present
|
127 |
+
onnx_coords = np.concatenate([onnx_point_coords, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
|
128 |
+
onnx_labels = np.concatenate([onnx_point_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
|
129 |
+
|
130 |
+
elif onnx_box_coords is not None:
|
131 |
+
# only boxes are present
|
132 |
+
raise NotImplementedError("Boxes without keypoints are not supported yet")
|
133 |
+
|
134 |
+
onnx_coords = self.predictor.transform.apply_coords(onnx_coords, image_shape).astype(np.float32)
|
135 |
+
|
136 |
+
# TODO: support mask inputs
|
137 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
138 |
+
|
139 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
140 |
+
|
141 |
+
ort_inputs = {
|
142 |
+
"image_embeddings": image_embedding,
|
143 |
+
"point_coords": onnx_coords,
|
144 |
+
"point_labels": onnx_labels,
|
145 |
+
"mask_input": onnx_mask_input,
|
146 |
+
"has_mask_input": onnx_has_mask_input,
|
147 |
+
"orig_im_size": np.array(image_shape, dtype=np.float32)
|
148 |
+
}
|
149 |
+
|
150 |
+
masks, prob, low_res_logits = self.ort.run(None, ort_inputs)
|
151 |
+
masks = masks > self.predictor.model.mask_threshold
|
152 |
+
mask = masks[0, 0, :, :].astype(np.uint8) # each mask has shape [H, W]
|
153 |
+
prob = float(prob[0][0])
|
154 |
+
# TODO: support the real multimask output as in https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
|
155 |
+
return {
|
156 |
+
'masks': [mask],
|
157 |
+
'probs': [prob]
|
158 |
+
}
|
159 |
+
|
160 |
+
def predict_sam(
|
161 |
+
self,
|
162 |
+
img_path,
|
163 |
+
point_coords: Optional[List[List]] = None,
|
164 |
+
point_labels: Optional[List] = None,
|
165 |
+
input_box: Optional[List] = None
|
166 |
+
):
|
167 |
+
self.set_image(img_path, calculate_embeddings=False)
|
168 |
+
point_coords = np.array(point_coords, dtype=np.float32) if point_coords else None
|
169 |
+
point_labels = np.array(point_labels, dtype=np.float32) if point_labels else None
|
170 |
+
input_box = np.array(input_box, dtype=np.float32) if input_box else None
|
171 |
+
|
172 |
+
masks, probs, logits = self.predictor.predict(
|
173 |
+
point_coords=point_coords,
|
174 |
+
point_labels=point_labels,
|
175 |
+
box=input_box,
|
176 |
+
# TODO: support multimask output
|
177 |
+
multimask_output=False
|
178 |
+
)
|
179 |
+
mask = masks[0, :, :].astype(np.uint8) # each mask has shape [H, W]
|
180 |
+
prob = float(probs[0])
|
181 |
+
return {
|
182 |
+
'masks': [mask],
|
183 |
+
'probs': [prob]
|
184 |
+
}
|
185 |
+
|
186 |
+
def predict(
|
187 |
+
self, img_path: str,
|
188 |
+
point_coords: Optional[List[List]] = None,
|
189 |
+
point_labels: Optional[List] = None,
|
190 |
+
input_box: Optional[List] = None
|
191 |
+
):
|
192 |
+
if self.model_choice == 'ONNX':
|
193 |
+
return self.predict_onnx(img_path, point_coords, point_labels, input_box)
|
194 |
+
elif self.model_choice in ('SAM', 'MobileSAM'):
|
195 |
+
return self.predict_sam(img_path, point_coords, point_labels, input_box)
|
196 |
+
else:
|
197 |
+
raise NotImplementedError(f"Model choice {self.model_choice} is not supported yet")
|
198 |
+
|
start.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Execute the gunicorn command
|
4 |
+
exec /home/user/.local/bin/gunicorn --preload --bind :$PORT --workers 1 --threads 8 --timeout 0 _wsgi:app
|