halleewong commited on
Commit
86e6672
β€’
1 Parent(s): 2958b13

upgrade to gradio4

Browse files
README.md CHANGED
@@ -4,19 +4,19 @@ emoji: 🩻
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.41.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
12
 
13
- This demo uses the ScribblePrompt-UNet model described in ["ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Medical Image"](https://arxiv.org/abs/2312.07381)
14
 
15
  ```
16
- @article{wong2023scribbleprompt,
17
- title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Medical Image},
18
  author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
19
- journal={arXiv:2312.07381},
20
- year={2023},
21
  }
22
  ```
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.41.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
11
  ---
12
 
13
+ This demo uses the ScribblePrompt-UNet model described in ["ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image"](https://arxiv.org/abs/2312.07381)
14
 
15
  ```
16
+ @article{wong2024scribbleprompt,
17
+ title={ScribblePrompt: Fast and Flexible Interactive Segmentation for Any Biomedical Image},
18
  author={Hallee E. Wong and Marianne Rakic and John Guttag and Adrian V. Dalca},
19
+ journal={European Conference on Computer Vision (ECCV)},
20
+ year={2024},
21
  }
22
  ```
app.py CHANGED
@@ -5,20 +5,19 @@ import torch.nn.functional as F
5
  import os
6
  import cv2
7
  import pathlib
 
8
 
9
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
 
11
  from predictor import Predictor
12
 
13
- RES = 256
 
14
 
15
  test_example_dir = pathlib.Path("./test_examples")
16
  test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
17
 
18
- val_example_dir = pathlib.Path("./val_od_examples")
19
- val_examples = [str(val_example_dir / x) for x in sorted(os.listdir(val_example_dir))]
20
-
21
- default_example = test_example_dir / "TotalSegmentator_2.jpg"
22
  exp_dir = pathlib.Path('./checkpoints')
23
  default_model = 'ScribblePrompt-Unet'
24
 
@@ -82,7 +81,7 @@ def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
82
 
83
  if contour:
84
  contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
85
- cv2.drawContours(output, contours[0], -1, (0, 255, 0), 1)
86
  else:
87
  mask_overlay = _get_overlay(img, mask)
88
  mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
@@ -111,26 +110,29 @@ def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coo
111
 
112
  out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
113
 
 
 
 
114
  if point_coords is not None:
115
  for i,(col,row) in enumerate(point_coords):
116
  if point_labels[i] == 1:
117
- cv2.circle(out,(col, row), 2, (0,255,0), -1)
118
  else:
119
- cv2.circle(out,(col, row), 2, (255,0,0), -1)
120
 
121
  if bbox_coords is not None:
122
  for i in range(len(bbox_coords)//2):
123
- cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), 1)
124
  if len(bbox_coords) % 2 == 1:
125
  cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
126
 
127
- return out
128
 
129
  # -----------------------------------------------------------------------------
130
  # Collect scribbles
131
  # -----------------------------------------------------------------------------
132
 
133
- def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, label: int):
134
  """
135
  Record scribbles
136
  """
@@ -138,28 +140,19 @@ def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img, lab
138
 
139
  if scribble_img is not None:
140
 
141
- color_mask = scribble_img.get('mask')
142
- scribble_mask = color_mask[...,0]/255
143
 
144
- not_same = (scribble_mask != last_scribble_mask)
145
- if not isinstance(not_same, bool):
146
- not_same = not_same.any()
147
-
148
- if not_same:
149
- # In case any scribbles were removed
150
- corrected_scribble_masks = np.stack(2*[(scribble_mask > 0)], axis=0)*seperate_scribble_masks
151
- corrected_last_scribble_mask = last_scribble_mask*(scribble_mask > 0)
152
-
153
- delta = (scribble_mask - corrected_last_scribble_mask) > 0
154
- new_scribbles = scribble_mask * delta
155
- corrected_scribble_masks[label,...] = np.clip(corrected_scribble_masks[label,...] + new_scribbles, a_min=0, a_max=1)
156
-
157
- last_scribble_mask = scribble_mask
158
- seperate_scribble_masks = corrected_scribble_masks
159
 
160
  return seperate_scribble_masks, last_scribble_mask
161
 
162
- def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode):
 
163
  """
164
  Make predictions
165
  """
@@ -194,8 +187,7 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
194
 
195
  # Record any new scribbles
196
  seperate_scribble_masks, last_scribble_mask = get_scribbles(
197
- seperate_scribble_masks, last_scribble_mask, scribble_img,
198
- label=(0 if brush_label == "Positive (green)" else 1) # current color of the brush
199
  )
200
 
201
  # Make prediction
@@ -206,12 +198,33 @@ def refresh_predictions(predictor, input_img, output_img, click_coords, click_la
206
  # Update input visualizations
207
  mask_to_viz = best_mask.numpy()
208
  click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
209
- scribble_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  out_viz = [
212
  viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
213
  input_img,
214
- 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3),
215
  ]
216
 
217
  return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
@@ -298,8 +311,8 @@ def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res
298
  with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
299
 
300
  # State variables
301
- seperate_scribble_masks = gr.State(np.zeros((2,RES,RES), dtype=np.float32))
302
- last_scribble_mask = gr.State(np.zeros((RES,RES), dtype=np.float32))
303
 
304
  click_coords = gr.State([])
305
  click_labels = gr.State([])
@@ -312,10 +325,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
312
  low_res_mask = gr.State(None)
313
 
314
  gr.HTML("""\
315
- <h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Medical Image</h1>
316
- <p style="text-align: center; font-size: large;"><a href="https://scribbleprompt.csail.mit.edu">ScribblePrompt</a> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
317
- </p>
318
-
 
319
  """)
320
 
321
  with gr.Accordion("Open for instructions!", open=False):
@@ -351,34 +365,42 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
351
  value="Positive (green)", label="Scribble/Click Label")
352
  bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
353
  with gr.Column(scale=1):
 
354
  binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
355
  autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
356
- gr.Markdown("<span style='color:orange'>Troubleshooting:</span> If the image does not fully load in the Scribbles tab, click 'Clear Scribbles' or 'Clear All Inputs' to reload (it make take multiple tries). If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
 
357
  multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
358
 
359
  with gr.Row():
360
  display_height = 500
361
 
 
 
 
362
  with gr.Column(scale=1):
363
  with gr.Tab("Scribbles"):
364
- scribble_img = gr.Image(
365
  label="Input",
366
- brush_radius=2,
367
- interactive=True,
368
- brush_color="#00FF00",
369
- tool="sketch",
370
- height=display_height,
371
  type='numpy',
372
- value=default_example,
 
 
 
 
373
  )
374
- clear_scribble_button = gr.ClearButton([scribble_img], value="Clear Scribbles", variant="stop")
375
 
376
  with gr.Tab("Clicks/Boxes") as click_tab:
377
  click_img = gr.Image(
378
  label="Input",
379
  type='numpy',
380
  value=default_example,
381
- height=display_height
 
 
 
382
  )
383
  with gr.Row():
384
  undo_click_button = gr.Button("Undo Last Click")
@@ -388,21 +410,20 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
388
  input_img = gr.Image(
389
  label="Input",
390
  image_mode="L",
391
- visible=True,
392
  value=default_example,
393
- height=display_height
394
  )
395
  gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
396
 
397
  with gr.Column(scale=1):
398
  with gr.Tab("Output"):
399
  output_img = gr.Gallery(
400
- label='Outputs',
401
  columns=1,
402
  elem_id="gallery",
403
  preview=True,
404
  object_fit="scale-down",
405
- height=display_height+50
406
  )
407
 
408
  submit_button = gr.Button("Refresh Prediction", variant='primary')
@@ -424,28 +445,9 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
424
 
425
  gr.Examples(examples=test_examples,
426
  inputs=[input_img],
427
- examples_per_page=10,
428
- label='Unseen Examples from Test Datasets'
429
- )
430
-
431
- gr.Examples(examples=val_examples,
432
- inputs=[input_img],
433
- examples_per_page=10,
434
- label='Unseen Examples from Validation Datasets'
435
  )
436
-
437
- # When clear scribble button is clicked
438
- def clear_scribble_history(input_img):
439
- if input_img is not None:
440
- input_shape = input_img.shape[:2]
441
- else:
442
- input_shape = (RES, RES)
443
- return input_img, input_img, np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None
444
-
445
- clear_scribble_button.click(clear_scribble_history,
446
- inputs=[input_img],
447
- outputs=[click_img, scribble_img, seperate_scribble_masks, last_scribble_mask, best_mask, low_res_mask]
448
- )
449
 
450
  # When clear clicks button is clicked
451
  def clear_click_history(input_img):
@@ -460,9 +462,25 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
460
  if input_img is not None:
461
  input_shape = input_img.shape[:2]
462
  else:
463
- input_shape = (RES, RES)
464
  return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  input_img.change(clear_all_history,
467
  inputs=[input_img],
468
  outputs=[click_img, scribble_img,
@@ -527,7 +545,8 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
527
  undo_click_button.click(fn=undo_click,
528
  inputs=[
529
  predictor,
530
- input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
 
531
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
532
  output_img, binary_checkbox, multimask_mode, autopredict_checkbox
533
  ],
@@ -542,8 +561,7 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
542
  Draw scribbles in the click canvas
543
  """
544
  seperate_scribble_masks, last_scribble_mask = get_scribbles(
545
- seperate_scribble_masks, last_scribble_mask, scribble_img,
546
- label=(0 if brush_label == "Positive (green)" else 1) # previous color of the brush
547
  )
548
  click_input_viz = viz_pred_mask(
549
  input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
@@ -566,17 +584,11 @@ with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as de
566
  Recorn new scribbles when changing brush color
567
  """
568
  if label == "Negative (red)":
569
- brush_update = gr.Image.update(brush_color = "#FF0000") # red
570
  elif label == "Positive (green)":
571
- brush_update = gr.Image.update(brush_color = "#00FF00") # green
572
  else:
573
  raise TypeError("Invalid brush color")
574
-
575
- # Record latest scribbles
576
- seperate_scribble_masks, last_scribble_mask = get_scribbles(
577
- seperate_scribble_masks, last_scribble_mask, scribble_img,
578
- label=(1 if label == "Positive (green)" else 0) # previous color of the brush
579
- )
580
 
581
  return seperate_scribble_masks, last_scribble_mask, brush_update
582
 
 
5
  import os
6
  import cv2
7
  import pathlib
8
+ import math
9
 
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
 
12
  from predictor import Predictor
13
 
14
+ H = 256
15
+ W = 256
16
 
17
  test_example_dir = pathlib.Path("./test_examples")
18
  test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
19
 
20
+ default_example = test_examples[0]
 
 
 
21
  exp_dir = pathlib.Path('./checkpoints')
22
  default_model = 'ScribblePrompt-Unet'
23
 
 
81
 
82
  if contour:
83
  contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
84
+ cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2)
85
  else:
86
  mask_overlay = _get_overlay(img, mask)
87
  mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
 
110
 
111
  out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
112
 
113
+ H,W = img.shape[:2]
114
+ marker_size = min(H,W)//100
115
+
116
  if point_coords is not None:
117
  for i,(col,row) in enumerate(point_coords):
118
  if point_labels[i] == 1:
119
+ cv2.circle(out,(col, row), marker_size, (0,255,0), -1)
120
  else:
121
+ cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
122
 
123
  if bbox_coords is not None:
124
  for i in range(len(bbox_coords)//2):
125
+ cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size)
126
  if len(bbox_coords) % 2 == 1:
127
  cv2.circle(out, tuple(bbox_coords[-1]), 2, (255,165,0), -1)
128
 
129
+ return out.astype(np.uint8)
130
 
131
  # -----------------------------------------------------------------------------
132
  # Collect scribbles
133
  # -----------------------------------------------------------------------------
134
 
135
+ def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img):
136
  """
137
  Record scribbles
138
  """
 
140
 
141
  if scribble_img is not None:
142
 
143
+ # Only use first layer
144
+ color_mask = scribble_img.get('layers')[0]
145
 
146
+ positive_scribbles = 1.0*(color_mask[...,1] > 128)
147
+ negative_scribbles = 1.0*(color_mask[...,0] > 128)
148
+
149
+ seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0)
150
+ last_scribble_mask = None
 
 
 
 
 
 
 
 
 
 
151
 
152
  return seperate_scribble_masks, last_scribble_mask
153
 
154
+ def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
155
+ low_res_mask, img_features, multimask_mode):
156
  """
157
  Make predictions
158
  """
 
187
 
188
  # Record any new scribbles
189
  seperate_scribble_masks, last_scribble_mask = get_scribbles(
190
+ seperate_scribble_masks, last_scribble_mask, scribble_img
 
191
  )
192
 
193
  # Make prediction
 
198
  # Update input visualizations
199
  mask_to_viz = best_mask.numpy()
200
  click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
 
201
 
202
+ empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8)
203
+ full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8)
204
+ gray_mask = (255*mask_to_viz).astype(np.uint8)
205
+
206
+ bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
207
+ old_scribbles = scribble_img.get('layers')[0]
208
+
209
+ scribble_mask = 255*(old_scribbles > 0).any(-1)
210
+
211
+ scribble_input_viz = {
212
+ "background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1),
213
+ ["layers"][0]: [np.stack([
214
+ (255*seperate_scribble_masks[1]).astype(np.uint8),
215
+ (255*seperate_scribble_masks[0]).astype(np.uint8),
216
+ empty_channel,
217
+ scribble_mask
218
+ ], axis=-1)],
219
+ "composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1),
220
+ }
221
+
222
+ mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3)
223
+
224
  out_viz = [
225
  viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
226
  input_img,
227
+ mask_img,
228
  ]
229
 
230
  return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
 
311
  with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
312
 
313
  # State variables
314
+ seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32))
315
+ last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32))
316
 
317
  click_coords = gr.State([])
318
  click_labels = gr.State([])
 
325
  low_res_mask = gr.State(None)
326
 
327
  gr.HTML("""\
328
+ <h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Biomedical Image</h1>
329
+ <p style="text-align: center; font-size: large;">
330
+ <b>ScribblePrompt</b> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
331
+ [<a href="https://arxiv.org/abs/2312.07381">paper</a> | <a href="https://scribbleprompt.csail.mit.edu">website</a> | <a href="https://github.com/halleewong/ScribblePrompt">code</a>]
332
+ </p>
333
  """)
334
 
335
  with gr.Accordion("Open for instructions!", open=False):
 
365
  value="Positive (green)", label="Scribble/Click Label")
366
  bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
367
  with gr.Column(scale=1):
368
+
369
  binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
370
  autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
371
+ with gr.Accordion("Troubleshooting tips", open=False):
372
+ gr.Markdown("<span style='color:orange'>If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
373
  multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
374
 
375
  with gr.Row():
376
  display_height = 500
377
 
378
+ green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=2)
379
+ red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2)
380
+
381
  with gr.Column(scale=1):
382
  with gr.Tab("Scribbles"):
383
+ scribble_img = gr.ImageEditor(
384
  label="Input",
385
+ image_mode="RGB",
386
+ brush=green_brush,
 
 
 
387
  type='numpy',
388
+ value=default_example,
389
+ transforms=(),
390
+ sources=(),
391
+ show_download_button=True,
392
+ # height=display_height
393
  )
 
394
 
395
  with gr.Tab("Clicks/Boxes") as click_tab:
396
  click_img = gr.Image(
397
  label="Input",
398
  type='numpy',
399
  value=default_example,
400
+ show_download_button=True,
401
+ sources=(),
402
+ container=True,
403
+ # height=display_height-50
404
  )
405
  with gr.Row():
406
  undo_click_button = gr.Button("Undo Last Click")
 
410
  input_img = gr.Image(
411
  label="Input",
412
  image_mode="L",
 
413
  value=default_example,
414
+ # height=display_height
415
  )
416
  gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
417
 
418
  with gr.Column(scale=1):
419
  with gr.Tab("Output"):
420
  output_img = gr.Gallery(
421
+ label='Output',
422
  columns=1,
423
  elem_id="gallery",
424
  preview=True,
425
  object_fit="scale-down",
426
+ # height=display_height
427
  )
428
 
429
  submit_button = gr.Button("Refresh Prediction", variant='primary')
 
445
 
446
  gr.Examples(examples=test_examples,
447
  inputs=[input_img],
448
+ examples_per_page=12,
449
+ label='Examples from datasets unseen during training'
 
 
 
 
 
 
450
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
451
 
452
  # When clear clicks button is clicked
453
  def clear_click_history(input_img):
 
462
  if input_img is not None:
463
  input_shape = input_img.shape[:2]
464
  else:
465
+ input_shape = (H, W)
466
  return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
467
 
468
+ # def clear_history_and_pad_input(input_img):
469
+ # if input_img is not None:
470
+ # h,w = input_img.shape[:2]
471
+ # if h != w:
472
+ # # Pad to square
473
+ # pad = abs(h-w)
474
+ # if h > w:
475
+ # padding = [(0,0), (math.ceil(pad/2),math.floor(pad/2))]
476
+ # else:
477
+ # padding = [(math.ceil(pad/2),math.floor(pad/2)), (0,0)]
478
+
479
+ # input_img = np.pad(input_img, padding, mode='constant', constant_values=0)
480
+
481
+ # return clear_all_history(input_img)
482
+
483
+
484
  input_img.change(clear_all_history,
485
  inputs=[input_img],
486
  outputs=[click_img, scribble_img,
 
545
  undo_click_button.click(fn=undo_click,
546
  inputs=[
547
  predictor,
548
+ input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords,
549
+ click_labels, bbox_coords,
550
  seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
551
  output_img, binary_checkbox, multimask_mode, autopredict_checkbox
552
  ],
 
561
  Draw scribbles in the click canvas
562
  """
563
  seperate_scribble_masks, last_scribble_mask = get_scribbles(
564
+ seperate_scribble_masks, last_scribble_mask, scribble_img
 
565
  )
566
  click_input_viz = viz_pred_mask(
567
  input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
 
584
  Recorn new scribbles when changing brush color
585
  """
586
  if label == "Negative (red)":
587
+ brush_update = gr.update(brush=red_brush)
588
  elif label == "Positive (green)":
589
+ brush_update = gr.update(brush=green_brush)
590
  else:
591
  raise TypeError("Invalid brush color")
 
 
 
 
 
 
592
 
593
  return seperate_scribble_masks, last_scribble_mask, brush_update
594
 
predictor.py CHANGED
@@ -3,7 +3,6 @@ import torch.nn.functional as F
3
  from typing import Dict, Tuple, Optional
4
  import network
5
 
6
-
7
  class Predictor:
8
  """
9
  Wrapper for ScribblePrompt Unet model
@@ -96,6 +95,7 @@ def rescale_inputs(inputs: Dict[str,any], res=128):
96
  Rescale the inputs
97
  """
98
  h,w = inputs['img'].shape[-2:]
 
99
  if h != res or w != res:
100
 
101
  inputs.update(dict(
 
3
  from typing import Dict, Tuple, Optional
4
  import network
5
 
 
6
  class Predictor:
7
  """
8
  Wrapper for ScribblePrompt Unet model
 
95
  Rescale the inputs
96
  """
97
  h,w = inputs['img'].shape[-2:]
98
+
99
  if h != res or w != res:
100
 
101
  inputs.update(dict(
{val_od_examples β†’ test_examples}/ACDC.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/BTCV.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/BUID.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/DRIVE.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/HipXRay.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/PanDental.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/SCD.jpg RENAMED
File without changes
test_examples/SCR.jpg DELETED
Binary file (13.6 kB)
 
{val_od_examples β†’ test_examples}/SpineWeb.jpg RENAMED
File without changes
{val_od_examples β†’ test_examples}/WBC.jpg RENAMED
File without changes