|
import os |
|
from typing import Any, Dict |
|
|
|
import cv2 |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from gradio_image_annotation import image_annotator |
|
from sam2 import load_model |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
from src.plot_utils import export_mask |
|
|
|
from spaces import GPU |
|
|
|
os.environ["ZEROGPU_V2"] = "true" |
|
|
|
@GPU() |
|
def predict(model_choice, annotations: Dict[str, Any]): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
sam2_model = load_model( |
|
variant=model_choice, |
|
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt", |
|
device=device, |
|
) |
|
predictor = SAM2ImagePredictor(sam2_model) |
|
predictor.set_image(annotations["image"]) |
|
coordinates = [] |
|
for i in range(len(annotations["boxes"])): |
|
coordinate = [ |
|
int(annotations["boxes"][i]["xmin"]), |
|
int(annotations["boxes"][i]["ymin"]), |
|
int(annotations["boxes"][i]["xmax"]), |
|
int(annotations["boxes"][i]["ymax"]), |
|
] |
|
coordinates.append(coordinate) |
|
|
|
masks, scores, _ = predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=np.array(coordinates), |
|
multimask_output=False, |
|
) |
|
|
|
if masks.shape[0] == 1: |
|
|
|
masks = np.expand_dims(masks, axis=0) |
|
|
|
return export_mask(masks) |
|
|
|
|
|
with gr.Blocks(delete_cache=(30, 30)) as demo: |
|
gr.Markdown( |
|
""" |
|
# 1. Choose Model Checkpoint |
|
""" |
|
) |
|
with gr.Row(): |
|
model = gr.Dropdown( |
|
choices=["tiny", "small", "base_plus", "large"], |
|
value="tiny", |
|
label="Model Checkpoint", |
|
info="Which model checkpoint to load?", |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
# 2. Upload your Image and draw bounding box(es) |
|
""" |
|
) |
|
|
|
annotator = image_annotator( |
|
value={"image": cv2.imread("assets/example.png")}, |
|
disable_edit_boxes=True, |
|
label="Draw a bounding box", |
|
) |
|
btn = gr.Button("Get Segmentation Mask(s)") |
|
btn.click( |
|
fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")] |
|
) |
|
|
|
demo.launch() |
|
|