Singh
Initial deployment
36fcf33
"""
Hugging Face Spaces deployment for SAM2 Auto Annotation API.
This file serves as the entry point for the FastAPI application on Hugging Face Spaces.
"""
import sys
import os
# Add sam2 folder to path to import from local sam2 directory
_current_dir = os.path.dirname(os.path.abspath(__file__))
_sam2_dir = os.path.join(_current_dir, "sam2")
# Add sam2 directory to sys.path if not already there
abs_sam2_dir = os.path.abspath(_sam2_dir)
if abs_sam2_dir not in sys.path:
sys.path.insert(0, abs_sam2_dir)
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import cv2
import numpy as np
import torch
import psutil
import PIL.Image
from requests.exceptions import Timeout, RequestException
# Import sam2 from local folder
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from model.sam_model import predict_polygon, predict_polygon_from_point
from model.utils import load_image_from_url, mask_to_polygon
from model.sam2_detection_function import SAM2AutoAnnotation, create_sam2_auto_annotation
# Hugging Face model ID for SAM2.1 Hiera Large model
HUGGINGFACE_MODEL_ID = "facebook/sam2.1-hiera-large"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Global SAM2 auto annotation (initialized once)
sam2_auto_annotation_global = None
app = FastAPI(
title="SAM Auto Annotation API (BBox ➜ Polygon)",
description="AI-powered auto-annotation API using Meta's Segment Anything Model (SAM)",
version="1.0.0"
)
# Add CORS middleware to handle preflight OPTIONS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods including OPTIONS
allow_headers=["*"], # Allows all headers
)
@app.get("/")
def root():
"""Root endpoint - API information."""
return {
"status": "Service is up and running!",
"message": "Backend service is active",
"api": "SAM Auto Annotation API",
"version": "1.0.0"
}
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "healthy", "service": "same model segmenticAPI"}
@app.post("/segment")
def segment(data: dict):
"""
Segment image using SAM2 model to convert bounding box to polygon (CVAT-style).
Bbox is used as a prompt to identify the object, not as a constraint.
**Input:**
```json
{
"imageUrl": "https://example.com/image.jpg",
"bbox": {"x": 494.97, "y": 187.22, "width": 137.99, "height": 98.00, "label": "Object"},
"imageSize": {"width": 663.07, "height": 442}
}
```
OR
```json
{
"imageUrl": "https://example.com/image.jpg",
"bbox": [494.97, 187.22, 137.99, 98.00], // [x, y, width, height]
"imageSize": [663.07, 442] // [width, height]
}
```
**Output:**
```json
{
"polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates
"confidence": 0.96
}
```
"""
try:
# Validate input
if "imageUrl" not in data:
raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
if "bbox" not in data:
raise HTTPException(status_code=400, detail="Missing required field: bbox")
image_url = data["imageUrl"]
bbox = data["bbox"]
image_size = data.get("imageSize") # Optional: for coordinate scaling
# Validate bbox format
if isinstance(bbox, dict):
required_keys = ["x", "y", "width", "height"]
if not all(key in bbox for key in required_keys):
raise HTTPException(
status_code=400,
detail=f"bbox dict must contain: {required_keys}"
)
elif isinstance(bbox, list):
if len(bbox) != 4:
raise HTTPException(
status_code=400,
detail="bbox list must contain exactly 4 values: [x, y, width, height]"
)
else:
raise HTTPException(
status_code=400,
detail="bbox must be either a dict or a list"
)
# Validate imageSize format if provided
if image_size is not None:
if isinstance(image_size, dict):
if not ("width" in image_size and "height" in image_size):
raise HTTPException(
status_code=400,
detail="imageSize dict must contain 'width' and 'height'"
)
elif isinstance(image_size, list):
if len(image_size) != 2:
raise HTTPException(
status_code=400,
detail="imageSize list must contain exactly 2 values: [width, height]"
)
else:
raise HTTPException(
status_code=400,
detail="imageSize must be either a dict or a list"
)
# Load image from URL
img_bgr = load_image_from_url(image_url)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# Predict polygon using SAM2 (bbox as prompt, CVAT-style)
mask, confidence, scale_factors = predict_polygon(img_rgb, bbox, image_size)
# Convert mask to polygon (CVAT-style)
polygon = mask_to_polygon(mask, scale_factors)
if not polygon:
raise HTTPException(status_code=400, detail="No polygon found in mask")
return {
"polygon": polygon, # CVAT format: flattened coordinates
"confidence": confidence
}
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=500, detail=str(e))
except ImportError as e:
raise HTTPException(
status_code=500,
detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory"
)
except Timeout as e:
raise HTTPException(
status_code=504,
detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
)
except RequestException as e:
raise HTTPException(
status_code=502,
detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/segment/point")
def segment_from_point(data: dict):
"""
Segment image using SAM2 model with a point click to select object.
The point identifies which object to segment.
**Input:**
```json
{
"imageUrl": "https://example.com/image.jpg",
"point": {"x": 494.97, "y": 187.22},
"imageSize": {"width": 663.07, "height": 442}
}
```
OR
```json
{
"imageUrl": "https://example.com/image.jpg",
"point": [494.97, 187.22], // [x, y]
"imageSize": [663.07, 442] // [width, height]
}
```
**Output:**
```json
{
"polygon": [x1, y1, x2, y2, x3, y3, ...], // CVAT format: flattened coordinates
"confidence": 0.96
}
```
"""
try:
# Validate input
if "imageUrl" not in data:
raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
if "point" not in data:
raise HTTPException(status_code=400, detail="Missing required field: point")
image_url = data["imageUrl"]
point = data["point"]
image_size = data.get("imageSize") # Optional: for coordinate scaling
# Validate point format
if isinstance(point, dict):
required_keys = ["x", "y"]
if not all(key in point for key in required_keys):
raise HTTPException(
status_code=400,
detail=f"point dict must contain: {required_keys}"
)
elif isinstance(point, list):
if len(point) != 2:
raise HTTPException(
status_code=400,
detail="point list must contain exactly 2 values: [x, y]"
)
else:
raise HTTPException(
status_code=400,
detail="point must be either a dict or a list"
)
# Validate imageSize format if provided
if image_size is not None:
if isinstance(image_size, dict):
if not ("width" in image_size and "height" in image_size):
raise HTTPException(
status_code=400,
detail="imageSize dict must contain 'width' and 'height'"
)
elif isinstance(image_size, list):
if len(image_size) != 2:
raise HTTPException(
status_code=400,
detail="imageSize list must contain exactly 2 values: [width, height]"
)
else:
raise HTTPException(
status_code=400,
detail="imageSize must be either a dict or a list"
)
# Load image from URL
img_bgr = load_image_from_url(image_url)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# Predict polygon using SAM2 (point click as prompt)
mask, confidence, scale_factors = predict_polygon_from_point(img_rgb, point, image_size)
# Convert mask to polygon (CVAT-style)
polygon = mask_to_polygon(mask, scale_factors)
if not polygon:
raise HTTPException(status_code=400, detail="No polygon found in mask. Try clicking on a different point.")
return {
"polygon": polygon, # CVAT format: flattened coordinates
"confidence": confidence
}
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=500, detail=str(e))
except ImportError as e:
raise HTTPException(
status_code=500,
detail=f"Segment Anything library not installed. Please run: pip install -e . in segment-anything directory"
)
except Timeout as e:
raise HTTPException(
status_code=504,
detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
)
except RequestException as e:
raise HTTPException(
status_code=502,
detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/auto-annotate")
def auto_annotate(data: dict):
"""
Automatically detect and segment all objects in an image using SAM2 from Hugging Face.
Uses SAM2AutomaticMaskGenerator (facebook/sam2.1-hiera-large) to detect all objects without requiring prompts (bbox or points).
**Input:**
```json
{
"imageUrl": "https://example.com/image.jpg",
"imageSize": {"width": 663.07, "height": 442},
"minArea": 100,
"minConfidence": 0.5,
"maxImageDimension": 1024,
"pointsPerSide": 32,
"pointsPerBatch": 64,
"filterObjectsOnly": true
}
```
**Output:**
```json
{
"masks": [
{
"polygon": [x1, y1, x2, y2, x3, y3, ...],
"confidence": 0.93,
"area": 12345
},
...
],
"count": 10,
"memoryInfo": {
"before_mb": 512.5,
"after_mb": 1024.3,
"peak_mb": 1024.3,
"estimated_mb": 800.0,
"memory_used_mb": 511.8
},
"imageInfo": {
"wasResized": true,
"originalSize": [1920, 1080],
"processedSize": [1024, 576],
"resizeScale": [1.875, 1.875]
}
}
```
"""
try:
# Validate input
if "imageUrl" not in data:
raise HTTPException(status_code=400, detail="Missing required field: imageUrl")
image_url = data["imageUrl"]
image_size = data.get("imageSize") # Optional: for coordinate scaling
min_area = data.get("minArea", 100) # Optional: minimum mask area
min_confidence = data.get("minConfidence", 0.5) # Optional: minimum confidence
max_image_dimension = data.get("maxImageDimension", 1024) # Optional: max dimension before resizing
# Lower default values for faster processing
points_per_side = data.get("pointsPerSide", 32) # Optional: points per side (lower = faster)
points_per_batch = data.get("pointsPerBatch", 64) # Optional: points per batch (lower = faster)
filter_objects_only = data.get("filterObjectsOnly", False) # Optional: filter out background masks
# Validate imageSize format if provided
if image_size is not None:
if isinstance(image_size, dict):
if not ("width" in image_size and "height" in image_size):
raise HTTPException(
status_code=400,
detail="imageSize dict must contain 'width' and 'height'"
)
elif isinstance(image_size, list):
if len(image_size) != 2:
raise HTTPException(
status_code=400,
detail="imageSize list must contain exactly 2 values: [width, height]"
)
else:
raise HTTPException(
status_code=400,
detail="imageSize must be either a dict or a list"
)
# Validate minArea and minConfidence
try:
min_area = int(min_area)
if min_area < 0:
raise HTTPException(status_code=400, detail="minArea must be >= 0")
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="minArea must be an integer")
try:
min_confidence = float(min_confidence)
if not (0.0 <= min_confidence <= 1.0):
raise HTTPException(status_code=400, detail="minConfidence must be between 0.0 and 1.0")
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="minConfidence must be a float between 0.0 and 1.0")
# Validate maxImageDimension
try:
max_image_dimension = int(max_image_dimension)
if max_image_dimension < 256:
raise HTTPException(status_code=400, detail="maxImageDimension must be >= 256")
if max_image_dimension > 4096:
raise HTTPException(status_code=400, detail="maxImageDimension must be <= 4096")
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="maxImageDimension must be an integer between 256 and 4096")
# Validate pointsPerSide
try:
points_per_side = int(points_per_side)
if points_per_side < 8:
raise HTTPException(status_code=400, detail="pointsPerSide must be >= 8")
if points_per_side > 128:
raise HTTPException(status_code=400, detail="pointsPerSide must be <= 128")
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="pointsPerSide must be an integer between 8 and 128")
# Validate pointsPerBatch
try:
points_per_batch = int(points_per_batch)
if points_per_batch < 16:
raise HTTPException(status_code=400, detail="pointsPerBatch must be >= 16")
if points_per_batch > 256:
raise HTTPException(status_code=400, detail="pointsPerBatch must be <= 256")
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="pointsPerBatch must be an integer between 16 and 256")
# Get memory before processing
process = psutil.Process(os.getpid())
memory_before = process.memory_info().rss / (1024 * 1024) # MB
# Load image from URL
img_bgr = load_image_from_url(image_url)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
# Resize image if needed to reduce memory usage
original_h, original_w = img_rgb.shape[:2]
original_size = [original_w, original_h]
processed_image = img_rgb
resize_scale = [1.0, 1.0]
was_resized = False
if max(original_h, original_w) > max_image_dimension:
was_resized = True
if original_h > original_w:
new_h = max_image_dimension
new_w = int(original_w * (max_image_dimension / original_h))
else:
new_w = max_image_dimension
new_h = int(original_h * (max_image_dimension / original_w))
processed_image = cv2.resize(img_rgb, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resize_scale = [original_w / new_w, original_h / new_h]
processed_h, processed_w = processed_image.shape[:2]
processed_size = [processed_w, processed_h]
# Estimate memory requirements
estimated_mb = ((processed_w * processed_h * 3 * 4) + (processed_w * processed_h * 256 * 4) + (processed_w * processed_h * 100 * 1)) / (1024 * 1024)
# Calculate scale factors for coordinate scaling (matching predict_polygon_from_point logic)
# We need to scale FROM processed image TO display size (imageSize)
# mask_to_polygon expects scale_factors that represent: FROM processed TO display
# It divides by these factors, so we pass (processed_w/display_w, processed_h/display_h)
scale_factor_x, scale_factor_y = 1.0, 1.0
if image_size is not None:
if isinstance(image_size, dict):
display_w = float(image_size.get("width", processed_w))
display_h = float(image_size.get("height", processed_h))
else:
display_w, display_h = float(image_size[0]), float(image_size[1])
# Calculate scale factors: FROM processed image TO display size
# These will be used in mask_to_polygon: polygon / scale_factor = display coords
scale_factor_x = processed_w / display_w if display_w > 0 else 1.0
scale_factor_y = processed_h / display_h if display_h > 0 else 1.0
# Get image dimensions for filtering
total_image_area = processed_w * processed_h
# Initialize SAM2 Auto Annotation
# This uses facebook/sam2.1-hiera-large model from Hugging Face
# Cache the annotation instance globally to avoid reloading on every request
global sam2_auto_annotation_global
if sam2_auto_annotation_global is None:
try:
sam2_auto_annotation_global = create_sam2_auto_annotation(
points_per_side=points_per_side,
points_per_batch=points_per_batch,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
min_mask_region_area=min_area,
)
except ImportError as e:
raise HTTPException(
status_code=500,
detail=f"Failed to import required modules. Please ensure 'sam2' and 'huggingface_hub' are installed. Error: {str(e)}"
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to load SAM2 Auto Annotation from Hugging Face ({HUGGINGFACE_MODEL_ID}). Error: {str(e)}"
)
# Generate masks using SAM2AutoAnnotation with proper scaling (matching predict_polygon_from_point)
# Pass scale_factors to scale FROM processed image TO display size
mask_results = sam2_auto_annotation_global.generate_masks(
image=processed_image,
min_confidence=min_confidence,
min_area=min_area,
filter_blank_regions=True,
scale_factors=(scale_factor_x, scale_factor_y)
)
# Get memory after processing
memory_after = process.memory_info().rss / (1024 * 1024) # MB
memory_used = memory_after - memory_before
# Process mask results (polygons are already scaled to display size by generate_masks)
results = []
for mask_result in mask_results:
# Extract mask information
polygon = mask_result.get("polygon")
score = mask_result.get("confidence")
area = mask_result.get("area")
# Early filtering: Skip masks that don't meet basic criteria
if area < min_area or score < min_confidence:
continue
# Filter out background masks if filterObjectsOnly is True
if filter_objects_only:
coverage_ratio = area / total_image_area if total_image_area > 0 else 0
if coverage_ratio >= 0.8: # Skip masks covering >80% (likely background)
continue
# Polygon is already scaled to display size by generate_masks (using mask_to_polygon with scale_factors)
# Return polygon in flattened format [x1, y1, x2, y2, ...]
if polygon and len(polygon) >= 6: # At least 3 points
mask_obj = {
"polygon": polygon # Already in flattened format and scaled to display size
}
if score is not None:
mask_obj["confidence"] = score
if area is not None:
mask_obj["area"] = area
results.append(mask_obj)
# Build response with all required fields
response = {
"masks": results,
"count": len(results),
"memoryInfo": {
"before_mb": round(memory_before, 2),
"after_mb": round(memory_after, 2),
"peak_mb": round(memory_after, 2),
"estimated_mb": round(estimated_mb, 2),
"memory_used_mb": round(memory_used, 2)
},
"imageInfo": {
"wasResized": was_resized,
"originalSize": original_size,
"processedSize": processed_size,
"resizeScale": resize_scale
}
}
return response
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Missing required field: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except FileNotFoundError as e:
raise HTTPException(status_code=500, detail=str(e))
except ImportError as e:
raise HTTPException(
status_code=500,
detail=f"Segment Anything library not installed. Please ensure 'sam2' and 'huggingface_hub' are installed."
)
except Timeout as e:
raise HTTPException(
status_code=504,
detail=f"Image download timeout: {str(e)}. The image server may be slow or unreachable. Please try again or use a different image URL."
)
except RequestException as e:
raise HTTPException(
status_code=502,
detail=f"Failed to fetch image from URL: {str(e)}. Please check the image URL and try again."
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")