File size: 2,651 Bytes
c73bc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52a4788
c73bc62
 
 
 
2758806
c73bc62
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr

from grounded_sam.inference import grounded_segmentation
from grounded_sam.plot import plot_detections, plot_detections_plotly

def app_fn(
    image: gr.Image,
    labels: str,
    threshold: float,
    bounding_box_selection: bool
) -> str:
    labels = labels.split("\n")
    labels = [label if label.endswith(".") else label + "." for label in labels]
    image_array, detections = grounded_segmentation(image, labels, threshold, True)
    fig_detection = plot_detections_plotly(image_array, detections, bounding_box_selection)

    return fig_detection

if __name__=="__main__":
    title = "Grounding SAM - Text-to-Segmentation Model"
    with gr.Blocks(title=title) as demo:
        gr.Markdown(f"# {title}")
        gr.Markdown(
            """
            Grounded SAM is a text-to-segmentation model that generates segmentation masks from natural language descriptions.
            This demo uses Grounding DINO in tandem with SAM to generate segmentation masks from text.
            The workflow is as follows:
            1. Select text labels to generate bounding boxes with Grounding DINO.
            2. Prompt the SAM model to generate segmentation masks from the bounding boxes.
            3. Refine the masks if needed.
            4. Visualize the segmentation masks.


            ### Notes
            - To pass multiple labels, separate them by a new line.
            - The model may take a few seconds to generate the segmentation masks as we need to run through two models.
            - The refinement is done by default by converting the mask to a polygon and back to a mask with openCV.
            - I use in here a concise implementation, but you can find the full code at [GitHub](https://github.com/EduardoPach/grounded-sam)
            """
        )
        with gr.Row():
            threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.05, label="Box Threshold")
            labels = gr.Textbox(lines=2, max_lines=5, label="Labels")
            bounding_box_selection = gr.Checkbox(label="Allow Box Selection")
        btn = gr.Button()
        with gr.Row():
            img = gr.Image(type="pil")
        fig = gr.Plot(label="Segmentation Mask")

        btn.click(fn=app_fn, inputs=[img, labels, threshold, bounding_box_selection], outputs=[fig])

        gr.Examples(
          [
            ["input_image.jpeg", "a person.\na mountain.", 0.3, False],
        ],
          inputs = [img, labels, threshold, bounding_box_selection],
          outputs = [fig],
          fn=app_fn,
          cache_examples=False,
          label='Try this example input!'
      )

    demo.launch()