Spaces:
Sleeping
Sleeping
File size: 7,725 Bytes
722738e 54b7544 722738e fadb2ab 722738e fadb2ab 722738e 54b7544 722738e 54b7544 722738e 54b7544 722738e 54b7544 b6879dc 54b7544 722738e 54b7544 722738e 54b7544 722738e 54b7544 722738e fadb2ab 722738e 54b7544 fadb2ab 54b7544 fadb2ab 54b7544 fadb2ab 54b7544 722738e 54b7544 722738e fadb2ab 722738e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import ToTensor
from transformers import SamModel, SamProcessor
to_tensor = ToTensor()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
model.to(device)
def segment_image(image, threshold, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image.save('image.png')
# init input data
prompt = [x_min, y_min, x_max, y_max]
inputs = processor(image, input_boxes=[[prompt]], return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
pred_mask = (prob_map > threshold).float().numpy()
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
ret_image = Image.open('image.png')
ret_pred = (Image.open('image.png'), [(pred_mask, 'Sidewalks')])
ret_prob = Image.open('prob.png')
return ret_image, ret_pred, ret_prob
def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image['background'].save('image.png')
# init input data
prompt = [x_min, y_min, x_max, y_max]
img = Image.open('image.png').convert('RGB')
inputs = processor(img, input_boxes=[[prompt]], return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
# perform mask guidance
guidance_mask = (np.max(np.array(image['layers'][0]), axis=2) != 0).astype(float)
enhance_map = prob_map.numpy() + offset * guidance_mask
pred_mask = (enhance_map > threshold).astype(float)
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(enhance_map, cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
return (image['background'], regions), Image.open('prob.png')
def segment_image_with_prompt(image, threshold, x_min, y_min, x_max, y_max):
# tolerate TIFF image input
image['background'].save('image.png')
# init input data
img = Image.open('image.png').convert('RGB')
mask = (np.max(np.array(image['layers'][0]), axis=2) != 0)
mask_prompt = to_tensor(mask).float()
box_prompt = [[[x_min, y_min, x_max, y_max]]]
inputs = processor(img, input_boxes=box_prompt,
input_masks=mask_prompt, return_tensors='pt')
# make prediction
outputs = model(pixel_values=inputs['pixel_values'].to(device),
input_boxes=inputs['input_boxes'].to(device),
input_masks=mask_prompt.to(device),
multimask_output=False)
prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
pred_mask = (prob_map > threshold).float().numpy()
# visualize results
plt.figure(figsize=(8, 8))
plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
plt.axis('off')
plt.tight_layout()
plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
plt.close()
# post-processing
regions = [(mask, 'Prompt'), (pred_mask, 'Sidewalks')]
return (image['background'], regions), Image.open('prob.png')
with gr.Blocks() as demo:
with gr.Tab('Baseline'):
with gr.Row():
with gr.Column():
t1_input = gr.Image(type='pil', label='Input Image')
with gr.Row():
t1_x_min = gr.Textbox(value=0, label='x_min')
t1_y_min = gr.Textbox(value=0, label='y_min')
t1_x_max = gr.Textbox(value=256, label='x_max')
t1_y_max = gr.Textbox(value=256, label='y_max')
t1_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t1_segment = gr.Button('Segment')
with gr.Column():
t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t1_prob_map = gr.Image(type='pil', label='Probability Map')
with gr.Tab('Mask Guidance (Best)'):
with gr.Row():
with gr.Column():
t2_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
brush=gr.Brush(default_size='5', color_mode='fixed'),
sources=['upload'], transforms=[])
with gr.Row():
t2_x_min = gr.Textbox(value=0, label='x_min')
t2_y_min = gr.Textbox(value=0, label='y_min')
t2_x_max = gr.Textbox(value=256, label='x_max')
t2_y_max = gr.Textbox(value=256, label='y_max')
t2_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t2_offset = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.4, label='Guidance Offset')
t2_segment = gr.Button('Segment')
with gr.Column():
t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t2_prob_map = gr.Image(type='pil', label='Probability Map')
with gr.Tab('Mask Prompt'):
with gr.Row():
with gr.Column():
t3_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
brush=gr.Brush(default_size='5', color_mode='fixed'),
sources=['upload'], transforms=[])
with gr.Row():
t3_x_min = gr.Textbox(value=0, label='x_min')
t3_y_min = gr.Textbox(value=0, label='y_min')
t3_x_max = gr.Textbox(value=256, label='x_max')
t3_y_max = gr.Textbox(value=256, label='y_max')
t3_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
t3_segment = gr.Button('Segment')
with gr.Column():
t3_pred = gr.AnnotatedImage(color_map={'Prompt': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
with gr.Column():
t3_prob_map = gr.Image(type='pil', label='Probability Map')
t1_segment.click(
segment_image,
inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
outputs=[t1_input, t1_pred, t1_prob_map]
)
t2_segment.click(
segment_image_with_guidance,
inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
outputs=[t2_pred, t2_prob_map]
)
t3_segment.click(
segment_image_with_prompt,
inputs=[t3_input, t3_thresh, t3_x_min, t3_y_min, t3_x_max, t3_y_max],
outputs=[t3_pred, t3_prob_map]
)
demo.launch(debug=True, show_error=True) |