Alexander McKinney commited on
Commit
d16d053
1 Parent(s): 8cd1abb

adds comments to code

Browse files
Files changed (1) hide show
  1. app.py +23 -14
app.py CHANGED
@@ -17,6 +17,7 @@ from diffusers import StableDiffusionInpaintPipeline
17
  torch.inference_mode()
18
  torch.no_grad()
19
 
 
20
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
21
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
22
  model = DetrForSegmentation.from_pretrained(model_name)
@@ -24,6 +25,7 @@ def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic
24
 
25
  return feature_extractor, model, cfg
26
 
 
27
  def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
28
  return StableDiffusionInpaintPipeline.from_pretrained(
29
  model_name,
@@ -31,6 +33,7 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
31
  torch_dtype=torch.float16
32
  )
33
 
 
34
  def get_device(try_cuda=True):
35
  return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
36
 
@@ -42,6 +45,7 @@ def max_pool(x: torch.Tensor, kernel_size: int):
42
  pad_size = (kernel_size - 1) // 2
43
  return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
44
 
 
45
  def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
46
  mask = torch.Tensor(mask[None, None]).float()
47
  mask = min_pool(mask, min_kernel)
@@ -49,13 +53,14 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
49
  mask = mask.bool().squeeze().numpy()
50
  return mask
51
 
52
- device = get_device()
53
 
54
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
55
-
56
  pipe = load_diffusion_pipeline()
 
 
57
  pipe = pipe.to(device)
58
 
 
59
  def fn_segmentation(image, max_kernel, min_kernel):
60
  inputs = feature_extractor(images=image, return_tensors="pt")
61
  outputs = segmentation_model(**inputs)
@@ -81,17 +86,7 @@ def fn_segmentation(image, max_kernel, min_kernel):
81
 
82
  return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
83
 
84
- def fn_clean(masks, max_kernel, min_kernel):
85
- out = []
86
- for m in masks:
87
- m = torch.FloatTensor(m)[None, None]
88
- m = min_pool(m, min_kernel)
89
- m = max_pool(m, max_kernel)
90
- m = m.squeeze().numpy().astype(np.uint8)
91
- out.append(m)
92
-
93
- return out
94
-
95
  def fn_update_mask(
96
  image: Image,
97
  masks: List[np.array],
@@ -108,6 +103,7 @@ def fn_update_mask(
108
 
109
  return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
110
 
 
111
  def fn_diffusion(
112
  prompt: str,
113
  masked_image: Image,
@@ -118,6 +114,9 @@ def fn_diffusion(
118
  ):
119
  if len(negative_prompt) == 0:
120
  negative_prompt = None
 
 
 
121
  STABLE_DIFFUSION_SMALL_EDGE = 512
122
 
123
  w, h = masked_image.size
@@ -133,6 +132,7 @@ def fn_diffusion(
133
  mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
134
  masked_image = masked_image.convert("RGB").resize((new_width, new_height))
135
 
 
136
  inpainted_image = pipe(
137
  height=new_height,
138
  width=new_width,
@@ -144,6 +144,7 @@ def fn_diffusion(
144
  negative_prompt=negative_prompt
145
  ).images[0]
146
 
 
147
  inpainted_image = inpainted_image.resize((w, h))
148
 
149
  return inpainted_image
@@ -151,21 +152,24 @@ def fn_diffusion(
151
  demo = gr.Blocks()
152
 
153
  with demo:
 
154
  input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
155
 
 
156
  bt_masks = gr.Button("Compute Masks")
157
-
158
  with gr.Row():
159
  mask_image = gr.Image(type='numpy', label="Diffusion Mask")
160
  masked_image = gr.Image(type='pil', label="Masked Image")
161
  mask_storage = gr.State()
162
 
 
163
  with gr.Row():
164
  max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
165
  min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
166
 
167
  mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
168
 
 
169
  with gr.Row():
170
  with gr.Column():
171
  prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
@@ -180,14 +184,19 @@ with demo:
180
  update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
181
  update_mask_outputs = [mask_image, masked_image]
182
 
 
183
  input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
184
 
 
185
  bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
186
 
 
 
187
  max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
188
  min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
189
  mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
190
 
 
191
  bt_diffusion.click(fn_diffusion, inputs=[
192
  prompt,
193
  masked_image,
 
17
  torch.inference_mode()
18
  torch.no_grad()
19
 
20
+ # Load segmentation models
21
  def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
22
  feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
23
  model = DetrForSegmentation.from_pretrained(model_name)
 
25
 
26
  return feature_extractor, model, cfg
27
 
28
+ # Load diffusion pipeline
29
  def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpainting'):
30
  return StableDiffusionInpaintPipeline.from_pretrained(
31
  model_name,
 
33
  torch_dtype=torch.float16
34
  )
35
 
36
+ # Device helper
37
  def get_device(try_cuda=True):
38
  return torch.device('cuda' if try_cuda and torch.cuda.is_available() else 'cpu')
39
 
 
45
  pad_size = (kernel_size - 1) // 2
46
  return torch.nn.functional.max_pool2d(x, kernel_size, (1, 1), padding=pad_size)
47
 
48
+ # Apply min-max pooling to clean up mask
49
  def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
50
  mask = torch.Tensor(mask[None, None]).float()
51
  mask = min_pool(mask, min_kernel)
 
53
  mask = mask.bool().squeeze().numpy()
54
  return mask
55
 
 
56
 
57
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
 
58
  pipe = load_diffusion_pipeline()
59
+
60
+ device = get_device()
61
  pipe = pipe.to(device)
62
 
63
+ # Callback function that runs segmentation and updates CheckboxGroup
64
  def fn_segmentation(image, max_kernel, min_kernel):
65
  inputs = feature_extractor(images=image, return_tensors="pt")
66
  outputs = segmentation_model(**inputs)
 
86
 
87
  return raw_masks, checkbox_group, gr.Image.update(value=np.zeros((image.height, image.width))), gr.Image.update(value=image)
88
 
89
+ # Callback function that updates the displayed mask based on selected checkboxes
 
 
 
 
 
 
 
 
 
 
90
  def fn_update_mask(
91
  image: Image,
92
  masks: List[np.array],
 
103
 
104
  return combined_mask.astype(np.uint8) * 255, Image.fromarray(masked_image)
105
 
106
+ # Callback function that runs diffusion given the current image, mask and prompt.
107
  def fn_diffusion(
108
  prompt: str,
109
  masked_image: Image,
 
114
  ):
115
  if len(negative_prompt) == 0:
116
  negative_prompt = None
117
+
118
+ # Resize image to a more stable diffusion friendly format.
119
+ # TODO: remove magic number
120
  STABLE_DIFFUSION_SMALL_EDGE = 512
121
 
122
  w, h = masked_image.size
 
132
  mask = Image.fromarray(mask).convert("RGB").resize((new_width, new_height))
133
  masked_image = masked_image.convert("RGB").resize((new_width, new_height))
134
 
135
+ # Run diffusion
136
  inpainted_image = pipe(
137
  height=new_height,
138
  width=new_width,
 
144
  negative_prompt=negative_prompt
145
  ).images[0]
146
 
147
+ # Resize back to the original size
148
  inpainted_image = inpainted_image.resize((w, h))
149
 
150
  return inpainted_image
 
152
  demo = gr.Blocks()
153
 
154
  with demo:
155
+ # Input image control
156
  input_image = gr.Image(value="http://images.cocodataset.org/val2017/000000039769.jpg", type='pil', label="Input Image")
157
 
158
+ # Combined mask controls
159
  bt_masks = gr.Button("Compute Masks")
 
160
  with gr.Row():
161
  mask_image = gr.Image(type='numpy', label="Diffusion Mask")
162
  masked_image = gr.Image(type='pil', label="Masked Image")
163
  mask_storage = gr.State()
164
 
165
+ # Mask editing controls
166
  with gr.Row():
167
  max_slider = gr.Slider(minimum=1, maximum=99, value=23, step=2, label="Mask Overflow")
168
  min_slider = gr.Slider(minimum=1, maximum=99, value=5, step=2, label="Mask Denoising")
169
 
170
  mask_checkboxes = gr.CheckboxGroup(interactive=True, label="Mask Selection")
171
 
172
+ # Diffusion controls and output
173
  with gr.Row():
174
  with gr.Column():
175
  prompt = gr.Textbox("Two ginger cats lying together on a pink sofa. There are two TV remotes. High definition.", label="Prompt")
 
184
  update_mask_inputs = [input_image, mask_storage, mask_checkboxes, max_slider, min_slider]
185
  update_mask_outputs = [mask_image, masked_image]
186
 
187
+ # Clear checkbox group on input image change
188
  input_image.change(lambda: gr.CheckboxGroup.update(choices=[], value=[]), outputs=mask_checkboxes)
189
 
190
+ # Segmentation button callback
191
  bt_masks.click(fn_segmentation, inputs=[input_image, max_slider, min_slider], outputs=[mask_storage, mask_checkboxes, mask_image, masked_image])
192
 
193
+ # Update mask callbacks
194
+ # TODO: can we replace this with `mask_image.change`? Not sure if it will actively update.
195
  max_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
196
  min_slider.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
197
  mask_checkboxes.change(fn_update_mask, inputs=update_mask_inputs, outputs=update_mask_outputs)
198
 
199
+ # Diffusion button callback
200
  bt_diffusion.click(fn_diffusion, inputs=[
201
  prompt,
202
  masked_image,