Spaces:
Running
on
T4
Running
on
T4
from huggingface_hub import snapshot_download | |
import gradio as gr | |
import numpy as np | |
import torch | |
import sys | |
from tinysam import sam_model_registry, SamPredictor | |
snapshot_download("merve/tinysam", local_dir="tinysam") | |
model_type = "vit_t" | |
sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth") | |
predictor = SamPredictor(sam) | |
def infer(img): | |
if img is None: | |
gr.Error("Please upload an image and select a point.") | |
if img["background"] is None: | |
gr.Error("Please upload an image and select a point.") | |
# background (original image) layers[0] ( point prompt) composite (total image) | |
image = img["background"].convert("RGB") | |
point_prompt = img["layers"][0] | |
total_image = img["composite"] | |
predictor.set_image(np.array(image)) | |
# get point prompt | |
img_arr = np.array(point_prompt) | |
if not np.any(img_arr): | |
gr.Error("Please select a point on top of the image.") | |
else: | |
nonzero_indices = np.nonzero(img_arr) | |
img_arr = np.array(point_prompt) | |
nonzero_indices = np.nonzero(img_arr) | |
center_x = int(np.mean(nonzero_indices[1])) | |
center_y = int(np.mean(nonzero_indices[0])) | |
input_point = np.array([[center_x, center_y]]) | |
input_label = np.array([1]) | |
masks, scores, logits = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
) | |
result_label = [(masks[scores.argmax(), :, :], "mask")] | |
return image, result_label | |
with gr.Blocks() as demo: | |
gr.Markdown("## TinySAM") | |
gr.Markdown("**[TinySAM](https://arxiv.org/abs/2312.13789) is a framework to distill Segment Anything Model.**") | |
gr.Markdown("**To try it out, simply upload an image, click the green tick, and then leave a point mark on what you would like to segment using the pencil on Image Editor.**") | |
with gr.Row(): | |
with gr.Column(): | |
im = gr.ImageEditor( | |
type="pil" | |
) | |
output = gr.AnnotatedImage() | |
im.change(infer, inputs=im, outputs=output) | |
demo.launch(debug=True) |