FindMySidewalk / app.py
Rishie Nandhan
Using bounding box prompts
0345fe3
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import SamConfig, SamProcessor, SamModel
# Load the model and processor
print('Status Update: Loading SAM Model ...')
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
sidewalk_model = SamModel(config=model_config)
# Update the model by loading the weights from saved file.
print('Status Update: Loading SAM pre-train weights ...')
checkpoint = torch.load("checkpoints/checkpoint_at_epoch_1.pt", map_location=torch.device('cpu')) #('mps')
sidewalk_model.load_state_dict(checkpoint["model"])
# Set the device
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"
sidewalk_model.to(device)
# print('Status Update: Using GPU.')
print('Status Update: FindMySidewalk Ready for inference ...')
# Generate bounding box prompt for SAM
def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256):
# add perturbation if inputted bounding box coordinates
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
def segment_sidewalk(image):
test_image = Image.fromarray(image).convert("RGB")
# Keep a copy of original image for display
original_image = test_image.copy()
# # Create grid of points for prompting
# array_size = 256
# grid_size = 7
# x = np.linspace(0, array_size - 1, grid_size)
# y = np.linspace(0, array_size - 1, grid_size)
# xv, yv = np.meshgrid(x, y)
# input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
# input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
# obtain bounding box prompt over entire image
prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1])
# prepare image for the model
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
# Convert dtype to float32 as the MPS framework doesn't support float64
inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
sidewalk_model.eval()
with torch.no_grad():
outputs = sidewalk_model(**inputs, multimask_output=False)
# apply sigmoid and convert soft mask to hard mask
test_image_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
test_image_prob = test_image_prob.cpu().numpy().squeeze()
pred_mask_np = (test_image_prob > 0.85).astype(np.uint8)
segmented_image = Image.fromarray(pred_mask_np * 255) # Convert mask to an image
return original_image, segmented_image
# Setup Gradio interface
demo = gr.Interface(
fn=segment_sidewalk,
inputs=gr.Image(),
outputs=[
gr.Image(type="pil", label="Original Image"),
gr.Image(type="pil", label="Segmented Mask")
],
title="Sidewalk Segmentation using SAM",
description="Upload a satellite image tile to segment sidewalks."
)
# Run the interface
demo.launch()