Spaces:
Sleeping
Sleeping
File size: 12,771 Bytes
0708936 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
from samgeo import tms_to_geotiff
from samgeo.text_sam import LangSAM
sam = LangSAM()
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from matplotlib import pyplot as plt
from samgeo.text_sam import LangSAM
import cv2
import matplotlib.patches as patches
from transformers import SamModel, SamConfig, SamProcessor
from math import floor, ceil
from matplotlib.colors import LinearSegmentedColormap
from samgeo import tms_to_geotiff
from samgeo.text_sam import LangSAM
# Load the SAM model
sam = LangSAM()
# methods for sidewalk inferences
def get_input_image(image_file, processor, bbox=None):
# img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1)
img = torch.tensor(np.array(image_file)).permute(2, 0, 1)
'''
image = Image.open(image_file).convert('RGB')
img = np.array(image)
'''
if bbox is None:
bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box
# prepare image and prompt for the model
inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["org_img"] = img
return inputs
def process_image(inputs):
model.eval()
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device),
input_boxes=inputs["input_boxes"].unsqueeze(0).to(device),
multimask_output=False)
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy()
return orig, medsam_seg_prob
def display_image(medsam_seg_prob, threshold=0.5):
medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8)
return medsam_seg
# output sidewalk with original photo
def output_sidewalk(image, medsam_seg, alpha=0.7):
# Color for 0: transparent, for 1: blue
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
axes.imshow(np.array(image))
axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha)
axes.axis('off')
# Ensure the figure canvas is drawn
fig.canvas.draw()
# Now convert it to a NumPy array
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
# methods for smoother sidewalk mask
def filter_weak(medsam_seg, size_threshold=10):
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8,
ltype=cv2.CV_32S)
result = np.zeros_like(medsam_seg)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= size_threshold:
result[labels == i] = 1
return result
def smoothing(mask, kernel_size=(6, 6)):
kernel = np.ones(kernel_size, np.uint8)
opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return closing
def pipeline(data, size_threshold=25, kernel_size=(9, 9)):
result = filter_weak(data, size_threshold)
result = smoothing(result, kernel_size)
return result
# methods for occlusion handling
def create_boundary_mask_from_bbox(bbox, array_size, thickness=1):
# Create an empty mask with the same dimensions as the array_size
mask = np.zeros(array_size, dtype=np.uint8)
# Calculate xmin, ymin, xmax, ymax from the bbox
xmin, ymin, xmax, ymax = bbox
# Ensure the bbox coordinates are within the array bounds to avoid IndexErrors
xmin = floor(max(xmin, 0))
xmax = ceil(min(xmax, array_size[1] - 1))
ymin = floor(max(ymin, 0))
ymax = ceil(min(ymax, array_size[0] - 1))
# Draw top and bottom horizontal lines
mask[ymin:ymin + thickness, xmin:xmax] = 2
mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2
# Draw left and right vertical lines
mask[ymin:ymax, xmin:xmin + thickness] = 2
mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2
return mask
def check_boundary(m1, m2, radius=1):
# Initialize an output mask of the same shape as m2, filled with zeros
boundary_mask = np.zeros_like(m2)
# Get the dimensions of the masks
rows, cols = m2.shape
# Iterate through each pixel in the m2 mask
for r in range(rows):
for c in range(cols):
# Check if the current pixel is a 'tree' pixel
if m2[r, c] == 2:
# Initialize a flag to check for at least one adjacent 'sidewalk'
found_sidewalk = 0
# Check the square around the current pixel with given radius
for dr in range(-radius, radius + 1):
for dc in range(-radius, radius + 1):
# Calculate the neighbor's position
nr, nc = r + dr, c + dc
# Ensure we're not out of bounds and we're not checking the center pixel itself
if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0):
if m1[nr, nc] == 1:
found_sidewalk += 1
boundary_mask[r, c] = found_sidewalk
return boundary_mask
def linear_regression_two_points(point1, point2):
# Create arrays of x and y values
x = np.array([point1[0], point2[0]])
y = np.array([point1[1], point2[1]])
# Perform linear regression: np.polyfit returns the slope and intercept
m, b = np.polyfit(x, y, 1)
return m, b, x, y
def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)):
# Create a blank black image (all zeros)
image = np.zeros(image_size, dtype=np.uint8)
# Define x values within the specified range x1 to x2
x_values = np.array(range(x1, x2 + 1))
# Calculate corresponding y values using the slope and intercept
y_values = (slope * x_values + intercept).astype(int)
# Draw the road line with the specified width
for i in range(len(x_values)):
if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries
cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line
return image
def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1):
array_size = (256, 256) # Define the size of the 2D mask
mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1)
# get intersection
output = check_boundary(filtered_med_seg, mask, radius)
# get connected component and centriods
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S)
centroids = centroids[1:]
centroids = sorted(centroids, key=lambda x: x[0])
# check if we have two 2 centriods
if len(centroids) == 2:
# linear regression
slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1])
# get road mask inferred from tree bbox intersection points
road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3)
else:
return None
return road_mask
def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7):
# Using SAM model to predict on the image with a specific prompt
text_prompt = "tree"
masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24,
return_results=True)
# Setting up custom color maps for overlays
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color
cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt)
# Plotting the results
# fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
# fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16)
'''
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(image)
axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7)
axes[1].axis('off')
axes[1].set_title("Sidewalk Mask - Initial")
'''
axes.imshow(image)
axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha)
axes.axis('off')
# axes.set_title("Sidewalk Mask - Refined with Occlusion Handling")
for bbox in boxes:
road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1)
if road_mask is not None:
axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha)
rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1,
edgecolor='r', facecolor='none')
axes.add_patch(rect)
# Ensure the figure canvas is drawn
fig.canvas.draw()
# Now convert it to a NumPy array
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
# Load pretrained model
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
model = SamModel(config=model_config)
model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu')))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move model to device once here instead of in the function
# special methods for gradio
partial_results = {}
def process_pipeline(image, threshold, alpha):
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256])
orig, medsam_seg_prob = process_image(processed_inputs)
medsam_seg = display_image(medsam_seg_prob, threshold)
filtered_med_seg = pipeline(medsam_seg)
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
partial_results["prob"] = medsam_seg_prob
partial_results["orig"] = orig
partial_results["filtered_med_seg"] = filtered_med_seg
return output_image, filled_image
def update_output(image, threshold, alpha):
if "prob" in partial_results and "orig" in partial_results:
medsam_seg_prob = partial_results['prob']
orig = partial_results['orig']
medsam_seg = display_image(medsam_seg_prob, threshold)
filtered_med_seg = pipeline(medsam_seg)
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
partial_results["filtered_med_seg"] = filtered_med_seg
return output_image, filled_image
def update_output_alpha(image, threshold, alpha):
if "prob" in partial_results and "filtered_med_seg" in partial_results:
medsam_seg_prob = partial_results['prob']
orig = partial_results['orig']
filtered_med_seg = partial_results["filtered_med_seg"]
output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
return output_image, filled_image
with gr.Blocks() as app:
gr.Markdown("# Sidewalk Detection with SAM Model")
gr.Markdown("#### by Dan Mao, Kevin Tan")
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil", label="Upload Image")
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold")
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay")
submit_button = gr.Button("Process Image")
with gr.Column():
img_out1 = gr.Image(label="Sidewalk Mask - Initial")
img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling")
gr.ClearButton(components=[img_in, img_out1, img_out2])
# Setting up triggers for changes and button clicks
threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
submit_button.click(
fn=process_pipeline,
inputs=[img_in, threshold, alpha],
outputs=[img_out1, img_out2]
)
app.launch(debug=True)
|