Tony Neel
commited on
Commit
·
bf7dfcc
1
Parent(s):
796780d
add test endpoints and make handler work
Browse files- README.md +12 -0
- __pycache__/handler.cpython-310.pyc +0 -0
- handler.py +58 -59
- test_flask.py +44 -0
- test_local.py +48 -0
README.md
CHANGED
@@ -8,8 +8,20 @@ Repository for SAM 2: Segment Anything in Images and Videos, a foundation model
|
|
8 |
|
9 |
The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
## Usage
|
12 |
|
|
|
|
|
13 |
For image prediction:
|
14 |
|
15 |
```python
|
|
|
8 |
|
9 |
The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
|
10 |
|
11 |
+
# SAM2 Small Inference Endpoint
|
12 |
+
|
13 |
+
This repository contains the code for running SAM2 (Segment Anything Model 2) small model as a Hugging Face inference endpoint.
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
|
17 |
+
- Model: SAM2 Hiera Small
|
18 |
+
- Source: facebook/sam2-hiera-small
|
19 |
+
- Type: Segmentation model
|
20 |
+
|
21 |
## Usage
|
22 |
|
23 |
+
Send a POST request with an image to get segmentation masks:
|
24 |
+
|
25 |
For image prediction:
|
26 |
|
27 |
```python
|
__pycache__/handler.cpython-310.pyc
ADDED
Binary file (2.2 kB). View file
|
|
handler.py
CHANGED
@@ -5,11 +5,34 @@ import numpy as np
|
|
5 |
from PIL import Image
|
6 |
import io
|
7 |
import base64
|
|
|
8 |
|
9 |
-
class EndpointHandler:
|
10 |
-
def __init__(self
|
11 |
-
"""Initialize the handler with
|
|
|
12 |
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
|
15 |
"""Load image from binary or base64 data"""
|
@@ -24,67 +47,43 @@ class EndpointHandler:
|
|
24 |
except Exception as e:
|
25 |
raise ValueError(f"Failed to load image: {str(e)}")
|
26 |
|
27 |
-
def __call__(self,
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
Dictionary containing masks and scores
|
37 |
-
"""
|
38 |
-
try:
|
39 |
-
# Handle different input formats
|
40 |
-
if isinstance(data, dict):
|
41 |
-
image_data = data.get("inputs", data)
|
42 |
-
# Get optional point prompts
|
43 |
-
point_coords = data.get("point_coords", None)
|
44 |
-
point_labels = data.get("point_labels", None)
|
45 |
-
else:
|
46 |
-
image_data = data
|
47 |
-
point_coords = None
|
48 |
-
point_labels = None
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
point_coords = np.array(point_coords)
|
61 |
-
point_labels = np.array(point_labels)
|
62 |
-
masks, scores, logits = self.predictor.predict(
|
63 |
point_coords=point_coords,
|
64 |
point_labels=point_labels
|
65 |
)
|
66 |
-
else:
|
67 |
-
# Default automatic mask generation
|
68 |
-
masks, scores, logits = self.predictor.predict()
|
69 |
-
|
70 |
-
# Convert outputs to JSON-serializable format
|
71 |
-
if masks is not None:
|
72 |
-
masks = [mask.tolist() for mask in masks]
|
73 |
-
scores = scores.tolist() if scores is not None else None
|
74 |
-
|
75 |
-
return {
|
76 |
-
"masks": masks,
|
77 |
-
"scores": scores,
|
78 |
-
"status": "success"
|
79 |
-
}
|
80 |
else:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
85 |
|
86 |
-
|
|
|
87 |
return {
|
88 |
-
"
|
89 |
-
"
|
90 |
-
|
|
|
|
|
|
5 |
from PIL import Image
|
6 |
import io
|
7 |
import base64
|
8 |
+
from huggingface_hub import InferenceEndpoint
|
9 |
|
10 |
+
class EndpointHandler(InferenceEndpoint):
|
11 |
+
def __init__(self):
|
12 |
+
"""Initialize the handler with mock predictor for local testing"""
|
13 |
+
# Comment out real model for local testing
|
14 |
self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
|
15 |
+
|
16 |
+
# Mock predictor for local testing
|
17 |
+
# class MockPredictor:
|
18 |
+
# def set_image(self, image):
|
19 |
+
# print(f"Mock: set_image called with shape {image.shape}")
|
20 |
+
|
21 |
+
# def predict(self, point_coords=None, point_labels=None):
|
22 |
+
# print("Mock: predict called")
|
23 |
+
# if point_coords is not None:
|
24 |
+
# print(f"Mock: with point coords {point_coords}")
|
25 |
+
# print(f"Mock: with point labels {point_labels}")
|
26 |
+
# # Return mock mask focused around the point
|
27 |
+
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)]
|
28 |
+
# mock_scores = np.array([0.95]) # Higher confidence for point prompt
|
29 |
+
# else:
|
30 |
+
# # Return multiple mock masks for automatic mode
|
31 |
+
# mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)]
|
32 |
+
# mock_scores = np.array([0.9, 0.8, 0.7])
|
33 |
+
# return mock_masks, mock_scores, None
|
34 |
+
|
35 |
+
self.predictor = MockPredictor()
|
36 |
|
37 |
def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
|
38 |
"""Load image from binary or base64 data"""
|
|
|
47 |
except Exception as e:
|
48 |
raise ValueError(f"Failed to load image: {str(e)}")
|
49 |
|
50 |
+
def __call__(self, image_bytes):
|
51 |
+
# Get point prompts if provided in request
|
52 |
+
if isinstance(image_bytes, dict):
|
53 |
+
point_coords = image_bytes.get('point_coords')
|
54 |
+
point_labels = image_bytes.get('point_labels')
|
55 |
+
image_bytes = image_bytes['image']
|
56 |
+
else:
|
57 |
+
point_coords = None
|
58 |
+
point_labels = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# Convert bytes to image
|
61 |
+
image = Image.open(io.BytesIO(image_bytes))
|
62 |
+
if image.mode != 'RGB':
|
63 |
+
image = image.convert('RGB')
|
64 |
+
image_array = np.array(image)
|
65 |
|
66 |
+
# Run inference (will use mock predictor locally)
|
67 |
+
with torch.inference_mode():
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
70 |
+
self.predictor.set_image(image_array)
|
71 |
+
masks, scores, _ = self.predictor.predict(
|
|
|
|
|
|
|
72 |
point_coords=point_coords,
|
73 |
point_labels=point_labels
|
74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
+
self.predictor.set_image(image_array)
|
77 |
+
masks, scores, _ = self.predictor.predict(
|
78 |
+
point_coords=point_coords,
|
79 |
+
point_labels=point_labels
|
80 |
+
)
|
81 |
|
82 |
+
# Format output
|
83 |
+
if masks is not None:
|
84 |
return {
|
85 |
+
"masks": [mask.tolist() for mask in masks],
|
86 |
+
"scores": scores.tolist() if scores is not None else None,
|
87 |
+
"status": "success"
|
88 |
+
}
|
89 |
+
return {"error": "No masks generated", "status": "error"}
|
test_flask.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
from handler import EndpointHandler
|
3 |
+
import torch
|
4 |
+
|
5 |
+
app = Flask(__name__)
|
6 |
+
|
7 |
+
# Initialize the handler
|
8 |
+
handler = EndpointHandler()
|
9 |
+
|
10 |
+
@app.route('/predict', methods=['POST'])
|
11 |
+
def predict():
|
12 |
+
if 'file' not in request.files:
|
13 |
+
return jsonify({'error': 'No file provided'}), 400
|
14 |
+
|
15 |
+
file = request.files['file']
|
16 |
+
if file.filename == '':
|
17 |
+
return jsonify({'error': 'No file selected'}), 400
|
18 |
+
|
19 |
+
# Read the file bytes
|
20 |
+
image_bytes = file.read()
|
21 |
+
|
22 |
+
# Get point prompts if provided
|
23 |
+
point_coords = request.form.get('point_coords')
|
24 |
+
point_labels = request.form.get('point_labels')
|
25 |
+
|
26 |
+
# Process with handler
|
27 |
+
try:
|
28 |
+
if point_coords and point_labels:
|
29 |
+
# Convert string inputs to lists
|
30 |
+
point_coords = eval(point_coords) # e.g. "[[500, 375]]"
|
31 |
+
point_labels = eval(point_labels) # e.g. "[1]"
|
32 |
+
result = handler({
|
33 |
+
'image': image_bytes,
|
34 |
+
'point_coords': point_coords,
|
35 |
+
'point_labels': point_labels
|
36 |
+
})
|
37 |
+
else:
|
38 |
+
result = handler(image_bytes)
|
39 |
+
return jsonify(result)
|
40 |
+
except Exception as e:
|
41 |
+
return jsonify({'error': str(e)}), 500
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
app.run(debug=True, port=5000)
|
test_local.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
def test_endpoint(image_path, point_coords=None, point_labels=None):
|
5 |
+
# URL for local Flask server
|
6 |
+
url = "http://localhost:5000/predict"
|
7 |
+
|
8 |
+
# Open image file
|
9 |
+
with open(image_path, 'rb') as f:
|
10 |
+
files = {'file': f}
|
11 |
+
data = {}
|
12 |
+
|
13 |
+
# Add point prompts if provided
|
14 |
+
if point_coords is not None and point_labels is not None:
|
15 |
+
data['point_coords'] = str(point_coords)
|
16 |
+
data['point_labels'] = str(point_labels)
|
17 |
+
|
18 |
+
# Make request
|
19 |
+
response = requests.post(url, files=files, data=data)
|
20 |
+
|
21 |
+
print(f"Status Code: {response.status_code}")
|
22 |
+
if response.status_code == 200:
|
23 |
+
result = response.json()
|
24 |
+
print("\nSuccess!")
|
25 |
+
print(f"Number of masks: {len(result['masks']) if 'masks' in result else 0}")
|
26 |
+
print(f"Scores: {result['scores'] if 'scores' in result else None}")
|
27 |
+
else:
|
28 |
+
print(f"Error: {response.text}")
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
# Test with your image
|
32 |
+
image_path = Path("images/20250121_gauge_0001.jpg")
|
33 |
+
if not image_path.exists():
|
34 |
+
print(f"Error: Image not found at {image_path}")
|
35 |
+
exit(1)
|
36 |
+
|
37 |
+
# Test without points
|
38 |
+
print("\nTesting without points...")
|
39 |
+
print(f"Testing with image: {image_path}")
|
40 |
+
test_endpoint(image_path)
|
41 |
+
|
42 |
+
# Test with points
|
43 |
+
print("\nTesting with points...")
|
44 |
+
test_endpoint(
|
45 |
+
image_path,
|
46 |
+
point_coords=[[500, 375]], # Example coordinates
|
47 |
+
point_labels=[1] # 1 for foreground
|
48 |
+
)
|