Alexander McKinney commited on
Commit
557cf2f
1 Parent(s): b4542eb

blocks example of segmentation with interactive sliders

Browse files
Files changed (1) hide show
  1. app.py +79 -15
app.py CHANGED
@@ -64,7 +64,42 @@ feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_mode
64
  pipe = load_diffusion_pipeline()
65
  pipe = pipe.to(device)
66
 
67
- # TODO: potentially use `gr.Gallery` to display different masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
69
  mask_indices = [int(i) for i in mask_indices.split(',')]
70
  inputs = feature_extractor(images=image, return_tensors="pt")
@@ -144,17 +179,46 @@ def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kerne
144
 
145
  # iface = gr.Series(
146
  # iface_segmentation, iface_diffusion,
147
- iface = gr.Interface(
148
- fn=fn_segmentation_diffusion,
149
- inputs=[
150
- "text",
151
- "text",
152
- gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
153
- gr.Slider(minimum=1, maximum=99, value=23, step=2),
154
- gr.Slider(minimum=1, maximum=99, value=5, step=2),
155
- gr.Slider(minimum=0, maximum=100, value=50, step=1),
156
- ],
157
- outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
158
- )
159
-
160
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  pipe = load_diffusion_pipeline()
65
  pipe = pipe.to(device)
66
 
67
+ def fn_segmentation(image, max_kernel, min_kernel):
68
+ inputs = feature_extractor(images=image, return_tensors="pt")
69
+ outputs = segmentation_model(**inputs)
70
+
71
+ processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
72
+ result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
73
+
74
+ panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
75
+ panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
76
+
77
+ panoptic_seg_id = rgb_to_id(panoptic_seg)
78
+
79
+ raw_masks = []
80
+ for s in result['segments_info']:
81
+ m = panoptic_seg_id == s['id']
82
+ raw_masks.append(m.astype(np.uint8) * 255)
83
+
84
+ masks = fn_clean(raw_masks, max_kernel, min_kernel)
85
+
86
+ return masks, raw_masks
87
+
88
+ def fn_clean(masks, max_kernel, min_kernel):
89
+ out = []
90
+ for m in masks:
91
+ m = torch.FloatTensor(m)[None, None]
92
+ m = min_pool(m, min_kernel)
93
+ m = max_pool(m, max_kernel)
94
+ m = m.squeeze().numpy().astype(np.uint8)
95
+ out.append(m)
96
+
97
+ return out
98
+
99
+ def fn_mask(image, mask_enabled):
100
+ if len(mask_enabled) == 0:
101
+ return image
102
+
103
  def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
104
  mask_indices = [int(i) for i in mask_indices.split(',')]
105
  inputs = feature_extractor(images=image, return_tensors="pt")
 
179
 
180
  # iface = gr.Series(
181
  # iface_segmentation, iface_diffusion,
182
+
183
+ # iface = gr.Interface(
184
+ # fn=fn_segmentation_diffusion,
185
+ # inputs=[
186
+ # "text",
187
+ # "text",
188
+ # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
189
+ # gr.Slider(minimum=1, maximum=99, value=23, step=2),
190
+ # gr.Slider(minimum=1, maximum=99, value=5, step=2),
191
+ # gr.Slider(minimum=0, maximum=100, value=50, step=1),
192
+ # ],
193
+ # outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
194
+ # )
195
+
196
+ # iface = gr.Interface(
197
+ # fn=fn_segmentation,
198
+ # inputs=[
199
+ # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
200
+ # gr.Slider(minimum=1, maximum=99, value=23, step=2),
201
+ # gr.Slider(minimum=1, maximum=99, value=5, step=2),
202
+ # ],
203
+ # outputs=gr.Gallery()
204
+ # )
205
+
206
+ # iface.launch()
207
+
208
+ demo = gr.Blocks()
209
+
210
+ with demo:
211
+ input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil')
212
+ mask_gallery = gr.Gallery()
213
+ mask_storage = gr.State()
214
+
215
+ max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2)
216
+ min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2)
217
+
218
+ bt_masks = gr.Button("Compute Masks")
219
+
220
+ bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_gallery, mask_storage])
221
+ max_slider.change(fn_clean, inputs=[mask_storage, max_slider, min_slider], outputs=mask_gallery)
222
+ min_slider.change(fn_clean, inputs=[mask_storage, max_slider, min_slider], outputs=mask_gallery)
223
+
224
+ demo.launch()