hmdliu commited on
Commit
fadb2ab
1 Parent(s): 54b7544

Add mask prompt

Browse files
Files changed (2) hide show
  1. app.py +54 -2
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,8 +3,10 @@ import numpy as np
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
 
6
  from transformers import SamModel, SamProcessor
7
 
 
8
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
10
  model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
@@ -62,6 +64,34 @@ def segment_image_with_guidance(image, threshold, offset, x_min, y_min, x_max, y
62
  regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
63
  return (image['background'], regions), Image.open('prob.png')
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  with gr.Blocks() as demo:
66
  with gr.Tab('Baseline'):
67
  with gr.Row():
@@ -78,10 +108,10 @@ with gr.Blocks() as demo:
78
  t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
79
  with gr.Column():
80
  t1_prob_map = gr.Image(type='pil', label='Probability Map')
81
- with gr.Tab('Mask Guidance'):
82
  with gr.Row():
83
  with gr.Column():
84
- t2_input = gr.ImageEditor(type='pil', crop_size='2:1', label='Input Image',
85
  brush=gr.Brush(default_size='5', color_mode='fixed'),
86
  sources=['upload'], transforms=[])
87
  with gr.Row():
@@ -96,6 +126,23 @@ with gr.Blocks() as demo:
96
  t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
97
  with gr.Column():
98
  t2_prob_map = gr.Image(type='pil', label='Probability Map')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  t1_segment.click(
100
  segment_image,
101
  inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
@@ -106,4 +153,9 @@ with gr.Blocks() as demo:
106
  inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
107
  outputs=[t2_pred, t2_prob_map]
108
  )
 
 
 
 
 
109
  demo.launch(debug=True, show_error=True)
 
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  from PIL import Image
6
+ from torchvision.transforms import ToTensor
7
  from transformers import SamModel, SamProcessor
8
 
9
+ to_tensor = ToTensor()
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  processor = SamProcessor.from_pretrained('facebook/sam-vit-base')
12
  model = SamModel.from_pretrained('hmdliu/sidewalks-seg')
 
64
  regions = [(guidance_mask, 'Guidance'), (pred_mask, 'Sidewalks')]
65
  return (image['background'], regions), Image.open('prob.png')
66
 
67
+ def segment_image_with_prompt(image, threshold, x_min, y_min, x_max, y_max):
68
+ # tolerate TIFF image input
69
+ image['background'].save('image.png')
70
+ # init input data
71
+ img = Image.open('image.png').convert('RGB')
72
+ mask = (np.max(np.array(image['layers'][0]), axis=2) != 0)
73
+ mask_prompt = to_tensor(mask).float()
74
+ box_prompt = [[[x_min, y_min, x_max, y_max]]]
75
+ inputs = processor(img, input_boxes=box_prompt,
76
+ input_masks=mask_prompt, return_tensors='pt')
77
+ # make prediction
78
+ outputs = model(pixel_values=inputs['pixel_values'].to(device),
79
+ input_boxes=inputs['input_boxes'].to(device),
80
+ input_masks=mask_prompt.to(device),
81
+ multimask_output=False)
82
+ prob_map = torch.sigmoid(outputs.pred_masks.squeeze()).cpu().detach()
83
+ pred_mask = (prob_map > threshold).float().numpy()
84
+ # visualize results
85
+ plt.figure(figsize=(8, 8))
86
+ plt.imshow(prob_map.numpy(), cmap='jet', interpolation='nearest')
87
+ plt.axis('off')
88
+ plt.tight_layout()
89
+ plt.savefig('prob.png', bbox_inches='tight', pad_inches=0)
90
+ plt.close()
91
+ # post-processing
92
+ regions = [(mask, 'Prompt'), (pred_mask, 'Sidewalks')]
93
+ return (image['background'], regions), Image.open('prob.png')
94
+
95
  with gr.Blocks() as demo:
96
  with gr.Tab('Baseline'):
97
  with gr.Row():
 
108
  t1_pred = gr.AnnotatedImage(color_map={'Sidewalks': '#0000FF'}, label='Prediction')
109
  with gr.Column():
110
  t1_prob_map = gr.Image(type='pil', label='Probability Map')
111
+ with gr.Tab('Mask Guidance (Best)'):
112
  with gr.Row():
113
  with gr.Column():
114
+ t2_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
115
  brush=gr.Brush(default_size='5', color_mode='fixed'),
116
  sources=['upload'], transforms=[])
117
  with gr.Row():
 
126
  t2_pred = gr.AnnotatedImage(color_map={'Guidance': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
127
  with gr.Column():
128
  t2_prob_map = gr.Image(type='pil', label='Probability Map')
129
+ with gr.Tab('Mask Prompt'):
130
+ with gr.Row():
131
+ with gr.Column():
132
+ t3_input = gr.ImageEditor(type='pil', crop_size='1:1', label='Input Image',
133
+ brush=gr.Brush(default_size='5', color_mode='fixed'),
134
+ sources=['upload'], transforms=[])
135
+ with gr.Row():
136
+ t3_x_min = gr.Textbox(value=0, label='x_min')
137
+ t3_y_min = gr.Textbox(value=0, label='y_min')
138
+ t3_x_max = gr.Textbox(value=256, label='x_max')
139
+ t3_y_max = gr.Textbox(value=256, label='y_max')
140
+ t3_thresh = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.5, label='Prediction Threshold')
141
+ t3_segment = gr.Button('Segment')
142
+ with gr.Column():
143
+ t3_pred = gr.AnnotatedImage(color_map={'Prompt': '#FF0000', 'Sidewalks': '#0000FF'}, label='Prediction')
144
+ with gr.Column():
145
+ t3_prob_map = gr.Image(type='pil', label='Probability Map')
146
  t1_segment.click(
147
  segment_image,
148
  inputs=[t1_input, t1_slider, t1_x_min, t1_y_min, t1_x_max, t1_y_max],
 
153
  inputs=[t2_input, t2_thresh, t2_offset, t2_x_min, t2_y_min, t2_x_max, t2_y_max],
154
  outputs=[t2_pred, t2_prob_map]
155
  )
156
+ t3_segment.click(
157
+ segment_image_with_prompt,
158
+ inputs=[t3_input, t3_thresh, t3_x_min, t3_y_min, t3_x_max, t3_y_max],
159
+ outputs=[t3_pred, t3_prob_map]
160
+ )
161
  demo.launch(debug=True, show_error=True)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch
 
2
  matplotlib
3
  transformers
 
1
  torch
2
+ torchvision
3
  matplotlib
4
  transformers