from samgeo import tms_to_geotiff from samgeo.text_sam import LangSAM sam = LangSAM() import gradio as gr import numpy as np from PIL import Image import torch from torchvision import transforms from matplotlib import pyplot as plt from samgeo.text_sam import LangSAM import cv2 import matplotlib.patches as patches from transformers import SamModel, SamConfig, SamProcessor from math import floor, ceil from matplotlib.colors import LinearSegmentedColormap from samgeo import tms_to_geotiff from samgeo.text_sam import LangSAM # Load the SAM model sam = LangSAM() # methods for sidewalk inferences def get_input_image(image_file, processor, bbox=None): # img = torch.tensor(np.array(Image.open(image_file))).permute(2, 0, 1) img = torch.tensor(np.array(image_file)).permute(2, 0, 1) ''' image = Image.open(image_file).convert('RGB') img = np.array(image) ''' if bbox is None: bbox = [0, 0, img.shape[1], img.shape[0]] # Use image dimensions as bounding box # prepare image and prompt for the model inputs = processor(img, input_boxes=[[bbox]], return_tensors="pt") # remove batch dimension which the processor adds by default inputs = {k: v.squeeze(0) for k, v in inputs.items()} inputs["org_img"] = img return inputs def process_image(inputs): model.eval() with torch.no_grad(): outputs = model(pixel_values=inputs["pixel_values"].unsqueeze(0).to(device), input_boxes=inputs["input_boxes"].unsqueeze(0).to(device), multimask_output=False) medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze() orig = inputs["org_img"].permute(1, 2, 0).cpu().numpy() return orig, medsam_seg_prob def display_image(medsam_seg_prob, threshold=0.5): medsam_seg = (medsam_seg_prob > threshold).astype(np.uint8) return medsam_seg # output sidewalk with original photo def output_sidewalk(image, medsam_seg, alpha=0.7): # Color for 0: transparent, for 1: blue colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # RGBA tuples cmap = LinearSegmentedColormap.from_list("custom_cmap", colors) fig, axes = plt.subplots(1, 1, figsize=(8, 8)) axes.imshow(np.array(image)) axes.imshow(np.array(medsam_seg), cmap=cmap, alpha=alpha) axes.axis('off') # Ensure the figure canvas is drawn fig.canvas.draw() # Now convert it to a NumPy array data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data # methods for smoother sidewalk mask def filter_weak(medsam_seg, size_threshold=10): num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(medsam_seg, connectivity=8, ltype=cv2.CV_32S) result = np.zeros_like(medsam_seg) for i in range(1, num_labels): if stats[i, cv2.CC_STAT_AREA] >= size_threshold: result[labels == i] = 1 return result def smoothing(mask, kernel_size=(6, 6)): kernel = np.ones(kernel_size, np.uint8) opening = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return closing def pipeline(data, size_threshold=25, kernel_size=(9, 9)): result = filter_weak(data, size_threshold) result = smoothing(result, kernel_size) return result # methods for occlusion handling def create_boundary_mask_from_bbox(bbox, array_size, thickness=1): # Create an empty mask with the same dimensions as the array_size mask = np.zeros(array_size, dtype=np.uint8) # Calculate xmin, ymin, xmax, ymax from the bbox xmin, ymin, xmax, ymax = bbox # Ensure the bbox coordinates are within the array bounds to avoid IndexErrors xmin = floor(max(xmin, 0)) xmax = ceil(min(xmax, array_size[1] - 1)) ymin = floor(max(ymin, 0)) ymax = ceil(min(ymax, array_size[0] - 1)) # Draw top and bottom horizontal lines mask[ymin:ymin + thickness, xmin:xmax] = 2 mask[ymax - thickness + 1:ymax + 1, xmin:xmax] = 2 # Draw left and right vertical lines mask[ymin:ymax, xmin:xmin + thickness] = 2 mask[ymin:ymax, xmax - thickness + 1:xmax + 1] = 2 return mask def check_boundary(m1, m2, radius=1): # Initialize an output mask of the same shape as m2, filled with zeros boundary_mask = np.zeros_like(m2) # Get the dimensions of the masks rows, cols = m2.shape # Iterate through each pixel in the m2 mask for r in range(rows): for c in range(cols): # Check if the current pixel is a 'tree' pixel if m2[r, c] == 2: # Initialize a flag to check for at least one adjacent 'sidewalk' found_sidewalk = 0 # Check the square around the current pixel with given radius for dr in range(-radius, radius + 1): for dc in range(-radius, radius + 1): # Calculate the neighbor's position nr, nc = r + dr, c + dc # Ensure we're not out of bounds and we're not checking the center pixel itself if 0 <= nr < rows and 0 <= nc < cols and (dr != 0 or dc != 0): if m1[nr, nc] == 1: found_sidewalk += 1 boundary_mask[r, c] = found_sidewalk return boundary_mask def linear_regression_two_points(point1, point2): # Create arrays of x and y values x = np.array([point1[0], point2[0]]) y = np.array([point1[1], point2[1]]) # Perform linear regression: np.polyfit returns the slope and intercept m, b = np.polyfit(x, y, 1) return m, b, x, y def generate_road_mask(x1, x2, slope, intercept, road_width=5, image_size=(256, 256)): # Create a blank black image (all zeros) image = np.zeros(image_size, dtype=np.uint8) # Define x values within the specified range x1 to x2 x_values = np.array(range(x1, x2 + 1)) # Calculate corresponding y values using the slope and intercept y_values = (slope * x_values + intercept).astype(int) # Draw the road line with the specified width for i in range(len(x_values)): if 0 <= y_values[i] < image_size[0]: # Check if the y-value is within the image boundaries cv2.circle(image, (x_values[i], y_values[i]), road_width // 2, 1, -1) # Draw circles to create a thick line return image def get_road_mask_per_bbox(filtered_med_seg, bbox, radius=1): array_size = (256, 256) # Define the size of the 2D mask mask = create_boundary_mask_from_bbox(bbox, array_size, thickness=1) # get intersection output = check_boundary(filtered_med_seg, mask, radius) # get connected component and centriods num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(output, 8, cv2.CV_32S) centroids = centroids[1:] centroids = sorted(centroids, key=lambda x: x[0]) # check if we have two 2 centriods if len(centroids) == 2: # linear regression slope, intercept, x, y = linear_regression_two_points(centroids[0], centroids[1]) # get road mask inferred from tree bbox intersection points road_mask = generate_road_mask(int(x[0]), int(x[1]), slope, intercept, 3) else: return None return road_mask def analyze_sidewalk(sam, filtered_med_seg, image, alpha=0.7): # Using SAM model to predict on the image with a specific prompt text_prompt = "tree" masks, boxes, labels, logits = sam.predict(image, text_prompt, box_threshold=0.24, text_threshold=0.24, return_results=True) # Setting up custom color maps for overlays colors = [(0, 0, 0, 0), (0, 0, 1, 1)] # Blue color cmap = LinearSegmentedColormap.from_list("custom_cmap", colors) colors_alt = [(0, 0, 0, 0), (0, 1, 0, 1)] # Green color cmap_alt = LinearSegmentedColormap.from_list("custom_cmap", colors_alt) # Plotting the results # fig, axes = plt.subplots(1, 3, figsize=(18, 6)) fig, axes = plt.subplots(1, 1, figsize=(8, 8)) # fig.suptitle(f"Sidewalk Detection with SAM Model \n{image}", fontsize=16) ''' axes[0].imshow(image) axes[0].set_title("Original Image") axes[0].axis('off') axes[1].imshow(image) axes[1].imshow(filtered_med_seg, cmap=cmap, alpha=0.7) axes[1].axis('off') axes[1].set_title("Sidewalk Mask - Initial") ''' axes.imshow(image) axes.imshow(filtered_med_seg, cmap=cmap, alpha=alpha) axes.axis('off') # axes.set_title("Sidewalk Mask - Refined with Occlusion Handling") for bbox in boxes: road_mask = get_road_mask_per_bbox(filtered_med_seg, bbox.tolist(), 1) if road_mask is not None: axes.imshow(road_mask, cmap=cmap_alt, alpha=alpha) rect = patches.Rectangle((bbox[0], bbox[1]), bbox[2] - bbox[0], bbox[3] - bbox[1], linewidth=1, edgecolor='r', facecolor='none') axes.add_patch(rect) # Ensure the figure canvas is drawn fig.canvas.draw() # Now convert it to a NumPy array data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data # Load pretrained model model_config = SamConfig.from_pretrained("facebook/sam-vit-base") model = SamModel(config=model_config) model.load_state_dict(torch.load("model_checkpoint_final1.pth", map_location=torch.device('cpu'))) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Move model to device once here instead of in the function # special methods for gradio partial_results = {} def process_pipeline(image, threshold, alpha): processor = SamProcessor.from_pretrained("facebook/sam-vit-base") processed_inputs = get_input_image(image, processor, bbox=[0, 0, 256, 256]) orig, medsam_seg_prob = process_image(processed_inputs) medsam_seg = display_image(medsam_seg_prob, threshold) filtered_med_seg = pipeline(medsam_seg) output_image = output_sidewalk(orig, filtered_med_seg, alpha) filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) partial_results["prob"] = medsam_seg_prob partial_results["orig"] = orig partial_results["filtered_med_seg"] = filtered_med_seg return output_image, filled_image def update_output(image, threshold, alpha): if "prob" in partial_results and "orig" in partial_results: medsam_seg_prob = partial_results['prob'] orig = partial_results['orig'] medsam_seg = display_image(medsam_seg_prob, threshold) filtered_med_seg = pipeline(medsam_seg) output_image = output_sidewalk(orig, filtered_med_seg, alpha) filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) partial_results["filtered_med_seg"] = filtered_med_seg return output_image, filled_image def update_output_alpha(image, threshold, alpha): if "prob" in partial_results and "filtered_med_seg" in partial_results: medsam_seg_prob = partial_results['prob'] orig = partial_results['orig'] filtered_med_seg = partial_results["filtered_med_seg"] output_image = output_sidewalk(orig, filtered_med_seg, alpha=alpha) filled_image = analyze_sidewalk(sam, filtered_med_seg, image, alpha=alpha) return output_image, filled_image with gr.Blocks() as app: gr.Markdown("# Sidewalk Detection with SAM Model") gr.Markdown("#### by Dan Mao, Kevin Tan") with gr.Row(): with gr.Column(): img_in = gr.Image(type="pil", label="Upload Image") threshold = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label="Threshold") alpha = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.7, label="Alpha for Mask Overlay") submit_button = gr.Button("Process Image") with gr.Column(): img_out1 = gr.Image(label="Sidewalk Mask - Initial") img_out2 = gr.Image(label="Sidewalk Mask - Refine with Occlusion Handling") gr.ClearButton(components=[img_in, img_out1, img_out2]) # Setting up triggers for changes and button clicks threshold.change(fn=update_output, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2]) alpha.change(fn=update_output_alpha, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2]) submit_button.click( fn=process_pipeline, inputs=[img_in, threshold, alpha], outputs=[img_out1, img_out2] ) app.launch(debug=True)