File size: 1,927 Bytes
7aef3af
 
 
 
 
 
c456e88
7aef3af
 
 
 
 
 
 
 
 
5b6ac69
 
 
 
7aef3af
 
 
 
 
 
 
 
0ebcb8d
 
 
 
 
 
 
 
 
 
 
7aef3af
 
 
0ebcb8d
 
7aef3af
 
 
230e159
 
 
c85b146
 
 
 
 
 
0ebcb8d
7aef3af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from huggingface_hub import snapshot_download
import gradio as gr
import numpy as np
import torch
import sys
from tinysam import sam_model_registry, SamPredictor


snapshot_download("merve/tinysam", local_dir="tinysam")

model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth")

predictor = SamPredictor(sam)

def infer(img):
  if img is None:
      gr.Error("Please upload an image and select a point.")
  if img["background"] is None:
      gr.Error("Please upload an image and select a point.")
  # background (original image) layers[0] ( point prompt) composite (total image)
  image = img["background"].convert("RGB")
  point_prompt = img["layers"][0]
  total_image = img["composite"]
  predictor.set_image(np.array(image))

  # get point prompt
  img_arr = np.array(point_prompt)
  if not np.any(img_arr):
    gr.Error("Please select a point on top of the image.")
  else:
    nonzero_indices = np.nonzero(img_arr)
    img_arr = np.array(point_prompt)
    nonzero_indices = np.nonzero(img_arr)
    center_x = int(np.mean(nonzero_indices[1]))
    center_y = int(np.mean(nonzero_indices[0]))
    input_point = np.array([[center_x, center_y]])
    input_label = np.array([1])
    masks, scores, logits = predictor.predict(
      point_coords=input_point,
      point_labels=input_label,
  )
    result_label = [(masks[0, :, :], "mask")]
    return image, result_label


with gr.Blocks() as demo:
    gr.Markdown("## TinySAM")
    gr.Markdown("**[TinySAM](https://arxiv.org/abs/2312.13789) is a framework to distill Segment Anything Model.**")
    gr.Markdown("**To try it out, simply upload an image and leave a point on what you would like to segment.**")
    with gr.Row():
        with gr.Column():
            im = gr.ImageEditor(
                type="pil"
            )
        output = gr.AnnotatedImage()
    im.change(infer, inputs=im, outputs=output)

demo.launch(debug=True)