import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from PIL import Image from torchvision.transforms import ToTensor from transformers import SamModel, SamProcessor to_tensor = ToTensor() device = 'cuda' if torch.cuda.is_available() else 'cpu' processor = SamProcessor.from_pretrained('facebook/sam-vit-base') model = SamModel.from_pretrained('hmdliu/sidewalks-seg') model.to(device) def segment_image(image, threshold, x_min, y_min, x_max, y_max): # tolerate TIFF image input image.save('image.png') # init input data prompt = [x_min, y_min, x_max, y_max] inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt') # make prediction outputs = model(pixel_values=inputs['pixel_values'].to(device), input_boxes=inputs['input_boxes'].to(device), multimask_output=False) prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach() pred_mask = (prob_map > threshold).float().numpy() # visualize results plt.figure(figsize=(8, 8)) plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest') plt.axis('off') plt.tight_layout() plt.savefig('prob.png', bbox_inches='tight', pad_inches=0) plt.close() # post-processing ret_image = Image.open('image.png') ret_pred = (Image.open('image.png'), [(pred_mask, 'Sidewalks')]) ret_prob = Image.open('prob.png') return ret_image, ret_pred, ret_prob def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y_max): # tolerate TIFF image input image['background'].save('image.png') # init input data prompt = [x_min, y_min, x_max, y_max] img = Image.open('image.png').convert('RGB') inputs = processor(img, input_boxes=[[prompt]], return_tensors='pt') # make prediction outputs = model(pixel_values=inputs['pixel_values'].to(device), input_boxes=inputs['input_boxes'].to(device), multimask_output=False) prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach() # perform mask guidance guidance_mask = (np.max(np.array(image['layers'][0]), axis=2) != 0).astype(float) enhance_map = prob_map.numpy() + offset * guidance_mask pred_mask = (enhance_map > threshold).astype(float) # visualize results plt.figure(figsize=(8, 8)) plt.imshow(enhance_map, cmap='jet', interpolation='nearest') plt.axis('off') plt.tight_layout() plt.savefig('prob.png', bbox_inches='tight', pad_inches=0) plt.close() # post-processing regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')] return (image['background'], regions), Image.open('prob.png') def segment_image_with_prompt(image, threshold, x_min, y_min, x_max, y_max): # tolerate TIFF image input image['background'].save('image.png') # init input data img = Image.open('image.png').convert('RGB') mask = (np.max(np.array(image['layers'][0]), axis=2) != 0) mask_prompt = to_tensor(mask).float() box_prompt = [[[x_min, y_min, x_max, y_max]]] inputs = processor(img, input_boxes=box_prompt, input_masks=mask_prompt, return_tensors='pt') # make prediction outputs = model(pixel_values=inputs['pixel_values'].to(device), input_boxes=inputs['input_boxes'].to(device), input_masks=mask_prompt.to(device), multimask_output=False) prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach() pred_mask = (prob_map > threshold).float().numpy() # visualize results plt.figure(figsize=(8, 8)) plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest') plt.axis('off') plt.tight_layout() plt.savefig('prob.png', bbox_inches='tight', pad_inches=0) plt.close() # post-processing regions = [(mask, 'Prompt'), (pred_mask, 'Sidewalks')] return (image['background'], regions), Image.open('prob.png') with gr.Blocks() as demo: with gr.Tab('Baseline'): with gr.Row(): with gr.Column(): t1_input = gr.Image(type='pil', label='Input Image') with gr.Row(): t1_x_min = gr.Textbox(value=0, label='x_min') t1_y_min = gr.Textbox(value=0, label='y_min') t1_x_max = gr.Textbox(value=256, label='x_max') t1_y_max = gr.Textbox(value=256, label='y_max') t1_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold') t1_segment = gr.Button('Segment') with gr.Column(): t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction') with gr.Column(): t1_prob_map = gr.Image(type='pil', label='Probability Map') with gr.Tab('Mask Guidance (Best)'): with gr.Row(): with gr.Column(): t2_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image', brush=gr.Brush(default_size='5', color_mode='fixed'), sources=['upload'], transforms=[]) with gr.Row(): t2_x_min = gr.Textbox(value=0, label='x_min') t2_y_min = gr.Textbox(value=0, label='y_min') t2_x_max = gr.Textbox(value=256, label='x_max') t2_y_max = gr.Textbox(value=256, label='y_max') t2_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold') t2_offset = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4, label='Guidance Offset') t2_segment = gr.Button('Segment') with gr.Column(): t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction') with gr.Column(): t2_prob_map = gr.Image(type='pil', label='Probability Map') with gr.Tab('Mask Prompt'): with gr.Row(): with gr.Column(): t3_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image', brush=gr.Brush(default_size='5', color_mode='fixed'), sources=['upload'], transforms=[]) with gr.Row(): t3_x_min = gr.Textbox(value=0, label='x_min') t3_y_min = gr.Textbox(value=0, label='y_min') t3_x_max = gr.Textbox(value=256, label='x_max') t3_y_max = gr.Textbox(value=256, label='y_max') t3_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold') t3_segment = gr.Button('Segment') with gr.Column(): t3_pred = gr.AnnotatedImage(color_map={'Prompt': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction') with gr.Column(): t3_prob_map = gr.Image(type='pil', label='Probability Map') t1_segment.click( segment_image, inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max], outputs=[t1_input, t1_pred, t1_prob_map] ) t2_segment.click( segment_image_with_guidance, inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max], outputs=[t2_pred, t2_prob_map] ) t3_segment.click( segment_image_with_prompt, inputs=[t3_input, t3_thresh, t3_x_min, t3_y_min, t3_x_max, t3_y_max], outputs=[t3_pred, t3_prob_map] ) demo.launch(debug=True, show_error=True)