| from fastapi import FastAPI, File, UploadFile, Form, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from segment_anything import sam_model_registry, SamPredictor |
| from PIL import Image |
| import numpy as np |
| import torch |
| import io |
| import base64 |
| import json |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| sam_checkpoint = "sam_vit_b.pth" |
| model_type = "vit_b" |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device) |
| predictor = SamPredictor(sam) |
|
|
| @app.get("/") |
| def read_root(): |
| return {"status": "SAM API is running"} |
|
|
| @app.post("/segment") |
| async def segment_image(file: UploadFile = File(...)): |
| try: |
| image_bytes = await file.read() |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| image_np = np.array(image) |
| |
| |
| height, width = image_np.shape[:2] |
| |
| |
| center_point = np.array([[width // 2, height // 2]]) |
| input_label = np.array([1]) |
| |
| predictor.set_image(image_np) |
| masks, scores, _ = predictor.predict( |
| point_coords=center_point, |
| point_labels=input_label, |
| multimask_output=True |
| ) |
| |
| |
| best_mask_idx = np.argmax(scores) |
| mask = masks[best_mask_idx].astype(bool) |
| |
| return { |
| "score": float(scores[best_mask_idx]), |
| "mask": mask.tolist() |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/models") |
| def list_models(): |
| return { |
| "models": [ |
| { |
| "name": "sam-cvat", |
| "type": "segmentation", |
| "labels": ["object"] |
| } |
| ] |
| } |
|
|
| |
| @app.post("/predict") |
| async def predict_for_cvat(body: str = Form(...)): |
| try: |
| data = json.loads(body) |
| image_data = data.get('image', '') |
| |
| |
| image_bytes = base64.b64decode(image_data) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| image_np = np.array(image) |
| |
| |
| points = data.get('points', []) |
| if not points: |
| |
| height, width = image_np.shape[:2] |
| points = [[width // 2, height // 2]] |
| |
| input_points = np.array(points) |
| input_labels = np.ones(len(points)) |
| |
| predictor.set_image(image_np) |
| masks, scores, _ = predictor.predict( |
| point_coords=input_points, |
| point_labels=input_labels, |
| multimask_output=True |
| ) |
| |
| |
| best_mask_idx = np.argmax(scores) |
| mask = masks[best_mask_idx].astype(bool) |
| |
| |
| height, width = mask.shape |
| rle = mask_to_rle(mask) |
| return { |
| "model": "sam-cvat", |
| "annotations": [{ |
| "name": "object", |
| "score": float(scores[best_mask_idx]), |
| "mask": { |
| "rle": rle, |
| "width": width, |
| "height": height |
| }, |
| "type": "mask" |
| }] |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| def mask_to_rle(mask): |
| """Convert mask to RLE format expected by CVAT""" |
| flattened_mask = mask.flatten() |
| rle = [] |
| current_pixel = 0 |
| count = 0 |
| |
| for pixel in flattened_mask: |
| if pixel == current_pixel: |
| count += 1 |
| else: |
| rle.append(count) |
| current_pixel = pixel |
| count = 1 |
| |
| rle.append(count) |
| return rle |