Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import SamConfig, SamProcessor, SamModel | |
# Load the model and processor | |
print('Status Update: Loading SAM Model ...') | |
model_config = SamConfig.from_pretrained("facebook/sam-vit-base") | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Create an instance of the model architecture with the loaded configuration | |
sidewalk_model = SamModel(config=model_config) | |
# Update the model by loading the weights from saved file. | |
print('Status Update: Loading SAM pre-train weights ...') | |
checkpoint = torch.load("checkpoints/checkpoint_at_epoch_1.pt", map_location=torch.device('cpu')) #('mps') | |
sidewalk_model.load_state_dict(checkpoint["model"]) | |
# Set the device | |
# device = "mps" if torch.backends.mps.is_available() else "cpu" | |
device = "cpu" | |
sidewalk_model.to(device) | |
# print('Status Update: Using GPU.') | |
print('Status Update: FindMySidewalk Ready for inference ...') | |
# Generate bounding box prompt for SAM | |
def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256): | |
# add perturbation if inputted bounding box coordinates | |
x_min = max(0, x_min - np.random.randint(0, 20)) | |
x_max = min(W, x_max + np.random.randint(0, 20)) | |
y_min = max(0, y_min - np.random.randint(0, 20)) | |
y_max = min(H, y_max + np.random.randint(0, 20)) | |
bbox = [x_min, y_min, x_max, y_max] | |
return bbox | |
def segment_sidewalk(image): | |
test_image = Image.fromarray(image).convert("RGB") | |
# Keep a copy of original image for display | |
original_image = test_image.copy() | |
# # Create grid of points for prompting | |
# array_size = 256 | |
# grid_size = 7 | |
# x = np.linspace(0, array_size - 1, grid_size) | |
# y = np.linspace(0, array_size - 1, grid_size) | |
# xv, yv = np.meshgrid(x, y) | |
# input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())] | |
# input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2) | |
# obtain bounding box prompt over entire image | |
prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1]) | |
# prepare image for the model | |
inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt") | |
# Convert dtype to float32 as the MPS framework doesn't support float64 | |
inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()} | |
sidewalk_model.eval() | |
with torch.no_grad(): | |
outputs = sidewalk_model(**inputs, multimask_output=False) | |
# apply sigmoid and convert soft mask to hard mask | |
test_image_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) | |
test_image_prob = test_image_prob.cpu().numpy().squeeze() | |
pred_mask_np = (test_image_prob > 0.85).astype(np.uint8) | |
segmented_image = Image.fromarray(pred_mask_np * 255) # Convert mask to an image | |
return original_image, segmented_image | |
# Setup Gradio interface | |
demo = gr.Interface( | |
fn=segment_sidewalk, | |
inputs=gr.Image(), | |
outputs=[ | |
gr.Image(type="pil", label="Original Image"), | |
gr.Image(type="pil", label="Segmented Mask") | |
], | |
title="Sidewalk Segmentation using SAM", | |
description="Upload a satellite image tile to segment sidewalks." | |
) | |
# Run the interface | |
demo.launch() | |