ScribblePrompt / app.py
halleewong's picture
improve heights
8d1c21c
raw
history blame contribute delete
No virus
27 kB
import gradio as gr
import numpy as np
import torch
import torch.nn.functional as F
import os
import cv2
import pathlib
import math
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from predictor import Predictor
H = 256
W = 256
test_example_dir = pathlib.Path("./test_examples")
test_examples = [str(test_example_dir / x) for x in sorted(os.listdir(test_example_dir))]
default_example = test_examples[0]
exp_dir = pathlib.Path('./checkpoints')
default_model = 'ScribblePrompt-Unet'
model_dict = {
'ScribblePrompt-Unet': 'ScribblePrompt_unet_v1_nf192_res128.pt'
}
# -----------------------------------------------------------------------------
# Model initialization functions
# -----------------------------------------------------------------------------
def load_model(exp_key: str = default_model):
fpath = exp_dir / model_dict.get(exp_key)
exp = Predictor(fpath)
return exp, None
# -----------------------------------------------------------------------------
# Vizualization functions
# -----------------------------------------------------------------------------
def _get_overlay(img, lay, const_color="l_blue"):
"""
Helper function for preparing overlay
"""
assert lay.ndim==2, "Overlay must be 2D, got shape: " + str(lay.shape)
if img.ndim == 2:
img = np.repeat(img[...,None], 3, axis=-1)
assert img.ndim==3, "Image must be 3D, got shape: " + str(img.shape)
if const_color == "blue":
const_color = 255*np.array([0, 0, 1])
elif const_color == "green":
const_color = 255*np.array([0, 1, 0])
elif const_color == "red":
const_color = 255*np.array([1, 0, 0])
elif const_color == "l_blue":
const_color = np.array([31, 119, 180])
elif const_color == "orange":
const_color = np.array([255, 127, 14])
else:
raise NotImplementedError
x,y = np.nonzero(lay)
for i in range(img.shape[-1]):
img[x,y,i] = const_color[i]
return img
def image_overlay(img, mask=None, scribbles=None, contour=False, alpha=0.5):
"""
Overlay the ground truth mask and scribbles on the image if provided
"""
assert img.ndim == 2, "Image must be 2D, got shape: " + str(img.shape)
output = np.repeat(img[...,None], 3, axis=-1)
if mask is not None:
assert mask.ndim == 2, "Mask must be 2D, got shape: " + str(mask.shape)
if contour:
contours = cv2.findContours((mask[...,None]>0.5).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(output, contours[0], -1, (0, 255, 0), 2)
else:
mask_overlay = _get_overlay(img, mask)
mask2 = 0.5*np.repeat(mask[...,None], 3, axis=-1)
output = cv2.convertScaleAbs(mask_overlay * mask2 + output * (1 - mask2))
if scribbles is not None:
pos_scribble_overlay = _get_overlay(output, scribbles[0,...], const_color="green")
cv2.addWeighted(pos_scribble_overlay, alpha, output, 1 - alpha, 0, output)
neg_scribble_overlay = _get_overlay(output, scribbles[1,...], const_color="red")
cv2.addWeighted(neg_scribble_overlay, alpha, output, 1 - alpha, 0, output)
return output
def viz_pred_mask(img, mask=None, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=True):
"""
Visualize image with clicks, scribbles, predicted mask overlaid
"""
assert isinstance(img, np.ndarray), "Image must be numpy array, got type: " + str(type(img))
if mask is not None:
if isinstance(mask, torch.Tensor):
mask = mask.cpu().numpy()
if binary and mask is not None:
mask = 1*(mask > 0.5)
out = image_overlay(img, mask=mask, scribbles=seperate_scribble_masks)
H,W = img.shape[:2]
marker_size = min(H,W)//100
if point_coords is not None:
for i,(col,row) in enumerate(point_coords):
if point_labels[i] == 1:
cv2.circle(out,(col, row), marker_size, (0,255,0), -1)
else:
cv2.circle(out,(col, row), marker_size, (255,0,0), -1)
if bbox_coords is not None:
for i in range(len(bbox_coords)//2):
cv2.rectangle(out, bbox_coords[2*i], bbox_coords[2*i+1], (255,165,0), marker_size)
if len(bbox_coords) % 2 == 1:
cv2.circle(out, tuple(bbox_coords[-1]), marker_size, (255,165,0), -1)
return out.astype(np.uint8)
# -----------------------------------------------------------------------------
# Collect scribbles
# -----------------------------------------------------------------------------
def get_scribbles(seperate_scribble_masks, last_scribble_mask, scribble_img):
"""
Record scribbles
"""
assert isinstance(seperate_scribble_masks, np.ndarray), "seperate_scribble_masks must be numpy array, got type: " + str(type(seperate_scribble_masks))
if scribble_img is not None:
# Only use first layer
color_mask = scribble_img.get('layers')[0]
positive_scribbles = 1.0*(color_mask[...,1] > 128)
negative_scribbles = 1.0*(color_mask[...,0] > 128)
seperate_scribble_masks = np.stack([positive_scribbles, negative_scribbles], axis=0)
last_scribble_mask = None
return seperate_scribble_masks, last_scribble_mask
def get_predictions(predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
low_res_mask, img_features, multimask_mode):
"""
Make predictions
"""
box = None
if len(bbox_coords) == 1:
gr.Error("Please click a second time to define the bounding box")
box = None
elif len(bbox_coords) == 2:
box = torch.Tensor(bbox_coords).flatten()[None,None,...].int().to(device) # B x n x 4
if seperate_scribble_masks is not None:
scribble = torch.from_numpy(seperate_scribble_masks)[None,...].to(device)
else:
scribble = None
prompts = dict(
img=torch.from_numpy(input_img)[None,None,...].to(device)/255,
point_coords=torch.Tensor([click_coords]).int().to(device) if len(click_coords)>0 else None,
point_labels=torch.Tensor([click_labels]).int().to(device) if len(click_labels)>0 else None,
scribble=scribble,
mask_input=low_res_mask.to(device) if low_res_mask is not None else None,
box=box,
)
mask, img_features, low_res_mask = predictor.predict(prompts, img_features, multimask_mode=multimask_mode)
return mask, img_features, low_res_mask
def refresh_predictions(predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
scribble_img, seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode):
# Record any new scribbles
seperate_scribble_masks, last_scribble_mask = get_scribbles(
seperate_scribble_masks, last_scribble_mask, scribble_img
)
# Make prediction
best_mask, img_features, low_res_mask = get_predictions(
predictor, input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, low_res_mask, img_features, multimask_mode
)
# Update input visualizations
mask_to_viz = best_mask.numpy()
click_input_viz = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox)
empty_channel = np.zeros(input_img.shape[:2]).astype(np.uint8)
full_channel = 255*np.ones(input_img.shape[:2]).astype(np.uint8)
gray_mask = (255*mask_to_viz).astype(np.uint8)
bg = viz_pred_mask(input_img, mask_to_viz, click_coords, click_labels, bbox_coords, None, binary_checkbox)
old_scribbles = scribble_img.get('layers')[0]
scribble_mask = 255*(old_scribbles > 0).any(-1)
scribble_input_viz = {
"background": np.stack([bg[...,i] for i in range(3)]+[full_channel], axis=-1),
["layers"][0]: [np.stack([
(255*seperate_scribble_masks[1]).astype(np.uint8),
(255*seperate_scribble_masks[0]).astype(np.uint8),
empty_channel,
scribble_mask
], axis=-1)],
"composite": np.stack([click_input_viz[...,i] for i in range(3)]+[empty_channel], axis=-1),
}
mask_img = 255*(mask_to_viz[...,None].repeat(axis=2, repeats=3)>0.5) if binary_checkbox else mask_to_viz[...,None].repeat(axis=2, repeats=3)
out_viz = [
viz_pred_mask(input_img, mask_to_viz, point_coords=None, point_labels=None, bbox_coords=None, seperate_scribble_masks=None, binary=binary_checkbox),
input_img,
mask_img,
]
return click_input_viz, scribble_input_viz, out_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask
def get_select_coords(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask,
click_coords, click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
output_img, binary_checkbox, multimask_mode, autopredict_checkbox, evt: gr.SelectData):
"""
Record user click and update the prediction
"""
# Record click coordinates
if bbox_label:
bbox_coords.append(evt.index)
elif brush_label in ['Positive (green)', 'Negative (red)']:
click_coords.append(evt.index)
click_labels.append(1 if brush_label=='Positive (green)' else 0)
else:
raise TypeError("Invalid brush label: {brush_label}")
# Only make new prediction if not waiting for additional bounding box click
if (len(bbox_coords) % 2 == 0) and autopredict_checkbox:
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
scribble_img, seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
)
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
else:
click_input_viz = viz_pred_mask(
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
)
scribble_input_viz = viz_pred_mask(
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
)
# Don't update output image if waiting for additional bounding box click
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
def undo_click(predictor, input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
output_img, binary_checkbox, multimask_mode, autopredict_checkbox):
"""
Remove last click and then update the prediction
"""
if bbox_label:
if len(bbox_coords) > 0:
bbox_coords.pop()
elif brush_label in ['Positive (green)', 'Negative (red)']:
if len(click_coords) > 0:
click_coords.pop()
click_labels.pop()
else:
raise TypeError("Invalid brush label: {brush_label}")
# Only make new prediction if not waiting for additional bounding box click
if (len(bbox_coords)==0 or len(bbox_coords)==2) and autopredict_checkbox:
click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, seperate_scribble_masks, last_scribble_mask = refresh_predictions(
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
scribble_img, seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
)
return click_input_viz, scribble_input_viz, output_viz, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
else:
click_input_viz = viz_pred_mask(
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
)
scribble_input_viz = viz_pred_mask(
input_img, best_mask, click_coords, click_labels, bbox_coords, None, binary_checkbox
)
# Don't update output image if waiting for additional bounding box click
return click_input_viz, scribble_input_viz, output_img, best_mask, low_res_mask, img_features, click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask
# --------------------------------------------------
with gr.Blocks(theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg)) as demo:
# State variables
seperate_scribble_masks = gr.State(np.zeros((2, H, W), dtype=np.float32))
last_scribble_mask = gr.State(np.zeros((H, W), dtype=np.float32))
click_coords = gr.State([])
click_labels = gr.State([])
bbox_coords = gr.State([])
# Load default model
predictor = gr.State(load_model()[0])
img_features = gr.State(None) # For SAM models
best_mask = gr.State(None)
low_res_mask = gr.State(None)
gr.HTML("""\
<h1 style="text-align: center; font-size: 28pt;">ScribblePrompt: Fast and Flexible Interactive Segmention for Any Biomedical Image</h1>
<p style="text-align: center; font-size: large;">
<b>ScribblePrompt</b> is an interactive segmentation tool designed to help users segment <b>new</b> structures in medical images using scribbles, clicks <b>and</b> bounding boxes.
[<a href="https://arxiv.org/abs/2312.07381">paper</a> | <a href="https://scribbleprompt.csail.mit.edu">website</a> | <a href="https://github.com/halleewong/ScribblePrompt">code</a>]
</p>
""")
with gr.Accordion("Open for instructions!", open=False):
gr.Markdown(
"""
* Select an input image from the examples below or upload your own image through the <b>'Input Image'</b> tab.
* Use the <b>'Scribbles'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> scribbles.
- Use the buttons in the top right hand corner of the canvas to undo or adjust the brush size
- Note: the app cannot detect new scribbles drawn on top of previous scribbles in a different color. Please undo/erase the scribble before drawing on the same pixel in a different color.
* Use the <b>'Clicks/Boxes'</b> tab to draw <span style='color:green'>positive</span> or <span style='color:red'>negative</span> clicks and <span style='color:orange'>bounding boxes</span> by placing two clicks.
* The <b>'Output'</b> tab will show the model's prediction based on your current inputs and the previous prediction.
* The <b>'Clear Input Mask'</b> button will clear the latest prediction (which is used as an input to the model).
* The <b>'Clear All Inputs'</b> button will clear all inputs (including scribbles, clicks, bounding boxes, and the last prediction).
"""
)
# Interface ------------------------------------
with gr.Row():
model_dropdown = gr.Dropdown(
label="Model",
choices = list(model_dict.keys()),
value=default_model,
multiselect=False,
interactive=False,
visible=False
)
with gr.Row():
with gr.Column(scale=1):
brush_label = gr.Radio(["Positive (green)", "Negative (red)"],
value="Positive (green)", label="Scribble/Click Label")
bbox_label = gr.Checkbox(value=False, label="Bounding Box (2 clicks)")
with gr.Column(scale=1):
binary_checkbox = gr.Checkbox(value=True, label="Show binary masks", visible=False)
autopredict_checkbox = gr.Checkbox(value=True, label="Auto-update prediction on clicks")
with gr.Accordion("Troubleshooting tips", open=False):
gr.Markdown("<span style='color:orange'>If you encounter an <span style='color:orange'>error</span> try clicking 'Clear All Inputs'.")
multimask_mode = gr.Checkbox(value=True, label="Multi-mask mode", visible=False)
with gr.Row():
display_height = 500
green_brush = gr.Brush(colors=["#00FF00"], color_mode="fixed", default_size=2)
red_brush = gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2)
with gr.Column(scale=1):
with gr.Tab("Scribbles"):
scribble_img = gr.ImageEditor(
label="Input",
image_mode="RGB",
brush=green_brush,
type='numpy',
value=default_example,
transforms=(),
sources=(),
show_download_button=True,
# height=display_height
)
with gr.Tab("Clicks/Boxes") as click_tab:
click_img = gr.Image(
label="Input",
type='numpy',
value=default_example,
show_download_button=True,
sources=(),
container=True,
# height=display_height-50
)
with gr.Row():
undo_click_button = gr.Button("Undo Last Click")
clear_click_button = gr.Button("Clear Clicks/Boxes", variant="stop")
with gr.Tab("Input Image"):
input_img = gr.Image(
label="Input",
image_mode="L",
value=default_example,
container=True
# height=display_height
)
gr.Markdown("To upload your own image: click the `x` in the top right corner to clear the current image, then drag & drop")
with gr.Column(scale=1):
with gr.Tab("Output"):
output_img = gr.Gallery(
label='Output',
columns=1,
elem_id="gallery",
preview=True,
object_fit="scale-down",
# height=display_height,
container=True
)
submit_button = gr.Button("Refresh Prediction", variant='primary')
clear_all_button = gr.ClearButton([scribble_img], value="Clear All Inputs", variant="stop")
clear_mask_button = gr.Button("Clear Input Mask")
# ----------------------------------------------
# Loading Models
# ----------------------------------------------
model_dropdown.change(fn=load_model,
inputs=[model_dropdown],
outputs=[predictor, img_features]
)
# ----------------------------------------------
# Loading Examples
# ----------------------------------------------
gr.Examples(examples=test_examples,
inputs=[input_img],
examples_per_page=12,
label='Examples from datasets unseen during training'
)
# When clear clicks button is clicked
def clear_click_history(input_img):
return input_img, input_img, [], [], [], None, None
clear_click_button.click(clear_click_history,
inputs=[input_img],
outputs=[click_img, scribble_img, click_coords, click_labels, bbox_coords, best_mask, low_res_mask])
# When clear all button is clicked
def clear_all_history(input_img):
if input_img is not None:
input_shape = input_img.shape[:2]
else:
input_shape = (H, W)
return input_img, input_img, [], [], [], [], np.zeros((2,)+input_shape, dtype=np.float32), np.zeros(input_shape, dtype=np.float32), None, None, None
# def clear_history_and_pad_input(input_img):
# if input_img is not None:
# h,w = input_img.shape[:2]
# if h != w:
# # Pad to square
# pad = abs(h-w)
# if h > w:
# padding = [(0,0), (math.ceil(pad/2),math.floor(pad/2))]
# else:
# padding = [(math.ceil(pad/2),math.floor(pad/2)), (0,0)]
# input_img = np.pad(input_img, padding, mode='constant', constant_values=0)
# return clear_all_history(input_img)
input_img.change(clear_all_history,
inputs=[input_img],
outputs=[click_img, scribble_img,
output_img, click_coords, click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features
])
clear_all_button.click(clear_all_history,
inputs=[input_img],
outputs=[click_img, scribble_img,
output_img, click_coords, click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features
])
# clear previous prediction mask
def clear_best_mask(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks):
click_input_viz = viz_pred_mask(
input_img, None, click_coords, click_labels, bbox_coords, seperate_scribble_masks
)
scribble_input_viz = viz_pred_mask(
input_img, None, click_coords, click_labels, bbox_coords, None
)
return None, None, click_input_viz, scribble_input_viz
clear_mask_button.click(
clear_best_mask,
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks],
outputs=[best_mask, low_res_mask, click_img, scribble_img],
)
# ----------------------------------------------
# Clicks
# ----------------------------------------------
click_img.select(get_select_coords,
inputs=[
predictor,
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords, click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
],
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
api_name = "get_select_coords"
)
submit_button.click(fn=refresh_predictions,
inputs=[
predictor, input_img, output_img, click_coords, click_labels, bbox_coords, brush_label,
scribble_img, seperate_scribble_masks, last_scribble_mask,
best_mask, low_res_mask, img_features, binary_checkbox, multimask_mode
],
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
seperate_scribble_masks, last_scribble_mask],
api_name="refresh_predictions"
)
undo_click_button.click(fn=undo_click,
inputs=[
predictor,
input_img, brush_label, bbox_label, best_mask, low_res_mask, click_coords,
click_labels, bbox_coords,
seperate_scribble_masks, last_scribble_mask, scribble_img, img_features,
output_img, binary_checkbox, multimask_mode, autopredict_checkbox
],
outputs=[click_img, scribble_img, output_img, best_mask, low_res_mask, img_features,
click_coords, click_labels, bbox_coords, seperate_scribble_masks, last_scribble_mask],
api_name="undo_click"
)
def update_click_img(input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox,
last_scribble_mask, scribble_img, brush_label, best_mask):
"""
Draw scribbles in the click canvas
"""
seperate_scribble_masks, last_scribble_mask = get_scribbles(
seperate_scribble_masks, last_scribble_mask, scribble_img
)
click_input_viz = viz_pred_mask(
input_img, best_mask, click_coords, click_labels, bbox_coords, seperate_scribble_masks, binary_checkbox
)
return click_input_viz, seperate_scribble_masks, last_scribble_mask
click_tab.select(fn=update_click_img,
inputs=[input_img, click_coords, click_labels, bbox_coords, seperate_scribble_masks,
binary_checkbox, last_scribble_mask, scribble_img, brush_label, best_mask],
outputs=[click_img, seperate_scribble_masks, last_scribble_mask],
api_name="update_click_img"
)
# ----------------------------------------------
# Scribbles
# ----------------------------------------------
def change_brush_color(seperate_scribble_masks, last_scribble_mask, scribble_img, label):
"""
Recorn new scribbles when changing brush color
"""
if label == "Negative (red)":
brush_update = gr.update(brush=red_brush)
elif label == "Positive (green)":
brush_update = gr.update(brush=green_brush)
else:
raise TypeError("Invalid brush color")
return seperate_scribble_masks, last_scribble_mask, brush_update
brush_label.change(fn=change_brush_color,
inputs=[seperate_scribble_masks, last_scribble_mask, scribble_img, brush_label],
outputs=[seperate_scribble_masks, last_scribble_mask, scribble_img],
api_name="change_brush_color"
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False)