import gradio as gr import torch import numpy as np from PIL import Image from transformers import SamConfig, SamProcessor, SamModel # Load the model and processor print('Status Update: Loading SAM Model ...') model_config = SamConfig.from_pretrained("facebook/sam-vit-base") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") # Create an instance of the model architecture with the loaded configuration sidewalk_model = SamModel(config=model_config) # Update the model by loading the weights from saved file. print('Status Update: Loading SAM pre-train weights ...') checkpoint = torch.load("checkpoints/checkpoint_at_epoch_1.pt", map_location=torch.device('cpu')) #('mps') sidewalk_model.load_state_dict(checkpoint["model"]) # Set the device # device = "mps" if torch.backends.mps.is_available() else "cpu" device = "cpu" sidewalk_model.to(device) # print('Status Update: Using GPU.') print('Status Update: FindMySidewalk Ready for inference ...') # Generate bounding box prompt for SAM def get_bounding_box(W = 256, H = 256, x_min = 0, y_min = 0, x_max = 256, y_max = 256): # add perturbation if inputted bounding box coordinates x_min = max(0, x_min - np.random.randint(0, 20)) x_max = min(W, x_max + np.random.randint(0, 20)) y_min = max(0, y_min - np.random.randint(0, 20)) y_max = min(H, y_max + np.random.randint(0, 20)) bbox = [x_min, y_min, x_max, y_max] return bbox def segment_sidewalk(image): test_image = Image.fromarray(image).convert("RGB") # Keep a copy of original image for display original_image = test_image.copy() # # Create grid of points for prompting # array_size = 256 # grid_size = 7 # x = np.linspace(0, array_size - 1, grid_size) # y = np.linspace(0, array_size - 1, grid_size) # xv, yv = np.meshgrid(x, y) # input_points = [[[int(x), int(y)] for x, y in zip(x_row, y_row)] for x_row, y_row in zip(xv.tolist(), yv.tolist())] # input_points = torch.tensor(input_points).view(1, 1, grid_size * grid_size, 2) # obtain bounding box prompt over entire image prompt = get_bounding_box(test_image.size[0], test_image.size[1], 0, 0, test_image.size[0], test_image.size[1]) # prepare image for the model inputs = processor(test_image, input_boxes=[[prompt]], return_tensors="pt") # Convert dtype to float32 as the MPS framework doesn't support float64 inputs = {k: v.to(torch.float32).to(device) for k, v in inputs.items()} sidewalk_model.eval() with torch.no_grad(): outputs = sidewalk_model(**inputs, multimask_output=False) # apply sigmoid and convert soft mask to hard mask test_image_prob = torch.sigmoid(outputs.pred_masks.squeeze(1)) test_image_prob = test_image_prob.cpu().numpy().squeeze() pred_mask_np = (test_image_prob > 0.85).astype(np.uint8) segmented_image = Image.fromarray(pred_mask_np * 255) # Convert mask to an image return original_image, segmented_image # Setup Gradio interface demo = gr.Interface( fn=segment_sidewalk, inputs=gr.Image(), outputs=[ gr.Image(type="pil", label="Original Image"), gr.Image(type="pil", label="Segmented Mask") ], title="Sidewalk Segmentation using SAM", description="Upload a satellite image tile to segment sidewalks." ) # Run the interface demo.launch()