Spaces:
Sleeping
Sleeping
from samgeo import tms_to_geotiff | |
from samgeo.text_sam import LangSAM | |
sam = LangSAM() | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from matplotlib import pyplot as plt | |
from samgeo.text_sam import LangSAM | |
import cv2 | |
import matplotlib.patches as patches | |
from transformers import SamModel, SamConfig, SamProcessor | |
from math import floor, ceil | |
from matplotlib.colors import LinearSegmentedColormap | |
from samgeo import tms_to_geotiff | |
from samgeo.text_sam import LangSAM | |
# Load the SAM model | |
sam = LangSAM() | |
# methods for sidewalk inferences | |
def get_input_image(image_file, processor, bbox=None): | |
# img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1) | |
img = torch.tensor(np.array(image_file)).permute(2, 0, 1) | |
''' | |
image = Image.open(image_file).convert('RGB') | |
img = np.array(image) | |
''' | |
if bbox is None: | |
bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box | |
# prepare image and prompt for the model | |
inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt") | |
# remove batch dimension which the processor adds by default | |
inputs = {k: v.squeeze(0) for k, v in inputs.items()} | |
inputs["org_img"] = img | |
return inputs | |
def process_image(inputs): | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device), | |
input_boxes=inputs["input_boxes"].unsqueeze(0).to(device), | |
multimask_output=False) | |
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() | |
orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy() | |
return orig, medsam_seg_prob | |
def display_image(medsam_seg_prob, threshold=0.5): | |
medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8) | |
return medsam_seg | |
# output sidewalk with original photo | |
def output_sidewalk(image, medsam_seg, alpha=0.7): | |
# Color for 0: transparent, for 1: blue | |
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors) | |
fig, axes = plt.subplots(1, 1, figsize=(8, 8)) | |
axes.imshow(np.array(image)) | |
axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha) | |
axes.axis('off') | |
# Ensure the figure canvas is drawn | |
fig.canvas.draw() | |
# Now convert it to a NumPy array | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
# methods for smoother sidewalk mask | |
def filter_weak(medsam_seg, size_threshold=10): | |
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8, | |
ltype=cv2.CV_32S) | |
result = np.zeros_like(medsam_seg) | |
for i in range(1, num_labels): | |
if stats[i, cv2.CC_STAT_AREA] >= size_threshold: | |
result[labels == i] = 1 | |
return result | |
def smoothing(mask, kernel_size=(6, 6)): | |
kernel = np.ones(kernel_size, np.uint8) | |
opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
return closing | |
def pipeline(data, size_threshold=25, kernel_size=(9, 9)): | |
result = filter_weak(data, size_threshold) | |
result = smoothing(result, kernel_size) | |
return result | |
# methods for occlusion handling | |
def create_boundary_mask_from_bbox(bbox, array_size, thickness=1): | |
# Create an empty mask with the same dimensions as the array_size | |
mask = np.zeros(array_size, dtype=np.uint8) | |
# Calculate xmin, ymin, xmax, ymax from the bbox | |
xmin, ymin, xmax, ymax = bbox | |
# Ensure the bbox coordinates are within the array bounds to avoid IndexErrors | |
xmin = floor(max(xmin, 0)) | |
xmax = ceil(min(xmax, array_size[1] - 1)) | |
ymin = floor(max(ymin, 0)) | |
ymax = ceil(min(ymax, array_size[0] - 1)) | |
# Draw top and bottom horizontal lines | |
mask[ymin:ymin + thickness, xmin:xmax] = 2 | |
mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2 | |
# Draw left and right vertical lines | |
mask[ymin:ymax, xmin:xmin + thickness] = 2 | |
mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2 | |
return mask | |
def check_boundary(m1, m2, radius=1): | |
# Initialize an output mask of the same shape as m2, filled with zeros | |
boundary_mask = np.zeros_like(m2) | |
# Get the dimensions of the masks | |
rows, cols = m2.shape | |
# Iterate through each pixel in the m2 mask | |
for r in range(rows): | |
for c in range(cols): | |
# Check if the current pixel is a 'tree' pixel | |
if m2[r, c] == 2: | |
# Initialize a flag to check for at least one adjacent 'sidewalk' | |
found_sidewalk = 0 | |
# Check the square around the current pixel with given radius | |
for dr in range(-radius, radius + 1): | |
for dc in range(-radius, radius + 1): | |
# Calculate the neighbor's position | |
nr, nc = r + dr, c + dc | |
# Ensure we're not out of bounds and we're not checking the center pixel itself | |
if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0): | |
if m1[nr, nc] == 1: | |
found_sidewalk += 1 | |
boundary_mask[r, c] = found_sidewalk | |
return boundary_mask | |
def linear_regression_two_points(point1, point2): | |
# Create arrays of x and y values | |
x = np.array([point1[0], point2[0]]) | |
y = np.array([point1[1], point2[1]]) | |
# Perform linear regression: np.polyfit returns the slope and intercept | |
m, b = np.polyfit(x, y, 1) | |
return m, b, x, y | |
def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)): | |
# Create a blank black image (all zeros) | |
image = np.zeros(image_size, dtype=np.uint8) | |
# Define x values within the specified range x1 to x2 | |
x_values = np.array(range(x1, x2 + 1)) | |
# Calculate corresponding y values using the slope and intercept | |
y_values = (slope * x_values + intercept).astype(int) | |
# Draw the road line with the specified width | |
for i in range(len(x_values)): | |
if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries | |
cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line | |
return image | |
def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1): | |
array_size = (256, 256) # Define the size of the 2D mask | |
mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1) | |
# get intersection | |
output = check_boundary(filtered_med_seg, mask, radius) | |
# get connected component and centriods | |
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S) | |
centroids = centroids[1:] | |
centroids = sorted(centroids, key=lambda x: x[0]) | |
# check if we have two 2 centriods | |
if len(centroids) == 2: | |
# linear regression | |
slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1]) | |
# get road mask inferred from tree bbox intersection points | |
road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3) | |
else: | |
return None | |
return road_mask | |
def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7): | |
# Using SAM model to predict on the image with a specific prompt | |
text_prompt = "tree" | |
masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24, | |
return_results=True) | |
# Setting up custom color maps for overlays | |
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color | |
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors) | |
colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color | |
cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt) | |
# Plotting the results | |
# fig, axes = plt.subplots(1, 3, figsize=(18, 6)) | |
fig, axes = plt.subplots(1, 1, figsize=(8, 8)) | |
# fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16) | |
''' | |
axes[0].imshow(image) | |
axes[0].set_title("Original Image") | |
axes[0].axis('off') | |
axes[1].imshow(image) | |
axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7) | |
axes[1].axis('off') | |
axes[1].set_title("Sidewalk Mask - Initial") | |
''' | |
axes.imshow(image) | |
axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha) | |
axes.axis('off') | |
# axes.set_title("Sidewalk Mask - Refined with Occlusion Handling") | |
for bbox in boxes: | |
road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1) | |
if road_mask is not None: | |
axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha) | |
rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, | |
edgecolor='r', facecolor='none') | |
axes.add_patch(rect) | |
# Ensure the figure canvas is drawn | |
fig.canvas.draw() | |
# Now convert it to a NumPy array | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
# Load pretrained model | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
model = SamModel(config=model_config) | |
model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu'))) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) # Move model to device once here instead of in the function | |
# special methods for gradio | |
partial_results = {} | |
def process_pipeline(image, threshold, alpha): | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256]) | |
orig, medsam_seg_prob = process_image(processed_inputs) | |
medsam_seg = display_image(medsam_seg_prob, threshold) | |
filtered_med_seg = pipeline(medsam_seg) | |
output_image = output_sidewalk(orig, filtered_med_seg, alpha) | |
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) | |
partial_results["prob"] = medsam_seg_prob | |
partial_results["orig"] = orig | |
partial_results["filtered_med_seg"] = filtered_med_seg | |
return output_image, filled_image | |
def update_output(image, threshold, alpha): | |
if "prob" in partial_results and "orig" in partial_results: | |
medsam_seg_prob = partial_results['prob'] | |
orig = partial_results['orig'] | |
medsam_seg = display_image(medsam_seg_prob, threshold) | |
filtered_med_seg = pipeline(medsam_seg) | |
output_image = output_sidewalk(orig, filtered_med_seg, alpha) | |
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) | |
partial_results["filtered_med_seg"] = filtered_med_seg | |
return output_image, filled_image | |
def update_output_alpha(image, threshold, alpha): | |
if "prob" in partial_results and "filtered_med_seg" in partial_results: | |
medsam_seg_prob = partial_results['prob'] | |
orig = partial_results['orig'] | |
filtered_med_seg = partial_results["filtered_med_seg"] | |
output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha) | |
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) | |
return output_image, filled_image | |
with gr.Blocks() as app: | |
gr.Markdown("# Sidewalk Detection with SAM Model") | |
gr.Markdown("#### by Dan Mao, Kevin Tan") | |
with gr.Row(): | |
with gr.Column(): | |
img_in = gr.Image(type="pil", label="Upload Image") | |
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold") | |
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay") | |
submit_button = gr.Button("Process Image") | |
with gr.Column(): | |
img_out1 = gr.Image(label="Sidewalk Mask - Initial") | |
img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling") | |
gr.ClearButton(components=[img_in, img_out1, img_out2]) | |
# Setting up triggers for changes and button clicks | |
threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2]) | |
alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2]) | |
submit_button.click( | |
fn=process_pipeline, | |
inputs=[img_in, threshold, alpha], | |
outputs=[img_out1, img_out2] | |
) | |
app.launch(debug=True) | |