Noursine's picture
Create app3.py
67d88ca verified
import base64
import io
import cv2
import numpy as np
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import os
import uvicorn
from fastapi import FastAPI, UploadFile
from fastapi.responses import StreamingResponse, JSONResponse
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize, compose
from hydra.core.global_hydra import GlobalHydra
# -------------------
# Detectron2 setup
# -------------------
det_cfg = get_cfg()
det_cfg.merge_from_file(
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
)
det_cfg.MODEL.WEIGHTS = "/app/model_final.pth" # your trained weights
det_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
det_cfg.MODEL.DEVICE = "cpu" # Hugging Face free tier is CPU only
det_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
# Register class metadata
MetadataCatalog.get("__unused__").thing_classes = ["toproof"]
predictor = DefaultPredictor(det_cfg)
# -------------------
# SAM2 setup
# -------------------
os.chdir("/app") # ensure hydra looks in the right place
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
# Make sure the filename matches your repo (sam2_1_hiera_l.yaml)
with initialize(version_base=None, config_path="."):
sam2_model = build_sam2("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", device="cpu")
sam2_predictor = SAM2ImagePredictor(sam2_model)
# -------------------
# FastAPI app
# -------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
@app.get("/")
def home():
return {"status": "running"}
# -------------------
# Helpers
# -------------------
def _largest_contour(mask):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
return max(contours, key=cv2.contourArea)
def _min_area_rect_to_poly(cnt):
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
return box.astype(np.float32).reshape(-1,1,2)
def mask_to_polygon_no_holes(mask, epsilon_factor=0.005, min_area=150):
if mask.dtype != np.uint8:
if mask.max() <= 1: # case: 0/1
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
mask = (mask > 0).astype(np.uint8) * 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
if cv2.contourArea(contour) < min_area:
return None
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
return approx
def clean_polygon_strict(mask, epsilon_factor=0.01, min_area=150):
if mask.dtype != np.uint8:
if mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
bw = (mask > 127).astype(np.uint8) * 255
cnt = _largest_contour(bw)
if cnt is None:
return None, "No contour"
rect_poly = _min_area_rect_to_poly(cnt)
polyB = mask_to_polygon_no_holes(bw, epsilon_factor=epsilon_factor, min_area=min_area)
if rect_poly is not None and polyB is not None:
rect_area = cv2.contourArea(rect_poly)
contour_area = cv2.contourArea(cnt)
area_ratio = rect_area / contour_area if contour_area > 0 else 0
# πŸ”Ή If polygon has > 4 sides β†’ prefer Candidate B
if len(polyB) > 4:
return polyB, "Candidate B (Polygon)"
# πŸ”Ή Stricter rectangle test
if 0.95 < area_ratio < 1.05 and len(polyB) == 4:
return rect_poly, "Candidate A (Rectangle)"
else:
return polyB, "Candidate B (Polygon)"
elif rect_poly is not None:
return rect_poly, "Candidate A (Rectangle)"
elif polyB is not None:
return polyB, "Candidate B (Polygon)"
else:
return None, "No polygon"
# -------------------
# API Endpoint
# -------------------
@app.post("/polygon")
async def polygon_endpoint(file: UploadFile = File(...)):
contents = await file.read()
im = np.array(Image.open(io.BytesIO(contents)).convert("RGB"))
# --- Step 1: Mask R-CNN ---
outputs = predictor(im) # use the Detectron2 predictor you set up
instances = outputs["instances"].to("cpu")
boxes = instances.pred_boxes.tensor.numpy()
masks = instances.pred_masks.numpy()
if len(masks) == 0:
return JSONResponse(content={"chosen": "No mask found", "polygon": None, "image": None})
# --- Step 2: SAM2 Refinement ---
refined_all = []
sam2_predictor.set_image(im)
for i, box in enumerate(boxes):
mask_rcnn = (masks[i].astype(np.uint8) * 255)
sam_masks, sam_scores, _ = sam2_predictor.predict(
box=box[None, :], multimask_output=True
)
best_idx = np.argmax(sam_scores)
sam_mask = (sam_masks[best_idx].astype(np.uint8) * 255)
# Clean SAM2 mask
sam_clean = cv2.morphologyEx(sam_mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
sam_clean = cv2.GaussianBlur(sam_clean, (3,3), 0)
_, sam_clean = cv2.threshold(sam_clean, 127, 255, cv2.THRESH_BINARY)
# --- Step 3: Fusion ---
mask_rcnn_dilated = cv2.dilate(mask_rcnn, np.ones((5,5), np.uint8), iterations=1)
combined = cv2.bitwise_and(mask_rcnn_dilated, sam_clean)
# --- Step 4: Final polygonization ---
poly, chosen = clean_polygon_strict(combined)
refined_all.append((combined, poly, chosen))
# Take first polygon for demo
if not refined_all or refined_all[0][1] is None:
return JSONResponse(content={"chosen": "No polygon", "polygon": None, "image": None})
combined, final_poly, chosen = refined_all[0]
# --- Step 5: Preview overlay ---
overlay = im.copy()
cv2.polylines(overlay, [final_poly.astype(np.int32)], True, (0,0,255), 2)
_, buffer = cv2.imencode(".png", overlay)
img_b64 = base64.b64encode(buffer).decode("utf-8")
return {
"chosen": chosen,
"polygon": final_poly.reshape(-1, 2).tolist(),
"image": img_b64
}