Alexander McKinney commited on
Commit
7d008e4
1 Parent(s): 557cf2f

full example in blocks

Browse files

no click support still, also some bugs when changing source image with
masks.

Files changed (2) hide show
  1. README.md +5 -0
  2. app.py +80 -13
README.md CHANGED
@@ -11,3 +11,8 @@ license: creativeml-openrail-m
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ### Notes on Gradio changes
16
+ - is there a way to stop the loading icon appearing? Would rather copy last input than flicker
17
+ - onclick events for canvas? we can draw, but can I get coordinates?
18
+ - checkboxes seem a bit busted with indexes
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  from skimage.measure import block_reduce
 
 
7
 
8
  import gradio as gr
9
 
@@ -49,7 +51,7 @@ def max_pool(x: torch.Tensor, kernel_size: int):
49
  pad_size = (kernel_size - 1) // 2
50
  return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
51
 
52
- def clean_mask(mask, min_kernel: int = 5, max_kernel: int = 23):
53
  mask = torch.Tensor(mask[None, None]).float()
54
  mask = min_pool(mask, min_kernel)
55
  mask = max_pool(mask, max_kernel)
@@ -81,9 +83,14 @@ def fn_segmentation(image, max_kernel, min_kernel):
81
  m = panoptic_seg_id == s['id']
82
  raw_masks.append(m.astype(np.uint8) * 255)
83
 
84
- masks = fn_clean(raw_masks, max_kernel, min_kernel)
 
 
 
 
 
85
 
86
- return masks, raw_masks
87
 
88
  def fn_clean(masks, max_kernel, min_kernel):
89
  out = []
@@ -96,9 +103,50 @@ def fn_clean(masks, max_kernel, min_kernel):
96
 
97
  return out
98
 
99
- def fn_mask(image, mask_enabled):
100
- if len(mask_enabled) == 0:
101
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
104
  mask_indices = [int(i) for i in mask_indices.split(',')]
@@ -209,16 +257,35 @@ demo = gr.Blocks()
209
 
210
  with demo:
211
  input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil')
212
- mask_gallery = gr.Gallery()
 
 
 
 
 
213
  mask_storage = gr.State()
214
 
215
- max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2)
216
- min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2)
 
217
 
218
- bt_masks = gr.Button("Compute Masks")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_gallery, mask_storage])
221
- max_slider.change(fn_clean, inputs=[mask_storage, max_slider, min_slider], outputs=mask_gallery)
222
- min_slider.change(fn_clean, inputs=[mask_storage, max_slider, min_slider], outputs=mask_gallery)
223
 
224
  demo.launch()
 
4
  import torch
5
  from PIL import Image
6
  from skimage.measure import block_reduce
7
+ from typing import List
8
+ from functools import reduce
9
 
10
  import gradio as gr
11
 
 
51
  pad_size = (kernel_size - 1) // 2
52
  return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
53
 
54
+ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
55
  mask = torch.Tensor(mask[None, None]).float()
56
  mask = min_pool(mask, min_kernel)
57
  mask = max_pool(mask, max_kernel)
 
83
  m = panoptic_seg_id == s['id']
84
  raw_masks.append(m.astype(np.uint8) * 255)
85
 
86
+ # masks = fn_clean(raw_masks, max_kernel, min_kernel)
87
+ checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']]
88
+
89
+ checkbox_group = gr.CheckboxGroup.update(
90
+ choices=checkbox_choices
91
+ )
92
 
93
+ return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
94
 
95
  def fn_clean(masks, max_kernel, min_kernel):
96
  out = []
 
103
 
104
  return out
105
 
106
+ def fn_update_mask(
107
+ image: Image,
108
+ masks: List[np.array],
109
+ masks_enabled: List[int],
110
+ max_kernel: int,
111
+ min_kernel: int,
112
+ ):
113
+ masks_enabled = [int(m.split(':')[0]) for m in masks_enabled]
114
+ combined_mask = reduce(lambda x, y: x | y, [masks[i] for i in masks_enabled], np.zeros_like(masks[0], dtype=bool))
115
+ combined_mask = clean_mask(combined_mask, max_kernel, min_kernel)
116
+
117
+ masked_image = np.array(image).copy()
118
+ masked_image[combined_mask] = 0.0
119
+
120
+ return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
121
+
122
+ def fn_diffusion(prompt: str, masked_image: Image, mask: Image, num_diffusion_steps: int):
123
+ STABLE_DIFFUSION_SMALL_EDGE = 512
124
+
125
+ w, h = masked_image.size
126
+ is_width_larger = w > h
127
+ resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
128
+
129
+ new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
130
+ new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
131
+
132
+ new_width += 8 - (new_width % 8) if is_width_larger else 0
133
+ new_height += 0 if is_width_larger else 8 - (new_height % 8)
134
+
135
+ mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
136
+ masked_image = masked_image.convert("RGB").resize((new_width, new_height))
137
+
138
+ inpainted_image = pipe(
139
+ height=new_height,
140
+ width=new_width,
141
+ prompt=prompt,
142
+ image=masked_image,
143
+ mask_image=mask,
144
+ num_inference_steps=num_diffusion_steps
145
+ ).images[0]
146
+
147
+ inpainted_image = inpainted_image.resize((w, h))
148
+
149
+ return inpainted_image
150
 
151
  def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
152
  mask_indices = [int(i) for i in mask_indices.split(',')]
 
257
 
258
  with demo:
259
  input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil')
260
+
261
+ bt_masks = gr.Button("Compute Masks")
262
+
263
+ with gr.Row():
264
+ mask_image = gr.Image(type='numpy')
265
+ masked_image = gr.Image(type='pil')
266
  mask_storage = gr.State()
267
 
268
+ with gr.Row():
269
+ max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2)
270
+ min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2)
271
 
272
+ mask_checkboxes = gr.CheckboxGroup(interactive=True)
273
+
274
+ with gr.Row():
275
+ with gr.Column():
276
+ prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.")
277
+ steps_slider = gr.Slider(minimum=1, maximum=100, value=50)
278
+ bt_diffusion = gr.Button("Run Diffusion")
279
+
280
+ inpainted_image = gr.Image(type='pil')
281
+
282
+
283
+ bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
284
+
285
+ max_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
286
+ min_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
287
+ mask_checkboxes.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
288
 
289
+ bt_diffusion.click(fn_diffusion, inputs=[prompt, masked_image, mask_image, steps_slider], outputs=inpainted_image)
 
 
290
 
291
  demo.launch()