Alexander McKinney commited on
Commit
8cd1abb
1 Parent(s): 7d008e4

fixes bug loading new image with different masks and cleans up code

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +42 -133
README.md CHANGED
@@ -16,3 +16,4 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
16
  - is there a way to stop the loading icon appearing? Would rather copy last input than flicker
17
  - onclick events for canvas? we can draw, but can I get coordinates?
18
  - checkboxes seem a bit busted with indexes
 
 
16
  - is there a way to stop the loading icon appearing? Would rather copy last input than flicker
17
  - onclick events for canvas? we can draw, but can I get coordinates?
18
  - checkboxes seem a bit busted with indexes
19
+ - set canvas default to segmentation output, make small edits rather than doing whole thing
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  from skimage.measure import block_reduce
7
- from typing import List
8
  from functools import reduce
9
 
10
  import gradio as gr
@@ -14,15 +14,6 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
14
 
15
  from diffusers import StableDiffusionInpaintPipeline
16
 
17
- # TODO: maybe need to port to `Blocks` system
18
- # allegedly provides:
19
- # Have multi-step interfaces, in which the output of one model becomes the
20
- # input to the next model, or have more flexible data flows in general.
21
-
22
- # and:
23
- # Change a component’s properties (for example, the choices in a dropdown) or its visibility based on user input
24
- # https://huggingface.co/course/chapter9/7?fw=pt
25
-
26
  torch.inference_mode()
27
  torch.no_grad()
28
 
@@ -61,7 +52,6 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
61
  device = get_device()
62
 
63
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
64
- # segmentation_model = segmentation_model.to(device)
65
 
66
  pipe = load_diffusion_pipeline()
67
  pipe = pipe.to(device)
@@ -83,7 +73,6 @@ def fn_segmentation(image, max_kernel, min_kernel):
83
  m = panoptic_seg_id == s['id']
84
  raw_masks.append(m.astype(np.uint8) * 255)
85
 
86
- # masks = fn_clean(raw_masks, max_kernel, min_kernel)
87
  checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']]
88
 
89
  checkbox_group = gr.CheckboxGroup.update(
@@ -119,7 +108,16 @@ def fn_update_mask(
119
 
120
  return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
121
 
122
- def fn_diffusion(prompt: str, masked_image: Image, mask: Image, num_diffusion_steps: int):
 
 
 
 
 
 
 
 
 
123
  STABLE_DIFFUSION_SMALL_EDGE = 512
124
 
125
  w, h = masked_image.size
@@ -141,151 +139,62 @@ def fn_diffusion(prompt: str, masked_image: Image, mask: Image, num_diffusion_st
141
  prompt=prompt,
142
  image=masked_image,
143
  mask_image=mask,
144
- num_inference_steps=num_diffusion_steps
 
 
145
  ).images[0]
146
 
147
  inpainted_image = inpainted_image.resize((w, h))
148
 
149
  return inpainted_image
150
 
151
- def fn_segmentation_diffusion(prompt, mask_indices, image, max_kernel, min_kernel, num_diffusion_steps):
152
- mask_indices = [int(i) for i in mask_indices.split(',')]
153
- inputs = feature_extractor(images=image, return_tensors="pt")
154
- outputs = segmentation_model(**inputs)
155
-
156
- processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
157
- result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
158
-
159
- panoptic_seg = Image.open(io.BytesIO(result["png_string"])).resize((image.width, image.height))
160
- panoptic_seg = np.array(panoptic_seg, dtype=np.uint8)
161
-
162
- class_str = '\n'.join(segmentation_cfg.id2label[s['category_id']] for s in result['segments_info'])
163
-
164
- panoptic_seg_id = rgb_to_id(panoptic_seg)
165
-
166
- if len(mask_indices) > 0:
167
- mask = (panoptic_seg_id == mask_indices[0])
168
- for idx in mask_indices[1:]:
169
- mask = mask | (panoptic_seg_id == idx)
170
- mask = clean_mask(mask, min_kernel=min_kernel, max_kernel=max_kernel)
171
-
172
- masked_image = np.array(image).copy()
173
- masked_image[mask] = 0
174
-
175
- masked_image = Image.fromarray(masked_image).resize(image.size)
176
- mask = Image.fromarray(mask.astype(np.uint8) * 255).resize(image.size)
177
-
178
- if num_diffusion_steps == 0:
179
- return masked_image, masked_image, class_str
180
-
181
- STABLE_DIFFUSION_SMALL_EDGE = 512
182
-
183
- assert masked_image.size == mask.size
184
- w, h = masked_image.size
185
- is_width_larger = w > h
186
- resize_ratio = STABLE_DIFFUSION_SMALL_EDGE / (h if is_width_larger else w)
187
-
188
- new_width = int(w * resize_ratio) if is_width_larger else STABLE_DIFFUSION_SMALL_EDGE
189
- new_height = STABLE_DIFFUSION_SMALL_EDGE if is_width_larger else int(h * resize_ratio)
190
-
191
- new_width += 8 - (new_width % 8) if is_width_larger else 0
192
- new_height += 0 if is_width_larger else 8 - (new_height % 8)
193
-
194
- mask = mask.convert("RGB").resize((new_width, new_height))
195
- masked_image = masked_image.convert("RGB").resize((new_width, new_height))
196
-
197
- inpainted_image = pipe(
198
- height=new_height,
199
- width=new_width,
200
- prompt=prompt,
201
- image=masked_image,
202
- mask_image=mask,
203
- num_inference_steps=num_diffusion_steps
204
- ).images[0]
205
-
206
- return masked_image, inpainted_image, class_str
207
-
208
-
209
- # iface_segmentation = gr.Interface(
210
- # fn=fn_segmentation,
211
- # inputs=[
212
- # "text",
213
- # "text",
214
- # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg"),
215
- # gr.Slider(minimum=1, maximum=99, value=23, step=2),
216
- # gr.Slider(minimum=1, maximum=99, value=5, step=2),
217
- # gr.Slider(minimum=0, maximum=100, value=50, step=1),
218
- # ],
219
- # outputs=["text", gr.Image(type="pil"), gr.Image(type="pil"), "number", "text"]
220
- # )
221
-
222
- # iface_diffusion = gr.Interface(
223
- # fn=fn_diffusion,
224
- # inputs=["text", gr.Image(type='pil'), gr.Image(type='pil'), "number", "text"],
225
- # outputs=[gr.Image(), gr.Image(), gr.Textbox()]
226
- # )
227
-
228
- # iface = gr.Series(
229
- # iface_segmentation, iface_diffusion,
230
-
231
- # iface = gr.Interface(
232
- # fn=fn_segmentation_diffusion,
233
- # inputs=[
234
- # "text",
235
- # "text",
236
- # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
237
- # gr.Slider(minimum=1, maximum=99, value=23, step=2),
238
- # gr.Slider(minimum=1, maximum=99, value=5, step=2),
239
- # gr.Slider(minimum=0, maximum=100, value=50, step=1),
240
- # ],
241
- # outputs=[gr.Image(), gr.Image(), gr.Textbox(interactive=False)]
242
- # )
243
-
244
- # iface = gr.Interface(
245
- # fn=fn_segmentation,
246
- # inputs=[
247
- # gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil'),
248
- # gr.Slider(minimum=1, maximum=99, value=23, step=2),
249
- # gr.Slider(minimum=1, maximum=99, value=5, step=2),
250
- # ],
251
- # outputs=gr.Gallery()
252
- # )
253
-
254
- # iface.launch()
255
-
256
  demo = gr.Blocks()
257
 
258
  with demo:
259
- input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil')
260
 
261
  bt_masks = gr.Button("Compute Masks")
262
 
263
  with gr.Row():
264
- mask_image = gr.Image(type='numpy')
265
- masked_image = gr.Image(type='pil')
266
  mask_storage = gr.State()
267
 
268
  with gr.Row():
269
- max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2)
270
- min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2)
271
 
272
- mask_checkboxes = gr.CheckboxGroup(interactive=True)
273
 
274
  with gr.Row():
275
  with gr.Column():
276
- prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.")
277
- steps_slider = gr.Slider(minimum=1, maximum=100, value=50)
 
 
 
278
  bt_diffusion = gr.Button("Run Diffusion")
279
 
280
- inpainted_image = gr.Image(type='pil')
281
 
 
 
282
 
283
- bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
284
 
285
- max_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
286
- min_slider.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
287
- mask_checkboxes.change(fn_update_mask, inputs=[input_image, mask_storage, mask_checkboxes, max_slider, min_slider], outputs=[mask_image, masked_image])
288
 
289
- bt_diffusion.click(fn_diffusion, inputs=[prompt, masked_image, mask_image, steps_slider], outputs=inpainted_image)
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  demo.launch()
 
4
  import torch
5
  from PIL import Image
6
  from skimage.measure import block_reduce
7
+ from typing import List, Optional
8
  from functools import reduce
9
 
10
  import gradio as gr
 
14
 
15
  from diffusers import StableDiffusionInpaintPipeline
16
 
 
 
 
 
 
 
 
 
 
17
  torch.inference_mode()
18
  torch.no_grad()
19
 
 
52
  device = get_device()
53
 
54
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
 
55
 
56
  pipe = load_diffusion_pipeline()
57
  pipe = pipe.to(device)
 
73
  m = panoptic_seg_id == s['id']
74
  raw_masks.append(m.astype(np.uint8) * 255)
75
 
 
76
  checkbox_choices = [f"{s['id']}:{segmentation_cfg.id2label[s['category_id']]}" for s in result['segments_info']]
77
 
78
  checkbox_group = gr.CheckboxGroup.update(
 
108
 
109
  return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
110
 
111
+ def fn_diffusion(
112
+ prompt: str,
113
+ masked_image: Image,
114
+ mask: Image,
115
+ num_diffusion_steps: int,
116
+ guidance_scale: float,
117
+ negative_prompt: Optional[str] = None,
118
+ ):
119
+ if len(negative_prompt) == 0:
120
+ negative_prompt = None
121
  STABLE_DIFFUSION_SMALL_EDGE = 512
122
 
123
  w, h = masked_image.size
 
139
  prompt=prompt,
140
  image=masked_image,
141
  mask_image=mask,
142
+ num_inference_steps=num_diffusion_steps,
143
+ guidance_scale=guidance_scale,
144
+ negative_prompt=negative_prompt
145
  ).images[0]
146
 
147
  inpainted_image = inpainted_image.resize((w, h))
148
 
149
  return inpainted_image
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  demo = gr.Blocks()
152
 
153
  with demo:
154
+ input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
155
 
156
  bt_masks = gr.Button("Compute Masks")
157
 
158
  with gr.Row():
159
+ mask_image = gr.Image(type='numpy', label="Diffusion Mask")
160
+ masked_image = gr.Image(type='pil', label="Masked Image")
161
  mask_storage = gr.State()
162
 
163
  with gr.Row():
164
+ max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
165
+ min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
166
 
167
+ mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
168
 
169
  with gr.Row():
170
  with gr.Column():
171
+ prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
172
+ negative_prompt = gr.Textbox(label="Negative Prompt")
173
+ with gr.Column():
174
+ steps_slider = gr.Slider(minimum=1, maximum=100, value=50, label="Inference Steps")
175
+ guidance_slider = gr.Slider(minimum=0.0, maximum=50.0, value=7.5, step=0.1, label="Guidance Scale")
176
  bt_diffusion = gr.Button("Run Diffusion")
177
 
178
+ inpainted_image = gr.Image(type='pil', label="Inpainted Image")
179
 
180
+ update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
181
+ update_mask_outputs = [mask_image, masked_image]
182
 
183
+ input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
184
 
185
+ bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
 
 
186
 
187
+ max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
188
+ min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
189
+ mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
190
+
191
+ bt_diffusion.click(fn_diffusion, inputs=[
192
+ prompt,
193
+ masked_image,
194
+ mask_image,
195
+ steps_slider,
196
+ guidance_slider,
197
+ negative_prompt
198
+ ], outputs=inpainted_image)
199
 
200
  demo.launch()