Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from samgeo import tms_to_geotiff
|
2 |
+
from samgeo.text_sam import LangSAM
|
3 |
+
|
4 |
+
sam = LangSAM()
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
from torchvision import transforms
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
from samgeo.text_sam import LangSAM
|
13 |
+
import cv2
|
14 |
+
import matplotlib.patches as patches
|
15 |
+
from transformers import SamModel, SamConfig, SamProcessor
|
16 |
+
from math import floor, ceil
|
17 |
+
from matplotlib.colors import LinearSegmentedColormap
|
18 |
+
from samgeo import tms_to_geotiff
|
19 |
+
from samgeo.text_sam import LangSAM
|
20 |
+
|
21 |
+
# Load the SAM model
|
22 |
+
sam = LangSAM()
|
23 |
+
|
24 |
+
|
25 |
+
# methods for sidewalk inferences
|
26 |
+
def get_input_image(image_file, processor, bbox=None):
|
27 |
+
# img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1)
|
28 |
+
img = torch.tensor(np.array(image_file)).permute(2, 0, 1)
|
29 |
+
'''
|
30 |
+
image = Image.open(image_file).convert('RGB')
|
31 |
+
img = np.array(image)
|
32 |
+
'''
|
33 |
+
if bbox is None:
|
34 |
+
bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box
|
35 |
+
# prepare image and prompt for the model
|
36 |
+
inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt")
|
37 |
+
# remove batch dimension which the processor adds by default
|
38 |
+
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
|
39 |
+
inputs["org_img"] = img
|
40 |
+
return inputs
|
41 |
+
|
42 |
+
|
43 |
+
def process_image(inputs):
|
44 |
+
model.eval()
|
45 |
+
with torch.no_grad():
|
46 |
+
outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device),
|
47 |
+
input_boxes=inputs["input_boxes"].unsqueeze(0).to(device),
|
48 |
+
multimask_output=False)
|
49 |
+
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
50 |
+
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
|
51 |
+
orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy()
|
52 |
+
return orig, medsam_seg_prob
|
53 |
+
|
54 |
+
|
55 |
+
def display_image(medsam_seg_prob, threshold=0.5):
|
56 |
+
medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8)
|
57 |
+
return medsam_seg
|
58 |
+
|
59 |
+
|
60 |
+
# output sidewalk with original photo
|
61 |
+
def output_sidewalk(image, medsam_seg, alpha=0.7):
|
62 |
+
# Color for 0: transparent, for 1: blue
|
63 |
+
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples
|
64 |
+
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
|
65 |
+
|
66 |
+
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
|
67 |
+
axes.imshow(np.array(image))
|
68 |
+
axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha)
|
69 |
+
axes.axis('off')
|
70 |
+
|
71 |
+
# Ensure the figure canvas is drawn
|
72 |
+
fig.canvas.draw()
|
73 |
+
|
74 |
+
# Now convert it to a NumPy array
|
75 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
76 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
77 |
+
|
78 |
+
return data
|
79 |
+
|
80 |
+
|
81 |
+
# methods for smoother sidewalk mask
|
82 |
+
def filter_weak(medsam_seg, size_threshold=10):
|
83 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8,
|
84 |
+
ltype=cv2.CV_32S)
|
85 |
+
result = np.zeros_like(medsam_seg)
|
86 |
+
for i in range(1, num_labels):
|
87 |
+
if stats[i, cv2.CC_STAT_AREA] >= size_threshold:
|
88 |
+
result[labels == i] = 1
|
89 |
+
return result
|
90 |
+
|
91 |
+
|
92 |
+
def smoothing(mask, kernel_size=(6, 6)):
|
93 |
+
kernel = np.ones(kernel_size, np.uint8)
|
94 |
+
opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
95 |
+
closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
96 |
+
return closing
|
97 |
+
|
98 |
+
|
99 |
+
def pipeline(data, size_threshold=25, kernel_size=(9, 9)):
|
100 |
+
result = filter_weak(data, size_threshold)
|
101 |
+
result = smoothing(result, kernel_size)
|
102 |
+
return result
|
103 |
+
|
104 |
+
|
105 |
+
# methods for occlusion handling
|
106 |
+
def create_boundary_mask_from_bbox(bbox, array_size, thickness=1):
|
107 |
+
# Create an empty mask with the same dimensions as the array_size
|
108 |
+
mask = np.zeros(array_size, dtype=np.uint8)
|
109 |
+
|
110 |
+
# Calculate xmin, ymin, xmax, ymax from the bbox
|
111 |
+
xmin, ymin, xmax, ymax = bbox
|
112 |
+
|
113 |
+
# Ensure the bbox coordinates are within the array bounds to avoid IndexErrors
|
114 |
+
xmin = floor(max(xmin, 0))
|
115 |
+
xmax = ceil(min(xmax, array_size[1] - 1))
|
116 |
+
ymin = floor(max(ymin, 0))
|
117 |
+
ymax = ceil(min(ymax, array_size[0] - 1))
|
118 |
+
|
119 |
+
# Draw top and bottom horizontal lines
|
120 |
+
mask[ymin:ymin + thickness, xmin:xmax] = 2
|
121 |
+
mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2
|
122 |
+
|
123 |
+
# Draw left and right vertical lines
|
124 |
+
mask[ymin:ymax, xmin:xmin + thickness] = 2
|
125 |
+
mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2
|
126 |
+
|
127 |
+
return mask
|
128 |
+
|
129 |
+
|
130 |
+
def check_boundary(m1, m2, radius=1):
|
131 |
+
# Initialize an output mask of the same shape as m2, filled with zeros
|
132 |
+
boundary_mask = np.zeros_like(m2)
|
133 |
+
|
134 |
+
# Get the dimensions of the masks
|
135 |
+
rows, cols = m2.shape
|
136 |
+
|
137 |
+
# Iterate through each pixel in the m2 mask
|
138 |
+
for r in range(rows):
|
139 |
+
for c in range(cols):
|
140 |
+
# Check if the current pixel is a 'tree' pixel
|
141 |
+
if m2[r, c] == 2:
|
142 |
+
# Initialize a flag to check for at least one adjacent 'sidewalk'
|
143 |
+
found_sidewalk = 0
|
144 |
+
|
145 |
+
# Check the square around the current pixel with given radius
|
146 |
+
for dr in range(-radius, radius + 1):
|
147 |
+
for dc in range(-radius, radius + 1):
|
148 |
+
# Calculate the neighbor's position
|
149 |
+
nr, nc = r + dr, c + dc
|
150 |
+
|
151 |
+
# Ensure we're not out of bounds and we're not checking the center pixel itself
|
152 |
+
if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0):
|
153 |
+
if m1[nr, nc] == 1:
|
154 |
+
found_sidewalk += 1
|
155 |
+
|
156 |
+
boundary_mask[r, c] = found_sidewalk
|
157 |
+
|
158 |
+
return boundary_mask
|
159 |
+
|
160 |
+
|
161 |
+
def linear_regression_two_points(point1, point2):
|
162 |
+
# Create arrays of x and y values
|
163 |
+
x = np.array([point1[0], point2[0]])
|
164 |
+
y = np.array([point1[1], point2[1]])
|
165 |
+
|
166 |
+
# Perform linear regression: np.polyfit returns the slope and intercept
|
167 |
+
m, b = np.polyfit(x, y, 1)
|
168 |
+
return m, b, x, y
|
169 |
+
|
170 |
+
|
171 |
+
def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)):
|
172 |
+
# Create a blank black image (all zeros)
|
173 |
+
image = np.zeros(image_size, dtype=np.uint8)
|
174 |
+
|
175 |
+
# Define x values within the specified range x1 to x2
|
176 |
+
x_values = np.array(range(x1, x2 + 1))
|
177 |
+
|
178 |
+
# Calculate corresponding y values using the slope and intercept
|
179 |
+
y_values = (slope * x_values + intercept).astype(int)
|
180 |
+
|
181 |
+
# Draw the road line with the specified width
|
182 |
+
for i in range(len(x_values)):
|
183 |
+
if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries
|
184 |
+
cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line
|
185 |
+
|
186 |
+
return image
|
187 |
+
|
188 |
+
|
189 |
+
def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1):
|
190 |
+
array_size = (256, 256) # Define the size of the 2D mask
|
191 |
+
mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1)
|
192 |
+
|
193 |
+
# get intersection
|
194 |
+
output = check_boundary(filtered_med_seg, mask, radius)
|
195 |
+
|
196 |
+
# get connected component and centriods
|
197 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S)
|
198 |
+
centroids = centroids[1:]
|
199 |
+
centroids = sorted(centroids, key=lambda x: x[0])
|
200 |
+
|
201 |
+
# check if we have two 2 centriods
|
202 |
+
if len(centroids) == 2:
|
203 |
+
# linear regression
|
204 |
+
slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1])
|
205 |
+
# get road mask inferred from tree bbox intersection points
|
206 |
+
road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3)
|
207 |
+
else:
|
208 |
+
return None
|
209 |
+
|
210 |
+
return road_mask
|
211 |
+
|
212 |
+
|
213 |
+
def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7):
|
214 |
+
# Using SAM model to predict on the image with a specific prompt
|
215 |
+
text_prompt = "tree"
|
216 |
+
masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24,
|
217 |
+
return_results=True)
|
218 |
+
|
219 |
+
# Setting up custom color maps for overlays
|
220 |
+
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color
|
221 |
+
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
|
222 |
+
|
223 |
+
colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color
|
224 |
+
cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt)
|
225 |
+
|
226 |
+
# Plotting the results
|
227 |
+
# fig, axes = plt.subplots(1, 3, figsize=(18, 6))
|
228 |
+
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
|
229 |
+
# fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16)
|
230 |
+
'''
|
231 |
+
axes[0].imshow(image)
|
232 |
+
axes[0].set_title("Original Image")
|
233 |
+
axes[0].axis('off')
|
234 |
+
|
235 |
+
axes[1].imshow(image)
|
236 |
+
axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7)
|
237 |
+
axes[1].axis('off')
|
238 |
+
axes[1].set_title("Sidewalk Mask - Initial")
|
239 |
+
'''
|
240 |
+
|
241 |
+
axes.imshow(image)
|
242 |
+
axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha)
|
243 |
+
axes.axis('off')
|
244 |
+
# axes.set_title("Sidewalk Mask - Refined with Occlusion Handling")
|
245 |
+
|
246 |
+
for bbox in boxes:
|
247 |
+
road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1)
|
248 |
+
if road_mask is not None:
|
249 |
+
axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha)
|
250 |
+
rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1,
|
251 |
+
edgecolor='r', facecolor='none')
|
252 |
+
axes.add_patch(rect)
|
253 |
+
|
254 |
+
# Ensure the figure canvas is drawn
|
255 |
+
fig.canvas.draw()
|
256 |
+
|
257 |
+
# Now convert it to a NumPy array
|
258 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
259 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
260 |
+
|
261 |
+
return data
|
262 |
+
|
263 |
+
|
264 |
+
# Load pretrained model
|
265 |
+
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
|
266 |
+
model = SamModel(config=model_config)
|
267 |
+
model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu')))
|
268 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
269 |
+
model.to(device) # Move model to device once here instead of in the function
|
270 |
+
|
271 |
+
# special methods for gradio
|
272 |
+
partial_results = {}
|
273 |
+
|
274 |
+
|
275 |
+
def process_pipeline(image, threshold, alpha):
|
276 |
+
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
|
277 |
+
processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256])
|
278 |
+
orig, medsam_seg_prob = process_image(processed_inputs)
|
279 |
+
medsam_seg = display_image(medsam_seg_prob, threshold)
|
280 |
+
filtered_med_seg = pipeline(medsam_seg)
|
281 |
+
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
|
282 |
+
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
|
283 |
+
partial_results["prob"] = medsam_seg_prob
|
284 |
+
partial_results["orig"] = orig
|
285 |
+
partial_results["filtered_med_seg"] = filtered_med_seg
|
286 |
+
return output_image, filled_image
|
287 |
+
|
288 |
+
|
289 |
+
def update_output(image, threshold, alpha):
|
290 |
+
if "prob" in partial_results and "orig" in partial_results:
|
291 |
+
medsam_seg_prob = partial_results['prob']
|
292 |
+
orig = partial_results['orig']
|
293 |
+
medsam_seg = display_image(medsam_seg_prob, threshold)
|
294 |
+
filtered_med_seg = pipeline(medsam_seg)
|
295 |
+
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
|
296 |
+
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
|
297 |
+
partial_results["filtered_med_seg"] = filtered_med_seg
|
298 |
+
return output_image, filled_image
|
299 |
+
|
300 |
+
|
301 |
+
def update_output_alpha(image, threshold, alpha):
|
302 |
+
if "prob" in partial_results and "filtered_med_seg" in partial_results:
|
303 |
+
medsam_seg_prob = partial_results['prob']
|
304 |
+
orig = partial_results['orig']
|
305 |
+
filtered_med_seg = partial_results["filtered_med_seg"]
|
306 |
+
output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha)
|
307 |
+
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
|
308 |
+
return output_image, filled_image
|
309 |
+
|
310 |
+
|
311 |
+
with gr.Blocks() as app:
|
312 |
+
gr.Markdown("# Sidewalk Detection with SAM Model")
|
313 |
+
gr.Markdown("#### by Dan Mao, Kevin Tan")
|
314 |
+
with gr.Row():
|
315 |
+
with gr.Column():
|
316 |
+
img_in = gr.Image(type="pil", label="Upload Image")
|
317 |
+
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold")
|
318 |
+
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay")
|
319 |
+
submit_button = gr.Button("Process Image")
|
320 |
+
|
321 |
+
with gr.Column():
|
322 |
+
img_out1 = gr.Image(label="Sidewalk Mask - Initial")
|
323 |
+
img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling")
|
324 |
+
gr.ClearButton(components=[img_in, img_out1, img_out2])
|
325 |
+
|
326 |
+
# Setting up triggers for changes and button clicks
|
327 |
+
threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
|
328 |
+
alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
|
329 |
+
submit_button.click(
|
330 |
+
fn=process_pipeline,
|
331 |
+
inputs=[img_in, threshold, alpha],
|
332 |
+
outputs=[img_out1, img_out2]
|
333 |
+
)
|
334 |
+
|
335 |
+
app.launch(debug=True)
|