Spaces:
Runtime error
Runtime error
import gradio as gr | |
from grounded_sam.inference import grounded_segmentation | |
from grounded_sam.plot import plot_detections, plot_detections_plotly | |
def app_fn( | |
image: gr.Image, | |
labels: str, | |
threshold: float, | |
bounding_box_selection: bool | |
) -> str: | |
labels = labels.split("\n") | |
labels = [label if label.endswith(".") else label + "." for label in labels] | |
image_array, detections = grounded_segmentation(image, labels, threshold, True) | |
fig_detection = plot_detections_plotly(image_array, detections, bounding_box_selection) | |
return fig_detection | |
if __name__=="__main__": | |
title = "Grounding SAM - Text-to-Segmentation Model" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown( | |
""" | |
Grounded SAM is a text-to-segmentation model that generates segmentation masks from natural language descriptions. | |
This demo uses Grounding DINO in tandem with SAM to generate segmentation masks from text. | |
The workflow is as follows: | |
1. Select text labels to generate bounding boxes with Grounding DINO. | |
2. Prompt the SAM model to generate segmentation masks from the bounding boxes. | |
3. Refine the masks if needed. | |
4. Visualize the segmentation masks. | |
### Notes | |
- To pass multiple labels, separate them by a new line. | |
- The model may take a few seconds to generate the segmentation masks as we need to run through two models. | |
- The refinement is done by default by converting the mask to a polygon and back to a mask with openCV. | |
- I use in here a concise implementation, but you can find the full code at [GitHub](https://github.com/EduardoPach/grounded-sam) | |
""" | |
) | |
with gr.Row(): | |
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="Box Threshold") | |
labels = gr.Textbox(lines=2, max_lines=5, label="Labels") | |
bounding_box_selection = gr.Checkbox(label="Allow Box Selection") | |
btn = gr.Button() | |
with gr.Row(): | |
img = gr.Image(type="pil") | |
fig = gr.Plot(label="Segmentation Mask") | |
btn.click(fn=app_fn, inputs=[img, labels, threshold, bounding_box_selection], outputs=[fig]) | |
gr.Examples( | |
[ | |
["input_image.jpeg", "a person.\na mountain.", 0.3, False], | |
], | |
inputs = [img, labels, threshold, bounding_box_selection], | |
outputs = [fig], | |
fn=app_fn, | |
cache_examples=True, | |
label='Try this example input!' | |
) | |
demo.launch() |