lm7154 commited on
Commit
0708936
1 Parent(s): a1c0e34

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +335 -0
app.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from samgeo import tms_to_geotiff
2
+ from samgeo.text_sam import LangSAM
3
+
4
+ sam = LangSAM()
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from torchvision import transforms
11
+ from matplotlib import pyplot as plt
12
+ from samgeo.text_sam import LangSAM
13
+ import cv2
14
+ import matplotlib.patches as patches
15
+ from transformers import SamModel, SamConfig, SamProcessor
16
+ from math import floor, ceil
17
+ from matplotlib.colors import LinearSegmentedColormap
18
+ from samgeo import tms_to_geotiff
19
+ from samgeo.text_sam import LangSAM
20
+
21
+ # Load the SAM model
22
+ sam = LangSAM()
23
+
24
+
25
+ # methods for sidewalk inferences
26
+ def get_input_image(image_file, processor, bbox=None):
27
+ # img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1)
28
+ img = torch.tensor(np.array(image_file)).permute(2, 0, 1)
29
+ '''
30
+ image = Image.open(image_file).convert('RGB')
31
+ img = np.array(image)
32
+ '''
33
+ if bbox is None:
34
+ bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box
35
+ # prepare image and prompt for the model
36
+ inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt")
37
+ # remove batch dimension which the processor adds by default
38
+ inputs = {k: v.squeeze(0) for k, v in inputs.items()}
39
+ inputs["org_img"] = img
40
+ return inputs
41
+
42
+
43
+ def process_image(inputs):
44
+ model.eval()
45
+ with torch.no_grad():
46
+ outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device),
47
+ input_boxes=inputs["input_boxes"].unsqueeze(0).to(device),
48
+ multimask_output=False)
49
+ medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
50
+ medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
51
+ orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy()
52
+ return orig, medsam_seg_prob
53
+
54
+
55
+ def display_image(medsam_seg_prob, threshold=0.5):
56
+ medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8)
57
+ return medsam_seg
58
+
59
+
60
+ # output sidewalk with original photo
61
+ def output_sidewalk(image, medsam_seg, alpha=0.7):
62
+ # Color for 0: transparent, for 1: blue
63
+ colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples
64
+ cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
65
+
66
+ fig, axes = plt.subplots(1, 1, figsize=(8, 8))
67
+ axes.imshow(np.array(image))
68
+ axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha)
69
+ axes.axis('off')
70
+
71
+ # Ensure the figure canvas is drawn
72
+ fig.canvas.draw()
73
+
74
+ # Now convert it to a NumPy array
75
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
76
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
77
+
78
+ return data
79
+
80
+
81
+ # methods for smoother sidewalk mask
82
+ def filter_weak(medsam_seg, size_threshold=10):
83
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8,
84
+ ltype=cv2.CV_32S)
85
+ result = np.zeros_like(medsam_seg)
86
+ for i in range(1, num_labels):
87
+ if stats[i, cv2.CC_STAT_AREA] >= size_threshold:
88
+ result[labels == i] = 1
89
+ return result
90
+
91
+
92
+ def smoothing(mask, kernel_size=(6, 6)):
93
+ kernel = np.ones(kernel_size, np.uint8)
94
+ opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
95
+ closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
96
+ return closing
97
+
98
+
99
+ def pipeline(data, size_threshold=25, kernel_size=(9, 9)):
100
+ result = filter_weak(data, size_threshold)
101
+ result = smoothing(result, kernel_size)
102
+ return result
103
+
104
+
105
+ # methods for occlusion handling
106
+ def create_boundary_mask_from_bbox(bbox, array_size, thickness=1):
107
+ # Create an empty mask with the same dimensions as the array_size
108
+ mask = np.zeros(array_size, dtype=np.uint8)
109
+
110
+ # Calculate xmin, ymin, xmax, ymax from the bbox
111
+ xmin, ymin, xmax, ymax = bbox
112
+
113
+ # Ensure the bbox coordinates are within the array bounds to avoid IndexErrors
114
+ xmin = floor(max(xmin, 0))
115
+ xmax = ceil(min(xmax, array_size[1] - 1))
116
+ ymin = floor(max(ymin, 0))
117
+ ymax = ceil(min(ymax, array_size[0] - 1))
118
+
119
+ # Draw top and bottom horizontal lines
120
+ mask[ymin:ymin + thickness, xmin:xmax] = 2
121
+ mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2
122
+
123
+ # Draw left and right vertical lines
124
+ mask[ymin:ymax, xmin:xmin + thickness] = 2
125
+ mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2
126
+
127
+ return mask
128
+
129
+
130
+ def check_boundary(m1, m2, radius=1):
131
+ # Initialize an output mask of the same shape as m2, filled with zeros
132
+ boundary_mask = np.zeros_like(m2)
133
+
134
+ # Get the dimensions of the masks
135
+ rows, cols = m2.shape
136
+
137
+ # Iterate through each pixel in the m2 mask
138
+ for r in range(rows):
139
+ for c in range(cols):
140
+ # Check if the current pixel is a 'tree' pixel
141
+ if m2[r, c] == 2:
142
+ # Initialize a flag to check for at least one adjacent 'sidewalk'
143
+ found_sidewalk = 0
144
+
145
+ # Check the square around the current pixel with given radius
146
+ for dr in range(-radius, radius + 1):
147
+ for dc in range(-radius, radius + 1):
148
+ # Calculate the neighbor's position
149
+ nr, nc = r + dr, c + dc
150
+
151
+ # Ensure we're not out of bounds and we're not checking the center pixel itself
152
+ if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0):
153
+ if m1[nr, nc] == 1:
154
+ found_sidewalk += 1
155
+
156
+ boundary_mask[r, c] = found_sidewalk
157
+
158
+ return boundary_mask
159
+
160
+
161
+ def linear_regression_two_points(point1, point2):
162
+ # Create arrays of x and y values
163
+ x = np.array([point1[0], point2[0]])
164
+ y = np.array([point1[1], point2[1]])
165
+
166
+ # Perform linear regression: np.polyfit returns the slope and intercept
167
+ m, b = np.polyfit(x, y, 1)
168
+ return m, b, x, y
169
+
170
+
171
+ def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)):
172
+ # Create a blank black image (all zeros)
173
+ image = np.zeros(image_size, dtype=np.uint8)
174
+
175
+ # Define x values within the specified range x1 to x2
176
+ x_values = np.array(range(x1, x2 + 1))
177
+
178
+ # Calculate corresponding y values using the slope and intercept
179
+ y_values = (slope * x_values + intercept).astype(int)
180
+
181
+ # Draw the road line with the specified width
182
+ for i in range(len(x_values)):
183
+ if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries
184
+ cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line
185
+
186
+ return image
187
+
188
+
189
+ def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1):
190
+ array_size = (256, 256) # Define the size of the 2D mask
191
+ mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1)
192
+
193
+ # get intersection
194
+ output = check_boundary(filtered_med_seg, mask, radius)
195
+
196
+ # get connected component and centriods
197
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S)
198
+ centroids = centroids[1:]
199
+ centroids = sorted(centroids, key=lambda x: x[0])
200
+
201
+ # check if we have two 2 centriods
202
+ if len(centroids) == 2:
203
+ # linear regression
204
+ slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1])
205
+ # get road mask inferred from tree bbox intersection points
206
+ road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3)
207
+ else:
208
+ return None
209
+
210
+ return road_mask
211
+
212
+
213
+ def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7):
214
+ # Using SAM model to predict on the image with a specific prompt
215
+ text_prompt = "tree"
216
+ masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24,
217
+ return_results=True)
218
+
219
+ # Setting up custom color maps for overlays
220
+ colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color
221
+ cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
222
+
223
+ colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color
224
+ cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt)
225
+
226
+ # Plotting the results
227
+ # fig, axes = plt.subplots(1, 3, figsize=(18, 6))
228
+ fig, axes = plt.subplots(1, 1, figsize=(8, 8))
229
+ # fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16)
230
+ '''
231
+ axes[0].imshow(image)
232
+ axes[0].set_title("Original Image")
233
+ axes[0].axis('off')
234
+
235
+ axes[1].imshow(image)
236
+ axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7)
237
+ axes[1].axis('off')
238
+ axes[1].set_title("Sidewalk Mask - Initial")
239
+ '''
240
+
241
+ axes.imshow(image)
242
+ axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha)
243
+ axes.axis('off')
244
+ # axes.set_title("Sidewalk Mask - Refined with Occlusion Handling")
245
+
246
+ for bbox in boxes:
247
+ road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1)
248
+ if road_mask is not None:
249
+ axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha)
250
+ rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1,
251
+ edgecolor='r', facecolor='none')
252
+ axes.add_patch(rect)
253
+
254
+ # Ensure the figure canvas is drawn
255
+ fig.canvas.draw()
256
+
257
+ # Now convert it to a NumPy array
258
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
259
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
260
+
261
+ return data
262
+
263
+
264
+ # Load pretrained model
265
+ model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
266
+ model = SamModel(config=model_config)
267
+ model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu')))
268
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
269
+ model.to(device) # Move model to device once here instead of in the function
270
+
271
+ # special methods for gradio
272
+ partial_results = {}
273
+
274
+
275
+ def process_pipeline(image, threshold, alpha):
276
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
277
+ processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256])
278
+ orig, medsam_seg_prob = process_image(processed_inputs)
279
+ medsam_seg = display_image(medsam_seg_prob, threshold)
280
+ filtered_med_seg = pipeline(medsam_seg)
281
+ output_image = output_sidewalk(orig, filtered_med_seg, alpha)
282
+ filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
283
+ partial_results["prob"] = medsam_seg_prob
284
+ partial_results["orig"] = orig
285
+ partial_results["filtered_med_seg"] = filtered_med_seg
286
+ return output_image, filled_image
287
+
288
+
289
+ def update_output(image, threshold, alpha):
290
+ if "prob" in partial_results and "orig" in partial_results:
291
+ medsam_seg_prob = partial_results['prob']
292
+ orig = partial_results['orig']
293
+ medsam_seg = display_image(medsam_seg_prob, threshold)
294
+ filtered_med_seg = pipeline(medsam_seg)
295
+ output_image = output_sidewalk(orig, filtered_med_seg, alpha)
296
+ filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
297
+ partial_results["filtered_med_seg"] = filtered_med_seg
298
+ return output_image, filled_image
299
+
300
+
301
+ def update_output_alpha(image, threshold, alpha):
302
+ if "prob" in partial_results and "filtered_med_seg" in partial_results:
303
+ medsam_seg_prob = partial_results['prob']
304
+ orig = partial_results['orig']
305
+ filtered_med_seg = partial_results["filtered_med_seg"]
306
+ output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha)
307
+ filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
308
+ return output_image, filled_image
309
+
310
+
311
+ with gr.Blocks() as app:
312
+ gr.Markdown("# Sidewalk Detection with SAM Model")
313
+ gr.Markdown("#### by Dan Mao, Kevin Tan")
314
+ with gr.Row():
315
+ with gr.Column():
316
+ img_in = gr.Image(type="pil", label="Upload Image")
317
+ threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold")
318
+ alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay")
319
+ submit_button = gr.Button("Process Image")
320
+
321
+ with gr.Column():
322
+ img_out1 = gr.Image(label="Sidewalk Mask - Initial")
323
+ img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling")
324
+ gr.ClearButton(components=[img_in, img_out1, img_out2])
325
+
326
+ # Setting up triggers for changes and button clicks
327
+ threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
328
+ alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
329
+ submit_button.click(
330
+ fn=process_pipeline,
331
+ inputs=[img_in, threshold, alpha],
332
+ outputs=[img_out1, img_out2]
333
+ )
334
+
335
+ app.launch(debug=True)