lm7154's picture
Create app.py
0708936 verified
from samgeo import tms_to_geotiff
from samgeo.text_sam import LangSAM
sam = LangSAM()
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from matplotlib import pyplot as plt
from samgeo.text_sam import LangSAM
import cv2
import matplotlib.patches as patches
from transformers import SamModel, SamConfig, SamProcessor
from math import floor, ceil
from matplotlib.colors import LinearSegmentedColormap
from samgeo import tms_to_geotiff
from samgeo.text_sam import LangSAM
# Load the SAM model
sam = LangSAM()
# methods for sidewalk inferences
def get_input_image(image_file, processor, bbox=None):
# img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1)
img = torch.tensor(np.array(image_file)).permute(2, 0, 1)
'''
image = Image.open(image_file).convert('RGB')
img = np.array(image)
'''
if bbox is None:
bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box
# prepare image and prompt for the model
inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt")
# remove batch dimension which the processor adds by default
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["org_img"] = img
return inputs
def process_image(inputs):
model.eval()
with torch.no_grad():
outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device),
input_boxes=inputs["input_boxes"].unsqueeze(0).to(device),
multimask_output=False)
medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy()
return orig, medsam_seg_prob
def display_image(medsam_seg_prob, threshold=0.5):
medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8)
return medsam_seg
# output sidewalk with original photo
def output_sidewalk(image, medsam_seg, alpha=0.7):
# Color for 0: transparent, for 1: blue
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
axes.imshow(np.array(image))
axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha)
axes.axis('off')
# Ensure the figure canvas is drawn
fig.canvas.draw()
# Now convert it to a NumPy array
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
# methods for smoother sidewalk mask
def filter_weak(medsam_seg, size_threshold=10):
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8,
ltype=cv2.CV_32S)
result = np.zeros_like(medsam_seg)
for i in range(1, num_labels):
if stats[i, cv2.CC_STAT_AREA] >= size_threshold:
result[labels == i] = 1
return result
def smoothing(mask, kernel_size=(6, 6)):
kernel = np.ones(kernel_size, np.uint8)
opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return closing
def pipeline(data, size_threshold=25, kernel_size=(9, 9)):
result = filter_weak(data, size_threshold)
result = smoothing(result, kernel_size)
return result
# methods for occlusion handling
def create_boundary_mask_from_bbox(bbox, array_size, thickness=1):
# Create an empty mask with the same dimensions as the array_size
mask = np.zeros(array_size, dtype=np.uint8)
# Calculate xmin, ymin, xmax, ymax from the bbox
xmin, ymin, xmax, ymax = bbox
# Ensure the bbox coordinates are within the array bounds to avoid IndexErrors
xmin = floor(max(xmin, 0))
xmax = ceil(min(xmax, array_size[1] - 1))
ymin = floor(max(ymin, 0))
ymax = ceil(min(ymax, array_size[0] - 1))
# Draw top and bottom horizontal lines
mask[ymin:ymin + thickness, xmin:xmax] = 2
mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2
# Draw left and right vertical lines
mask[ymin:ymax, xmin:xmin + thickness] = 2
mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2
return mask
def check_boundary(m1, m2, radius=1):
# Initialize an output mask of the same shape as m2, filled with zeros
boundary_mask = np.zeros_like(m2)
# Get the dimensions of the masks
rows, cols = m2.shape
# Iterate through each pixel in the m2 mask
for r in range(rows):
for c in range(cols):
# Check if the current pixel is a 'tree' pixel
if m2[r, c] == 2:
# Initialize a flag to check for at least one adjacent 'sidewalk'
found_sidewalk = 0
# Check the square around the current pixel with given radius
for dr in range(-radius, radius + 1):
for dc in range(-radius, radius + 1):
# Calculate the neighbor's position
nr, nc = r + dr, c + dc
# Ensure we're not out of bounds and we're not checking the center pixel itself
if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0):
if m1[nr, nc] == 1:
found_sidewalk += 1
boundary_mask[r, c] = found_sidewalk
return boundary_mask
def linear_regression_two_points(point1, point2):
# Create arrays of x and y values
x = np.array([point1[0], point2[0]])
y = np.array([point1[1], point2[1]])
# Perform linear regression: np.polyfit returns the slope and intercept
m, b = np.polyfit(x, y, 1)
return m, b, x, y
def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)):
# Create a blank black image (all zeros)
image = np.zeros(image_size, dtype=np.uint8)
# Define x values within the specified range x1 to x2
x_values = np.array(range(x1, x2 + 1))
# Calculate corresponding y values using the slope and intercept
y_values = (slope * x_values + intercept).astype(int)
# Draw the road line with the specified width
for i in range(len(x_values)):
if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries
cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line
return image
def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1):
array_size = (256, 256) # Define the size of the 2D mask
mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1)
# get intersection
output = check_boundary(filtered_med_seg, mask, radius)
# get connected component and centriods
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S)
centroids = centroids[1:]
centroids = sorted(centroids, key=lambda x: x[0])
# check if we have two 2 centriods
if len(centroids) == 2:
# linear regression
slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1])
# get road mask inferred from tree bbox intersection points
road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3)
else:
return None
return road_mask
def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7):
# Using SAM model to predict on the image with a specific prompt
text_prompt = "tree"
masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24,
return_results=True)
# Setting up custom color maps for overlays
colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color
cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)
colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color
cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt)
# Plotting the results
# fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig, axes = plt.subplots(1, 1, figsize=(8, 8))
# fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16)
'''
axes[0].imshow(image)
axes[0].set_title("Original Image")
axes[0].axis('off')
axes[1].imshow(image)
axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7)
axes[1].axis('off')
axes[1].set_title("Sidewalk Mask - Initial")
'''
axes.imshow(image)
axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha)
axes.axis('off')
# axes.set_title("Sidewalk Mask - Refined with Occlusion Handling")
for bbox in boxes:
road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1)
if road_mask is not None:
axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha)
rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1,
edgecolor='r', facecolor='none')
axes.add_patch(rect)
# Ensure the figure canvas is drawn
fig.canvas.draw()
# Now convert it to a NumPy array
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
# Load pretrained model
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
model = SamModel(config=model_config)
model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu')))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move model to device once here instead of in the function
# special methods for gradio
partial_results = {}
def process_pipeline(image, threshold, alpha):
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256])
orig, medsam_seg_prob = process_image(processed_inputs)
medsam_seg = display_image(medsam_seg_prob, threshold)
filtered_med_seg = pipeline(medsam_seg)
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
partial_results["prob"] = medsam_seg_prob
partial_results["orig"] = orig
partial_results["filtered_med_seg"] = filtered_med_seg
return output_image, filled_image
def update_output(image, threshold, alpha):
if "prob" in partial_results and "orig" in partial_results:
medsam_seg_prob = partial_results['prob']
orig = partial_results['orig']
medsam_seg = display_image(medsam_seg_prob, threshold)
filtered_med_seg = pipeline(medsam_seg)
output_image = output_sidewalk(orig, filtered_med_seg, alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
partial_results["filtered_med_seg"] = filtered_med_seg
return output_image, filled_image
def update_output_alpha(image, threshold, alpha):
if "prob" in partial_results and "filtered_med_seg" in partial_results:
medsam_seg_prob = partial_results['prob']
orig = partial_results['orig']
filtered_med_seg = partial_results["filtered_med_seg"]
output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha)
filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha)
return output_image, filled_image
with gr.Blocks() as app:
gr.Markdown("# Sidewalk Detection with SAM Model")
gr.Markdown("#### by Dan Mao, Kevin Tan")
with gr.Row():
with gr.Column():
img_in = gr.Image(type="pil", label="Upload Image")
threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold")
alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay")
submit_button = gr.Button("Process Image")
with gr.Column():
img_out1 = gr.Image(label="Sidewalk Mask - Initial")
img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling")
gr.ClearButton(components=[img_in, img_out1, img_out2])
# Setting up triggers for changes and button clicks
threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2])
submit_button.click(
fn=process_pipeline,
inputs=[img_in, threshold, alpha],
outputs=[img_out1, img_out2]
)
app.launch(debug=True)