Tony Neel commited on
Commit
bf7dfcc
·
1 Parent(s): 796780d

add test endpoints and make handler work

Browse files
Files changed (5) hide show
  1. README.md +12 -0
  2. __pycache__/handler.cpython-310.pyc +0 -0
  3. handler.py +58 -59
  4. test_flask.py +44 -0
  5. test_local.py +48 -0
README.md CHANGED
@@ -8,8 +8,20 @@ Repository for SAM 2: Segment Anything in Images and Videos, a foundation model
8
 
9
  The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
10
 
 
 
 
 
 
 
 
 
 
 
11
  ## Usage
12
 
 
 
13
  For image prediction:
14
 
15
  ```python
 
8
 
9
  The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
10
 
11
+ # SAM2 Small Inference Endpoint
12
+
13
+ This repository contains the code for running SAM2 (Segment Anything Model 2) small model as a Hugging Face inference endpoint.
14
+
15
+ ## Model Details
16
+
17
+ - Model: SAM2 Hiera Small
18
+ - Source: facebook/sam2-hiera-small
19
+ - Type: Segmentation model
20
+
21
  ## Usage
22
 
23
+ Send a POST request with an image to get segmentation masks:
24
+
25
  For image prediction:
26
 
27
  ```python
__pycache__/handler.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
handler.py CHANGED
@@ -5,11 +5,34 @@ import numpy as np
5
  from PIL import Image
6
  import io
7
  import base64
 
8
 
9
- class EndpointHandler:
10
- def __init__(self, path=""):
11
- """Initialize the handler with SAM2 model"""
 
12
  self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
15
  """Load image from binary or base64 data"""
@@ -24,67 +47,43 @@ class EndpointHandler:
24
  except Exception as e:
25
  raise ValueError(f"Failed to load image: {str(e)}")
26
 
27
- def __call__(self, data: Union[Dict[str, Any], bytes]) -> Dict[str, Any]:
28
- """
29
- Handle incoming request data
30
- Args:
31
- data: Either raw bytes or dictionary containing:
32
- - image data (raw binary or base64)
33
- - optional point_coords: List of [x,y] coordinates for clicks
34
- - optional point_labels: List of 1 (foreground) or 0 (background)
35
- Returns:
36
- Dictionary containing masks and scores
37
- """
38
- try:
39
- # Handle different input formats
40
- if isinstance(data, dict):
41
- image_data = data.get("inputs", data)
42
- # Get optional point prompts
43
- point_coords = data.get("point_coords", None)
44
- point_labels = data.get("point_labels", None)
45
- else:
46
- image_data = data
47
- point_coords = None
48
- point_labels = None
49
 
50
- # Load and convert image
51
- image = self._load_image(image_data)
52
- image_array = np.array(image)
 
 
53
 
54
- # Process with SAM2
55
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
56
- self.predictor.set_image(image_array)
57
-
58
- # If point prompts provided, use them
59
- if point_coords is not None and point_labels is not None:
60
- point_coords = np.array(point_coords)
61
- point_labels = np.array(point_labels)
62
- masks, scores, logits = self.predictor.predict(
63
  point_coords=point_coords,
64
  point_labels=point_labels
65
  )
66
- else:
67
- # Default automatic mask generation
68
- masks, scores, logits = self.predictor.predict()
69
-
70
- # Convert outputs to JSON-serializable format
71
- if masks is not None:
72
- masks = [mask.tolist() for mask in masks]
73
- scores = scores.tolist() if scores is not None else None
74
-
75
- return {
76
- "masks": masks,
77
- "scores": scores,
78
- "status": "success"
79
- }
80
  else:
81
- return {
82
- "error": "No masks generated",
83
- "status": "error"
84
- }
 
85
 
86
- except Exception as e:
 
87
  return {
88
- "error": str(e),
89
- "status": "error"
90
- }
 
 
 
5
  from PIL import Image
6
  import io
7
  import base64
8
+ from huggingface_hub import InferenceEndpoint
9
 
10
+ class EndpointHandler(InferenceEndpoint):
11
+ def __init__(self):
12
+ """Initialize the handler with mock predictor for local testing"""
13
+ # Comment out real model for local testing
14
  self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
15
+
16
+ # Mock predictor for local testing
17
+ # class MockPredictor:
18
+ # def set_image(self, image):
19
+ # print(f"Mock: set_image called with shape {image.shape}")
20
+
21
+ # def predict(self, point_coords=None, point_labels=None):
22
+ # print("Mock: predict called")
23
+ # if point_coords is not None:
24
+ # print(f"Mock: with point coords {point_coords}")
25
+ # print(f"Mock: with point labels {point_labels}")
26
+ # # Return mock mask focused around the point
27
+ # mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)]
28
+ # mock_scores = np.array([0.95]) # Higher confidence for point prompt
29
+ # else:
30
+ # # Return multiple mock masks for automatic mode
31
+ # mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)]
32
+ # mock_scores = np.array([0.9, 0.8, 0.7])
33
+ # return mock_masks, mock_scores, None
34
+
35
+ self.predictor = MockPredictor()
36
 
37
  def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
38
  """Load image from binary or base64 data"""
 
47
  except Exception as e:
48
  raise ValueError(f"Failed to load image: {str(e)}")
49
 
50
+ def __call__(self, image_bytes):
51
+ # Get point prompts if provided in request
52
+ if isinstance(image_bytes, dict):
53
+ point_coords = image_bytes.get('point_coords')
54
+ point_labels = image_bytes.get('point_labels')
55
+ image_bytes = image_bytes['image']
56
+ else:
57
+ point_coords = None
58
+ point_labels = None
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Convert bytes to image
61
+ image = Image.open(io.BytesIO(image_bytes))
62
+ if image.mode != 'RGB':
63
+ image = image.convert('RGB')
64
+ image_array = np.array(image)
65
 
66
+ # Run inference (will use mock predictor locally)
67
+ with torch.inference_mode():
68
+ if torch.cuda.is_available():
69
+ with torch.autocast("cuda", dtype=torch.bfloat16):
70
+ self.predictor.set_image(image_array)
71
+ masks, scores, _ = self.predictor.predict(
 
 
 
72
  point_coords=point_coords,
73
  point_labels=point_labels
74
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  else:
76
+ self.predictor.set_image(image_array)
77
+ masks, scores, _ = self.predictor.predict(
78
+ point_coords=point_coords,
79
+ point_labels=point_labels
80
+ )
81
 
82
+ # Format output
83
+ if masks is not None:
84
  return {
85
+ "masks": [mask.tolist() for mask in masks],
86
+ "scores": scores.tolist() if scores is not None else None,
87
+ "status": "success"
88
+ }
89
+ return {"error": "No masks generated", "status": "error"}
test_flask.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from handler import EndpointHandler
3
+ import torch
4
+
5
+ app = Flask(__name__)
6
+
7
+ # Initialize the handler
8
+ handler = EndpointHandler()
9
+
10
+ @app.route('/predict', methods=['POST'])
11
+ def predict():
12
+ if 'file' not in request.files:
13
+ return jsonify({'error': 'No file provided'}), 400
14
+
15
+ file = request.files['file']
16
+ if file.filename == '':
17
+ return jsonify({'error': 'No file selected'}), 400
18
+
19
+ # Read the file bytes
20
+ image_bytes = file.read()
21
+
22
+ # Get point prompts if provided
23
+ point_coords = request.form.get('point_coords')
24
+ point_labels = request.form.get('point_labels')
25
+
26
+ # Process with handler
27
+ try:
28
+ if point_coords and point_labels:
29
+ # Convert string inputs to lists
30
+ point_coords = eval(point_coords) # e.g. "[[500, 375]]"
31
+ point_labels = eval(point_labels) # e.g. "[1]"
32
+ result = handler({
33
+ 'image': image_bytes,
34
+ 'point_coords': point_coords,
35
+ 'point_labels': point_labels
36
+ })
37
+ else:
38
+ result = handler(image_bytes)
39
+ return jsonify(result)
40
+ except Exception as e:
41
+ return jsonify({'error': str(e)}), 500
42
+
43
+ if __name__ == '__main__':
44
+ app.run(debug=True, port=5000)
test_local.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from pathlib import Path
3
+
4
+ def test_endpoint(image_path, point_coords=None, point_labels=None):
5
+ # URL for local Flask server
6
+ url = "http://localhost:5000/predict"
7
+
8
+ # Open image file
9
+ with open(image_path, 'rb') as f:
10
+ files = {'file': f}
11
+ data = {}
12
+
13
+ # Add point prompts if provided
14
+ if point_coords is not None and point_labels is not None:
15
+ data['point_coords'] = str(point_coords)
16
+ data['point_labels'] = str(point_labels)
17
+
18
+ # Make request
19
+ response = requests.post(url, files=files, data=data)
20
+
21
+ print(f"Status Code: {response.status_code}")
22
+ if response.status_code == 200:
23
+ result = response.json()
24
+ print("\nSuccess!")
25
+ print(f"Number of masks: {len(result['masks']) if 'masks' in result else 0}")
26
+ print(f"Scores: {result['scores'] if 'scores' in result else None}")
27
+ else:
28
+ print(f"Error: {response.text}")
29
+
30
+ if __name__ == "__main__":
31
+ # Test with your image
32
+ image_path = Path("images/20250121_gauge_0001.jpg")
33
+ if not image_path.exists():
34
+ print(f"Error: Image not found at {image_path}")
35
+ exit(1)
36
+
37
+ # Test without points
38
+ print("\nTesting without points...")
39
+ print(f"Testing with image: {image_path}")
40
+ test_endpoint(image_path)
41
+
42
+ # Test with points
43
+ print("\nTesting with points...")
44
+ test_endpoint(
45
+ image_path,
46
+ point_coords=[[500, 375]], # Example coordinates
47
+ point_labels=[1] # 1 for foreground
48
+ )