Noursine commited on
Commit
67d88ca
·
verified ·
1 Parent(s): c978fec

Create app3.py

Browse files
Files changed (1) hide show
  1. app3.py +196 -0
app3.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import cv2
4
+ import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File
6
+ from fastapi.responses import JSONResponse
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from PIL import Image
9
+ import torch
10
+ import os
11
+ import uvicorn
12
+ from fastapi import FastAPI, UploadFile
13
+ from fastapi.responses import StreamingResponse, JSONResponse
14
+ from detectron2.engine import DefaultPredictor
15
+ from detectron2.config import get_cfg
16
+ from detectron2 import model_zoo
17
+ from detectron2.data import MetadataCatalog
18
+ from sam2.build_sam import build_sam2
19
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
20
+ from hydra import initialize, compose
21
+ from hydra.core.global_hydra import GlobalHydra
22
+ # -------------------
23
+ # Detectron2 setup
24
+ # -------------------
25
+ det_cfg = get_cfg()
26
+ det_cfg.merge_from_file(
27
+ model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
28
+ )
29
+ det_cfg.MODEL.WEIGHTS = "/app/model_final.pth" # your trained weights
30
+ det_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
31
+ det_cfg.MODEL.DEVICE = "cpu" # Hugging Face free tier is CPU only
32
+ det_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
33
+
34
+ # Register class metadata
35
+ MetadataCatalog.get("__unused__").thing_classes = ["toproof"]
36
+
37
+ predictor = DefaultPredictor(det_cfg)
38
+
39
+
40
+ # -------------------
41
+ # SAM2 setup
42
+ # -------------------
43
+ os.chdir("/app") # ensure hydra looks in the right place
44
+ if GlobalHydra.instance().is_initialized():
45
+ GlobalHydra.instance().clear()
46
+
47
+ # Make sure the filename matches your repo (sam2_1_hiera_l.yaml)
48
+ with initialize(version_base=None, config_path="."):
49
+ sam2_model = build_sam2("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", device="cpu")
50
+
51
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
52
+
53
+ # -------------------
54
+ # FastAPI app
55
+ # -------------------
56
+ app = FastAPI()
57
+ app.add_middleware(
58
+ CORSMiddleware,
59
+ allow_origins=["*"], allow_credentials=True,
60
+ allow_methods=["*"], allow_headers=["*"],
61
+ )
62
+
63
+ @app.get("/")
64
+ def home():
65
+ return {"status": "running"}
66
+
67
+ # -------------------
68
+ # Helpers
69
+ # -------------------
70
+ def _largest_contour(mask):
71
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
72
+ if not contours:
73
+ return None
74
+ return max(contours, key=cv2.contourArea)
75
+
76
+ def _min_area_rect_to_poly(cnt):
77
+ rect = cv2.minAreaRect(cnt)
78
+ box = cv2.boxPoints(rect)
79
+ return box.astype(np.float32).reshape(-1,1,2)
80
+
81
+ def mask_to_polygon_no_holes(mask, epsilon_factor=0.005, min_area=150):
82
+ if mask.dtype != np.uint8:
83
+ if mask.max() <= 1: # case: 0/1
84
+ mask = (mask * 255).astype(np.uint8)
85
+ else:
86
+ mask = mask.astype(np.uint8)
87
+
88
+ mask = (mask > 0).astype(np.uint8) * 255
89
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
90
+ if not contours:
91
+ return None
92
+ contour = max(contours, key=cv2.contourArea)
93
+ if cv2.contourArea(contour) < min_area:
94
+ return None
95
+ epsilon = epsilon_factor * cv2.arcLength(contour, True)
96
+ approx = cv2.approxPolyDP(contour, epsilon, True)
97
+ return approx
98
+
99
+ def clean_polygon_strict(mask, epsilon_factor=0.01, min_area=150):
100
+ if mask.dtype != np.uint8:
101
+ if mask.max() <= 1:
102
+ mask = (mask * 255).astype(np.uint8)
103
+ else:
104
+ mask = mask.astype(np.uint8)
105
+
106
+ bw = (mask > 127).astype(np.uint8) * 255
107
+ cnt = _largest_contour(bw)
108
+ if cnt is None:
109
+ return None, "No contour"
110
+
111
+ rect_poly = _min_area_rect_to_poly(cnt)
112
+ polyB = mask_to_polygon_no_holes(bw, epsilon_factor=epsilon_factor, min_area=min_area)
113
+
114
+ if rect_poly is not None and polyB is not None:
115
+ rect_area = cv2.contourArea(rect_poly)
116
+ contour_area = cv2.contourArea(cnt)
117
+ area_ratio = rect_area / contour_area if contour_area > 0 else 0
118
+
119
+ # 🔹 If polygon has > 4 sides → prefer Candidate B
120
+ if len(polyB) > 4:
121
+ return polyB, "Candidate B (Polygon)"
122
+
123
+ # 🔹 Stricter rectangle test
124
+ if 0.95 < area_ratio < 1.05 and len(polyB) == 4:
125
+ return rect_poly, "Candidate A (Rectangle)"
126
+ else:
127
+ return polyB, "Candidate B (Polygon)"
128
+
129
+ elif rect_poly is not None:
130
+ return rect_poly, "Candidate A (Rectangle)"
131
+ elif polyB is not None:
132
+ return polyB, "Candidate B (Polygon)"
133
+ else:
134
+ return None, "No polygon"
135
+
136
+ # -------------------
137
+ # API Endpoint
138
+ # -------------------
139
+ @app.post("/polygon")
140
+ async def polygon_endpoint(file: UploadFile = File(...)):
141
+ contents = await file.read()
142
+ im = np.array(Image.open(io.BytesIO(contents)).convert("RGB"))
143
+
144
+ # --- Step 1: Mask R-CNN ---
145
+ outputs = predictor(im) # use the Detectron2 predictor you set up
146
+ instances = outputs["instances"].to("cpu")
147
+
148
+ boxes = instances.pred_boxes.tensor.numpy()
149
+ masks = instances.pred_masks.numpy()
150
+
151
+ if len(masks) == 0:
152
+ return JSONResponse(content={"chosen": "No mask found", "polygon": None, "image": None})
153
+
154
+ # --- Step 2: SAM2 Refinement ---
155
+ refined_all = []
156
+ sam2_predictor.set_image(im)
157
+
158
+ for i, box in enumerate(boxes):
159
+ mask_rcnn = (masks[i].astype(np.uint8) * 255)
160
+
161
+ sam_masks, sam_scores, _ = sam2_predictor.predict(
162
+ box=box[None, :], multimask_output=True
163
+ )
164
+ best_idx = np.argmax(sam_scores)
165
+ sam_mask = (sam_masks[best_idx].astype(np.uint8) * 255)
166
+
167
+ # Clean SAM2 mask
168
+ sam_clean = cv2.morphologyEx(sam_mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
169
+ sam_clean = cv2.GaussianBlur(sam_clean, (3,3), 0)
170
+ _, sam_clean = cv2.threshold(sam_clean, 127, 255, cv2.THRESH_BINARY)
171
+
172
+ # --- Step 3: Fusion ---
173
+ mask_rcnn_dilated = cv2.dilate(mask_rcnn, np.ones((5,5), np.uint8), iterations=1)
174
+ combined = cv2.bitwise_and(mask_rcnn_dilated, sam_clean)
175
+
176
+ # --- Step 4: Final polygonization ---
177
+ poly, chosen = clean_polygon_strict(combined)
178
+ refined_all.append((combined, poly, chosen))
179
+
180
+ # Take first polygon for demo
181
+ if not refined_all or refined_all[0][1] is None:
182
+ return JSONResponse(content={"chosen": "No polygon", "polygon": None, "image": None})
183
+
184
+ combined, final_poly, chosen = refined_all[0]
185
+
186
+ # --- Step 5: Preview overlay ---
187
+ overlay = im.copy()
188
+ cv2.polylines(overlay, [final_poly.astype(np.int32)], True, (0,0,255), 2)
189
+ _, buffer = cv2.imencode(".png", overlay)
190
+ img_b64 = base64.b64encode(buffer).decode("utf-8")
191
+
192
+ return {
193
+ "chosen": chosen,
194
+ "polygon": final_poly.reshape(-1, 2).tolist(),
195
+ "image": img_b64
196
+ }