File size: 4,918 Bytes
2cd9d38
 
 
 
 
 
b4e8f1d
2cd9d38
 
b4e8f1d
2cd9d38
b4e8f1d
2cd9d38
 
b4e8f1d
 
 
 
 
 
 
2cd9d38
b4e8f1d
 
 
2cd9d38
 
 
 
 
 
 
 
b4e8f1d
2cd9d38
 
 
 
 
 
 
 
 
b4e8f1d
 
2cd9d38
b4e8f1d
 
 
2cd9d38
 
 
 
 
b4e8f1d
2cd9d38
b4e8f1d
2cd9d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4e8f1d
 
2cd9d38
 
 
 
 
 
 
 
 
 
 
 
b4e8f1d
 
2cd9d38
 
 
 
 
bb8da79
2cd9d38
 
 
ad09938
2cd9d38
 
 
 
 
 
 
 
 
 
 
bb8da79
2cd9d38
 
 
ad09938
2cd9d38
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
from gradio_image_prompter import ImagePrompter
import spaces

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to("cuda")
slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")

def get_processor_and_model(slim: bool):
    if slim:
        return slimsam_processor, slimsam_model
    return sam_processor, sam_model

@spaces.GPU
def sam_box_inference(image, x_min, y_min, x_max, y_max, *, slim=False):

    processor, model = get_processor_and_model(slim)
    
    inputs = processor(
        Image.fromarray(image),
        input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    mask = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    print(mask)
    print(mask.shape)
    return [(mask, "mask")]

@spaces.GPU
def sam_point_inference(image, x, y, *, slim=False):

    processor, model = get_processor_and_model(slim)
    
    inputs = processor(
        image,
        input_points=[[[x, y]]],
        return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    mask = processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    print(type(mask))
    print(mask.shape)
    return [(mask, "mask")]

def infer_point(img):
    if img is None:
        gr.Error("Please upload an image and select a point.")
    if img["background"] is None:
        gr.Error("Please upload an image and select a point.")
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = img["background"].convert("RGB")
    point_prompt = img["layers"][0]
    total_image = img["composite"]
    img_arr = np.array(point_prompt)
    if not np.any(img_arr):
        gr.Error("Please select a point on top of the image.")
    else:
        nonzero_indices = np.nonzero(img_arr)
        img_arr = np.array(point_prompt)
        nonzero_indices = np.nonzero(img_arr)
        center_x = int(np.mean(nonzero_indices[1]))
        center_y = int(np.mean(nonzero_indices[0]))
    print("Point inference returned.")
    return ((image, sam_point_inference(image, center_x, center_y, slim=True)),
    (image, sam_point_inference(image, center_x, center_y)))

def infer_box(prompts):
    # background (original image) layers[0] ( point prompt) composite (total image)
    image = prompts["image"]
    if image is None:
      gr.Error("Please upload an image and draw a box before submitting")
    points = prompts["points"][0]
    if points is None:
      gr.Error("Please draw a box before submitting.")
    print(points)

    # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
    return ((image, sam_box_inference(image, points[0], points[1], points[3], points[4], slim=True)),
    (image, sam_box_inference(image, points[0], points[1], points[3], points[4])))
with gr.Blocks(title="SlimSAM") as demo:
  gr.Markdown("# SlimSAM")
  gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
  gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
  
  with gr.Tab("Box Prompt"):
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
    with gr.Row():
        with gr.Column():
            im = ImagePrompter()
            btn = gr.Button("Submit")
        with gr.Column():
          output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
          output_box_sam = gr.AnnotatedImage(label="SAM Output")


    btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])

  with gr.Tab("Point Prompt"):
    with gr.Row():
        with gr.Column(scale=1):
            # Title
            gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
    with gr.Row():
        with gr.Column():
            im = gr.ImageEditor(
                type="pil",
            )
        with gr.Column():
          output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
          output_sam = gr.AnnotatedImage(label="SAM Output")

    im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
demo.launch(debug=True)