sidewalks-seg / app.py
hmdliu's picture
Add mask prompt
fadb2ab
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import SamModel, SamProcessor
to_tensor = ToTensor()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
model.to(device)
def segment_image(image, threshold, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image.save('image.png')
# init input data
prompt = [x_min, y_min, x_max, y_max]
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
pred_mask = (prob_map > threshold).float().numpy()
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
ret_image = Image.open('image.png')
ret_pred = (Image.open('image.png'), [(pred_mask, 'Sidewalks')])
ret_prob = Image.open('prob.png')
return ret_image, ret_pred, ret_prob
def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image['background'].save('image.png')
# init input data
prompt = [x_min, y_min, x_max, y_max]
img = Image.open('image.png').convert('RGB')
inputs = processor(img, input_boxes=[[prompt]], return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
# perform mask guidance
guidance_mask = (np.max(np.array(image['layers'][0]), axis=2) != 0).astype(float)
enhance_map = prob_map.numpy() + offset * guidance_mask
pred_mask = (enhance_map > threshold).astype(float)
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(enhance_map, cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
return (image['background'], regions), Image.open('prob.png')
def segment_image_with_prompt(image, threshold, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image['background'].save('image.png')
# init input data
img = Image.open('image.png').convert('RGB')
mask = (np.max(np.array(image['layers'][0]), axis=2) != 0)
mask_prompt = to_tensor(mask).float()
box_prompt = [[[x_min, y_min, x_max, y_max]]]
inputs = processor(img, input_boxes=box_prompt,
input_masks=mask_prompt, return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
input_masks=mask_prompt.to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
pred_mask = (prob_map > threshold).float().numpy()
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
regions = [(mask, 'Prompt'), (pred_mask, 'Sidewalks')]
return (image['background'], regions), Image.open('prob.png')
with gr.Blocks() as demo:
with gr.Tab('Baseline'):
with gr.Row():
with gr.Column():
t1_input = gr.Image(type='pil', label='Input Image')
with gr.Row():
t1_x_min = gr.Textbox(value=0, label='x_min')
t1_y_min = gr.Textbox(value=0, label='y_min')
t1_x_max = gr.Textbox(value=256, label='x_max')
t1_y_max = gr.Textbox(value=256, label='y_max')
t1_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t1_segment = gr.Button('Segment')
with gr.Column():
t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t1_prob_map = gr.Image(type='pil', label='Probability Map')
with gr.Tab('Mask Guidance (Best)'):
with gr.Row():
with gr.Column():
t2_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
brush=gr.Brush(default_size='5', color_mode='fixed'),
sources=['upload'], transforms=[])
with gr.Row():
t2_x_min = gr.Textbox(value=0, label='x_min')
t2_y_min = gr.Textbox(value=0, label='y_min')
t2_x_max = gr.Textbox(value=256, label='x_max')
t2_y_max = gr.Textbox(value=256, label='y_max')
t2_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t2_offset = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4, label='Guidance Offset')
t2_segment = gr.Button('Segment')
with gr.Column():
t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t2_prob_map = gr.Image(type='pil', label='Probability Map')
with gr.Tab('Mask Prompt'):
with gr.Row():
with gr.Column():
t3_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
brush=gr.Brush(default_size='5', color_mode='fixed'),
sources=['upload'], transforms=[])
with gr.Row():
t3_x_min = gr.Textbox(value=0, label='x_min')
t3_y_min = gr.Textbox(value=0, label='y_min')
t3_x_max = gr.Textbox(value=256, label='x_max')
t3_y_max = gr.Textbox(value=256, label='y_max')
t3_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t3_segment = gr.Button('Segment')
with gr.Column():
t3_pred = gr.AnnotatedImage(color_map={'Prompt': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t3_prob_map = gr.Image(type='pil', label='Probability Map')
t1_segment.click(
segment_image,
inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
outputs=[t1_input, t1_pred, t1_prob_map]
)
t2_segment.click(
segment_image_with_guidance,
inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
outputs=[t2_pred, t2_prob_map]
)
t3_segment.click(
segment_image_with_prompt,
inputs=[t3_input, t3_thresh, t3_x_min, t3_y_min, t3_x_max, t3_y_max],
outputs=[t3_pred, t3_prob_map]
)
demo.launch(debug=True, show_error=True)