File size: 2,657 Bytes
3d74359
8260e47
 
bf29adc
 
 
410698b
6038d30
bf29adc
 
410698b
bf29adc
630e69b
8a9151b
630e69b
410698b
 
 
8a9151b
410698b
6038d30
 
 
 
 
 
 
 
 
 
 
 
8260e47
6038d30
 
 
 
 
 
 
 
 
 
bf29adc
6038d30
410698b
6038d30
 
 
 
bf29adc
 
 
af62359
 
 
 
 
bf29adc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95190fc
bf29adc
 
 
 
630e69b
bf29adc
 
 
95190fc
410698b
 
8260e47
bf29adc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Any, Dict

import cv2
import gradio as gr
import numpy as np
from gradio_image_annotation import image_annotator
from sam2 import load_model
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.sam2_image_predictor import SAM2ImagePredictor

from src.plot_utils import export_mask


# @spaces.GPU()
def predict(model_choice, annotations: Dict[str, Any]):
    sam2_model = load_model(
        variant=model_choice,
        ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
        device="cpu",
    )
    if annotations["boxes"]:
        predictor = SAM2ImagePredictor(sam2_model)  # type:ignore
        predictor.set_image(annotations["image"])
        coordinates = []
        for i in range(len(annotations["boxes"])):
            coordinate = [
                int(annotations["boxes"][i]["xmin"]),
                int(annotations["boxes"][i]["ymin"]),
                int(annotations["boxes"][i]["xmax"]),
                int(annotations["boxes"][i]["ymax"]),
            ]
            coordinates.append(coordinate)

        masks, scores, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=np.array(coordinates),
            multimask_output=False,
        )

        if masks.shape[0] == 1:
            # handle single mask cases
            masks = np.expand_dims(masks, axis=0)

        return export_mask(masks)

    else:
        mask_generator = SAM2AutomaticMaskGenerator(sam2_model)  # type: ignore
        masks = mask_generator.generate(annotations["image"])
        return export_mask(masks, autogenerated=True)


with gr.Blocks(delete_cache=(30, 30)) as demo:
    gr.Markdown(
        """
        ## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends)
        """
    )
    gr.Markdown(
        """
        # 1. Choose Model Checkpoint
        """
    )
    with gr.Row():
        model = gr.Dropdown(
            choices=["tiny", "small", "base_plus", "large"],
            value="tiny",
            label="Model Checkpoint",
            info="Which model checkpoint to load?",
        )

    gr.Markdown(
        """
        # 2. Upload your Image and draw bounding box(es)
        """
    )

    annotator = image_annotator(
        value={"image": cv2.imread("assets/example.png")},
        disable_edit_boxes=True,
        label="Draw a bounding box",
    )
    btn = gr.Button("Get Segmentation Mask(s)")
    btn.click(
        fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
    )

demo.launch()