FindMySidewalk / app.py
Rishie Nandhan
Using bounding box prompts
0345fe3
raw
history blame
3.32 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import SamConfig, SamProcessor, SamModel
# Load the model and processor
print('Status Update: Loading SAM Model ...')
model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Create an instance of the model architecture with the loaded configuration
sidewalk_model = SamModel(config=model_config)
# Update the model by loading the weights from saved file.
print('Status Update: Loading SAM pre-train weights ...')
checkpoint = torch.load("checkpoints/checkpoint_at_epoch_1.pt", map_location=torch.device('cpu')) #('mps')
sidewalk_model.load_state_dict(checkpoint["model"])
# Set the device
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cpu"
sidewalk_model.to(device)
# print('Status Update: Using GPU.')
print('Status Update: FindMySidewalk Ready for inference ...')
# Generate bounding box prompt for SAM
def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256):
# add perturbation if inputted bounding box coordinates
x_min = max(0, x_min - np.random.randint(0, 20))
x_max = min(W, x_max + np.random.randint(0, 20))
y_min = max(0, y_min - np.random.randint(0, 20))
y_max = min(H, y_max + np.random.randint(0, 20))
bbox = [x_min, y_min, x_max, y_max]
return bbox
def segment_sidewalk(image):
test_image = Image.fromarray(image).convert("RGB")
# Keep a copy of original image for display
original_image = test_image.copy()
# # Create grid of points for prompting
# array_size = 256
# grid_size = 7
# x = np.linspace(0, array_size - 1, grid_size)
# y = np.linspace(0, array_size - 1, grid_size)
# xv, yv = np.meshgrid(x, y)
# input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())]
# input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2)
# obtain bounding box prompt over entire image
prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1])
# prepare image for the model
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt")
# Convert dtype to float32 as the MPS framework doesn't support float64
inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()}
sidewalk_model.eval()
with torch.no_grad():
outputs = sidewalk_model(**inputs, multimask_output=False)
# apply sigmoid and convert soft mask to hard mask
test_image_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
test_image_prob = test_image_prob.cpu().numpy().squeeze()
pred_mask_np = (test_image_prob > 0.85).astype(np.uint8)
segmented_image = Image.fromarray(pred_mask_np * 255) # Convert mask to an image
return original_image, segmented_image
# Setup Gradio interface
demo = gr.Interface(
fn=segment_sidewalk,
inputs=gr.Image(),
outputs=[
gr.Image(type="pil", label="Original Image"),
gr.Image(type="pil", label="Segmented Mask")
],
title="Sidewalk Segmentation using SAM",
description="Upload a satellite image tile to segment sidewalks."
)
# Run the interface
demo.launch()