Noursine commited on
Commit
734ea1c
·
verified ·
1 Parent(s): 93f67bb

Create main3.py

Browse files
Files changed (1) hide show
  1. main3.py +191 -0
main3.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import gdown
4
+ import base64
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from typing import Optional
9
+ from fastapi import FastAPI, UploadFile, File, Form
10
+ from fastapi.responses import JSONResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from detectron2.engine import DefaultPredictor
13
+ from detectron2.config import get_cfg
14
+ from detectron2.projects.point_rend import add_pointrend_config
15
+
16
+ # -------------------------------
17
+ # FastAPI setup
18
+ # -------------------------------
19
+ app = FastAPI(title="Rooftop Segmentation API")
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # -------------------------------
30
+ # Available epsilons
31
+ # -------------------------------
32
+ EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001]
33
+
34
+ @app.get("/epsilons")
35
+ def get_epsilons():
36
+ return {"epsilons": EPSILONS}
37
+
38
+ # -------------------------------
39
+ # Google Drive model download (irregular-flat)
40
+ # -------------------------------
41
+ MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth"
42
+ DRIVE_FILE_ID = "15vi4zPhCs3aBnGepVnXFOqQjxdK1jpnA"
43
+
44
+ def download_irregular_model():
45
+ if not os.path.exists(MODEL_PATH_IRREGULAR):
46
+ url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}"
47
+ tmp_dir = "/tmp/gdown"
48
+ os.makedirs(tmp_dir, exist_ok=True)
49
+ os.environ["GDOWN_CACHE_DIR"] = tmp_dir
50
+ print("Downloading irregular-flat Detectron2 model...")
51
+ gdown.download(url, MODEL_PATH_IRREGULAR, quiet=False, fuzzy=True, use_cookies=False)
52
+ print("Download complete.")
53
+ else:
54
+ print("Irregular-flat model already exists, skipping download.")
55
+
56
+ download_irregular_model()
57
+
58
+ if os.path.exists(MODEL_PATH_IRREGULAR):
59
+ print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR)
60
+ else:
61
+ print("Irregular-flat model NOT found! Something went wrong!")
62
+
63
+ # -------------------------------
64
+ # Detectron2 model setup
65
+ # -------------------------------
66
+ def setup_model_rect(weights_path: str):
67
+ cfg = get_cfg()
68
+ add_pointrend_config(cfg)
69
+ cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
70
+ cfg.merge_from_file(cfg_path)
71
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
72
+ cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
73
+ cfg.MODEL.WEIGHTS = weights_path
74
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
75
+ cfg.MODEL.DEVICE = "cpu"
76
+ return DefaultPredictor(cfg)
77
+
78
+ def setup_model_irregular(weights_path: str):
79
+ cfg = get_cfg()
80
+ add_pointrend_config(cfg)
81
+ cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
82
+ cfg.merge_from_file(cfg_path)
83
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
84
+ cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
85
+ cfg.MODEL.WEIGHTS = weights_path
86
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
87
+ cfg.MODEL.DEVICE = "cpu"
88
+ return DefaultPredictor(cfg)
89
+
90
+ # Load models
91
+ predictor_rect = setup_model_rect("/app/model_rect_final.pth")
92
+ predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR)
93
+
94
+ # -------------------------------
95
+ # Utility functions
96
+ # -------------------------------
97
+ def im_to_b64_png(im: np.ndarray) -> str:
98
+ _, buffer = cv2.imencode(".png", im)
99
+ return base64.b64encode(buffer).decode()
100
+
101
+ def extract_polygon(mask: np.ndarray, epsilon_ratio: float = 0.004):
102
+ mask_uint8 = (mask * 255).astype(np.uint8)
103
+ contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
+ if not contours:
105
+ return None
106
+ c = max(contours, key=cv2.contourArea)
107
+ epsilon = epsilon_ratio * cv2.arcLength(c, True)
108
+ polygon = cv2.approxPolyDP(c, epsilon, True)
109
+ return polygon.reshape(-1, 2)
110
+
111
+ def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray], vertex_color=(0,0,255), line_color=(0,255,0)):
112
+ overlay = im.copy()
113
+ if polygon is not None:
114
+ # Draw polygon outline (thin)
115
+ cv2.polylines(overlay, [polygon.astype(np.int32)], True, line_color, thickness=2)
116
+
117
+ # Draw vertices
118
+ for i, (x, y) in enumerate(polygon):
119
+ cv2.circle(overlay, (int(x), int(y)), 4, vertex_color, -1)
120
+ # Draw vertex index (black number)
121
+ cv2.putText(overlay, str(i+1), (int(x)+5, int(y)-5),
122
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20,20,20), 1, cv2.LINE_AA)
123
+
124
+ # Display vertex count on top
125
+ vertex_count = len(polygon)
126
+ cv2.putText(overlay, f"num_vertices = {vertex_count}", (20, 35),
127
+ cv2.FONT_HERSHEY_SIMPLEX, 0.9, (20,20,20), 2, cv2.LINE_AA)
128
+
129
+ return overlay
130
+
131
+ # -------------------------------
132
+ # API endpoints
133
+ # -------------------------------
134
+ @app.get("/")
135
+ def root():
136
+ return {"message": "Rooftop Segmentation API is running!"}
137
+
138
+ @app.post("/predict")
139
+ async def predict(
140
+ file: UploadFile = File(...),
141
+ rooftop_type: str = Form(...),
142
+ epsilon: float = Form(0.004)
143
+ ):
144
+ contents = await file.read()
145
+ try:
146
+ im_pil = Image.open(io.BytesIO(contents)).convert("RGB")
147
+ except Exception as e:
148
+ return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)})
149
+
150
+ im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR
151
+
152
+ if rooftop_type.lower() == "rectangular":
153
+ predictor = predictor_rect
154
+ model_used = "model_rect_final.pth"
155
+ elif rooftop_type.lower() == "irregular":
156
+ predictor = predictor_irregular_flat
157
+ model_used = "model_irregular_flat.pth"
158
+ else:
159
+ return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."})
160
+
161
+ outputs = predictor(im)
162
+ instances = outputs["instances"].to("cpu")
163
+
164
+ if len(instances) == 0:
165
+ return {
166
+ "polygon": None,
167
+ "vertices": None,
168
+ "vertex_count": 0,
169
+ "image": None,
170
+ "model_used": model_used,
171
+ "rooftop_type": rooftop_type,
172
+ "epsilon": epsilon
173
+ }
174
+
175
+ idx = int(instances.scores.argmax().item())
176
+ raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8)
177
+
178
+ polygon = extract_polygon(raw_mask, epsilon_ratio=epsilon)
179
+ vertex_count = len(polygon) if polygon is not None else 0
180
+
181
+ overlay = overlay_polygon(im, polygon)
182
+ img_b64 = im_to_b64_png(overlay)
183
+
184
+ return {
185
+ "polygon": polygon.tolist() if polygon is not None else None,
186
+ "vertex_count": vertex_count,
187
+ "image": img_b64,
188
+ "model_used": model_used,
189
+ "rooftop_type": rooftop_type,
190
+ "epsilon": epsilon
191
+ }