merve HF staff commited on
Commit
2cd9d38
1 Parent(s): e8fa64c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import SamModel, SamProcessor
6
+ from gradio_image_prompter import ImagePrompter
7
+
8
+
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
11
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
12
+ slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
13
+ slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
14
+
15
+ def sam_box_inference(image, model, x_min, y_min, x_max, y_max):
16
+
17
+ inputs = sam_processor(
18
+ Image.fromarray(image),
19
+ input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
20
+ return_tensors="pt"
21
+ ).to(device)
22
+
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+
26
+ mask = sam_processor.image_processor.post_process_masks(
27
+ outputs.pred_masks.cpu(),
28
+ inputs["original_sizes"].cpu(),
29
+ inputs["reshaped_input_sizes"].cpu()
30
+ )[0][0][0].numpy()
31
+ mask = mask[np.newaxis, ...]
32
+ print(mask)
33
+ print(mask.shape)
34
+ return [(mask, "mask")]
35
+
36
+
37
+ def sam_point_inference(image, model, x, y):
38
+ inputs = sam_processor(
39
+ image,
40
+ input_points=[[[x, y]]],
41
+ return_tensors="pt").to(device)
42
+
43
+ with torch.no_grad():
44
+ outputs = sam_model(**inputs)
45
+
46
+ mask = sam_processor.post_process_masks(
47
+ outputs.pred_masks.cpu(),
48
+ inputs["original_sizes"].cpu(),
49
+ inputs["reshaped_input_sizes"].cpu()
50
+ )[0][0][0].numpy()
51
+ mask = mask[np.newaxis, ...]
52
+ print(type(mask))
53
+ print(mask.shape)
54
+ return [(mask, "mask")]
55
+
56
+ def infer_point(img):
57
+ if img is None:
58
+ gr.Error("Please upload an image and select a point.")
59
+ if img["background"] is None:
60
+ gr.Error("Please upload an image and select a point.")
61
+ # background (original image) layers[0] ( point prompt) composite (total image)
62
+ image = img["background"].convert("RGB")
63
+ point_prompt = img["layers"][0]
64
+ total_image = img["composite"]
65
+ img_arr = np.array(point_prompt)
66
+ if not np.any(img_arr):
67
+ gr.Error("Please select a point on top of the image.")
68
+ else:
69
+ nonzero_indices = np.nonzero(img_arr)
70
+ img_arr = np.array(point_prompt)
71
+ nonzero_indices = np.nonzero(img_arr)
72
+ center_x = int(np.mean(nonzero_indices[1]))
73
+ center_y = int(np.mean(nonzero_indices[0]))
74
+ print("Point inference returned.")
75
+ return ((image, sam_point_inference(image, slimsam_model, center_x, center_y)),
76
+ (image, sam_point_inference(image, sam_model, center_x, center_y)))
77
+
78
+ def infer_box(prompts):
79
+ # background (original image) layers[0] ( point prompt) composite (total image)
80
+ image = prompts["image"]
81
+ if image is None:
82
+ gr.Error("Please upload an image and draw a box before submitting")
83
+ points = prompts["points"][0]
84
+ if points is None:
85
+ gr.Error("Please draw a box before submitting.")
86
+ print(points)
87
+
88
+ # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
89
+ return ((image, sam_box_inference(image, slimsam_model, points[0], points[1], points[3], points[4])),
90
+ (image, sam_box_inference(image, sam_model, points[0], points[1], points[3], points[4])))
91
+ with gr.Blocks(title="SlimSAM") as demo:
92
+ gr.Markdown("# SlimSAM")
93
+ gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
94
+ gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
95
+
96
+ with gr.Tab("Box Prompt"):
97
+ with gr.Row():
98
+ with gr.Column(scale=1):
99
+ # Title
100
+ gr.Markdown("Box Prompting")
101
+ with gr.Row():
102
+ with gr.Column():
103
+ im = ImagePrompter()
104
+ btn = gr.Button("Submit")
105
+ with gr.Column():
106
+ output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
107
+ output_box_sam = gr.AnnotatedImage(label="SAM Output")
108
+
109
+
110
+ btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])
111
+
112
+ with gr.Tab("Point Prompt"):
113
+ with gr.Row():
114
+ with gr.Column(scale=1):
115
+ # Title
116
+ gr.Markdown("Point Prompting")
117
+ with gr.Row():
118
+ with gr.Column():
119
+ im = gr.ImageEditor(
120
+ type="pil",
121
+ )
122
+ with gr.Column():
123
+ output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
124
+ output_sam = gr.AnnotatedImage(label="SAM Output")
125
+
126
+ im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
127
+ demo.launch(debug=True)