dragonSwing commited on
Commit
67d4d3e
1 Parent(s): 8b0abc0

Upload files

Browse files
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pyo
3
+ *.pyd
4
+ __py
5
+ **/__pycache__/
6
+ data
7
+ onnx
8
+ results
9
+ **.egg-info
10
+ *.log
11
+ *.onnx
12
+ .hypothesis
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+
4
+ from nanosam import Predictor
5
+
6
+ import gradio as gr
7
+ import time
8
+ from PIL import ImageDraw
9
+ from utils import download_file_from_url, fast_process, format_results, point_prompt
10
+
11
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
12
+
13
+ if not os.path.exists("onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx"):
14
+ download_file_from_url(
15
+ "https://huggingface.co/dragonSwing/nanosam/resolve/main/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
16
+ "onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
17
+ )
18
+
19
+ if not os.path.exists("onnx/efficientvit_l0_mask_decoder.onnx"):
20
+ download_file_from_url(
21
+ "https://huggingface.co/dragonSwing/nanosam/resolve/main/efficientvit_l0_mask_decoder.onnx",
22
+ "onnx/efficientvit_l0_mask_decoder.onnx",
23
+ )
24
+
25
+ # Load the pre-trained model
26
+ image_encoder_cfg = {
27
+ "path": "onnx/sam_hgv2_b4_ln_nonorm_image_encoder.onnx",
28
+ "provider": "cpu",
29
+ "normalize_input": False,
30
+ }
31
+ mask_decoder_cfg = {
32
+ "path": "onnx/efficientvit_l0_mask_decoder.onnx",
33
+ "provider": "cpu",
34
+ }
35
+ predictor = Predictor(image_encoder_cfg, mask_decoder_cfg)
36
+
37
+ # Description
38
+ title = "<center><strong><font size='8'>Faster Segment Anything(NanoSAM)<font></strong></center>"
39
+
40
+ description_p = """ ## This is a demo of [Faster Segment Anything(NanoSAM) Model](https://github.com/binh234/nanosam).
41
+ # Instructions for point mode
42
+ 0. Restart by click the Restart button
43
+ 1. Select a point with Add Mask for the foreground (Must)
44
+ 2. Select a point with Remove Area for the background (Optional)
45
+ 3. Click the Start Segmenting.
46
+ - Github [link](https://github.com/binh234/nanosam)
47
+ - Model Card [link](https://huggingface.co/dragoswing/nanosam)
48
+ We will provide box mode soon.
49
+ Enjoy!
50
+ """
51
+
52
+ examples = [
53
+ ["assets/picture3.jpg"],
54
+ ["assets/picture4.jpg"],
55
+ ["assets/picture5.jpg"],
56
+ ["assets/picture6.jpg"],
57
+ ["assets/picture1.jpg"],
58
+ ["assets/picture2.jpg"],
59
+ ["assets/dogs.jpg"],
60
+ ]
61
+
62
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
63
+
64
+
65
+ def get_empty_state():
66
+ return {"points": [], "point_labels": [], "features": None}
67
+
68
+
69
+ def clear():
70
+ return None, None, get_empty_state()
71
+
72
+
73
+ def set_image(image):
74
+ state = get_empty_state()
75
+ start = time.perf_counter()
76
+ predictor.set_image(image)
77
+ end = time.perf_counter()
78
+ print(f"Encoder time: {end - start: .3f}s")
79
+ state["features"] = predictor.features
80
+ return state
81
+
82
+
83
+ def segment_with_points(
84
+ image,
85
+ state,
86
+ better_quality=False,
87
+ withContours=True,
88
+ use_retina=True,
89
+ mask_random_color=True,
90
+ ):
91
+ global predictor
92
+
93
+ points = np.asarray(state["points"])
94
+ point_labels = np.asarray(state["point_labels"])
95
+ if len(points) == 0 and len(point_labels) == 0:
96
+ raise gr.Error("No points selected")
97
+ if len(points) != len(point_labels):
98
+ raise gr.Error("Mismatch length between points and point labels")
99
+ if state["features"] is None:
100
+ raise gr.Error(
101
+ "Image was not set correctly, please wait for a moment after uploading image before drawing points!"
102
+ )
103
+
104
+ predictor.features = state["features"]
105
+ img_w, img_h = image.size
106
+ predictor.original_size = (img_h, img_w)
107
+ start = time.perf_counter()
108
+ masks, scores, logits = predictor.predict(
109
+ points=points,
110
+ point_labels=point_labels,
111
+ )
112
+ end = time.perf_counter()
113
+ print(f"Decoder time: {end - start: .3f}s")
114
+
115
+ # results = format_results(masks[0], scores[0], logits[0], 0)
116
+
117
+ # annotations, _ = point_prompt(results, points, point_labels, img_h, img_w)
118
+ # annotations = np.array([annotations])
119
+
120
+ fig = fast_process(
121
+ annotations=[masks[0, scores.argmax()] > 0],
122
+ image=image,
123
+ scale=1,
124
+ better_quality=better_quality,
125
+ mask_random_color=mask_random_color,
126
+ bbox=None,
127
+ use_retina=use_retina,
128
+ withContours=withContours,
129
+ )
130
+
131
+ # return fig, None
132
+ return fig
133
+
134
+
135
+ def get_points_with_draw(image, label, evt: gr.SelectData, state):
136
+ x, y = evt.index[0], evt.index[1]
137
+ point_radius, point_color = 15, (
138
+ (255, 255, 0)
139
+ if label == "Add Mask"
140
+ else (
141
+ 255,
142
+ 0,
143
+ 255,
144
+ )
145
+ )
146
+ state["points"].append([x, y])
147
+ state["point_labels"].append(1 if label == "Add Mask" else 0)
148
+
149
+ print(x, y, label == "Add Mask")
150
+
151
+ draw = ImageDraw.Draw(image)
152
+ draw.ellipse(
153
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
154
+ fill=point_color,
155
+ )
156
+ return image, state
157
+
158
+
159
+ cond_img_p = gr.Image(label="Input with points", type="pil", interactive=True)
160
+
161
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type="pil")
162
+
163
+ global_points = []
164
+ global_point_labels = []
165
+
166
+ with gr.Blocks(css=css, title="Faster Segment Anything(NanoSAM)") as demo:
167
+ state = gr.State(value=get_empty_state())
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ # Title
171
+ gr.Markdown(title)
172
+
173
+ with gr.Tab("Point mode"):
174
+ # Images
175
+ with gr.Row(variant="panel"):
176
+ with gr.Column(scale=1):
177
+ cond_img_p.render()
178
+
179
+ with gr.Column(scale=1):
180
+ segm_img_p.render()
181
+
182
+ # Submit & Clear
183
+ with gr.Row():
184
+ with gr.Column():
185
+ with gr.Row():
186
+ add_or_remove = gr.Radio(
187
+ ["Add Mask", "Remove Area"],
188
+ value="Add Mask",
189
+ )
190
+
191
+ with gr.Column():
192
+ segment_btn_p = gr.Button("Start segmenting!", variant="primary")
193
+ restart_btn_p = gr.Button("Restart", variant="secondary")
194
+
195
+ gr.Markdown("Try some of the examples below ⬇️")
196
+ gr.Examples(
197
+ examples=examples,
198
+ inputs=[cond_img_p],
199
+ outputs=[state],
200
+ fn=set_image,
201
+ run_on_click=True,
202
+ examples_per_page=4,
203
+ )
204
+
205
+ with gr.Column():
206
+ # Description
207
+ gr.Markdown(description_p)
208
+
209
+ cond_img_p.upload(set_image, inputs=[cond_img_p], outputs=[state])
210
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove, state], [cond_img_p, state])
211
+ segment_btn_p.click(segment_with_points, [cond_img_p, state], [segm_img_p])
212
+ restart_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, state])
213
+
214
+ demo.queue().launch()
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/dogs.jpg ADDED
assets/picture1.jpg ADDED
assets/picture2.jpg ADDED
assets/picture3.jpg ADDED
assets/picture4.jpg ADDED
assets/picture5.jpg ADDED
assets/picture6.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ onnx>=1.14.0
3
+ onnxruntime>=1.14.0
4
+ opencv-python-headless
5
+ git+https://github.com/binh234/nanosam.git
utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import requests
7
+ from PIL import Image
8
+
9
+
10
+ def point_prompt(masks, points, point_label, target_height, target_width):
11
+ h = masks[0]["segmentation"].shape[0]
12
+ w = masks[0]["segmentation"].shape[1]
13
+ if h != target_height or w != target_width:
14
+ points = [
15
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
16
+ for point in points
17
+ ]
18
+ onemask = np.zeros((h, w))
19
+ for i, annotation in enumerate(masks):
20
+ if type(annotation) == dict:
21
+ mask = annotation["segmentation"]
22
+ else:
23
+ mask = annotation
24
+ for i, point in enumerate(points):
25
+ if mask[point[1], point[0]] == 1:
26
+ if point_label[i] == 0:
27
+ onemask -= mask
28
+ else:
29
+ onemask += mask
30
+ break
31
+ onemask = onemask > 0
32
+ return onemask, 0
33
+
34
+
35
+ def format_results(masks, scores, logits, filter=0):
36
+ annotations = []
37
+ n = len(scores)
38
+ for i in range(n):
39
+ annotation = {}
40
+
41
+ mask = masks[i] > 0
42
+ tmp = np.where(mask)
43
+ annotation["id"] = i
44
+ annotation["segmentation"] = mask
45
+ annotation["bbox"] = [
46
+ np.min(tmp[0]),
47
+ np.min(tmp[1]),
48
+ np.max(tmp[1]),
49
+ np.max(tmp[0]),
50
+ ]
51
+ annotation["score"] = scores[i]
52
+ annotation["area"] = mask.sum()
53
+ annotations.append(annotation)
54
+ return annotations
55
+
56
+
57
+ def fast_process(
58
+ annotations,
59
+ image,
60
+ scale,
61
+ better_quality=False,
62
+ mask_random_color=True,
63
+ bbox=None,
64
+ use_retina=True,
65
+ withContours=True,
66
+ ):
67
+ if isinstance(annotations[0], dict):
68
+ annotations = [annotation["segmentation"] for annotation in annotations]
69
+
70
+ original_h = image.height
71
+ original_w = image.width
72
+ if better_quality:
73
+ for i, mask in enumerate(annotations):
74
+ mask = cv2.morphologyEx(
75
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
76
+ )
77
+ annotations[i] = cv2.morphologyEx(
78
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
79
+ )
80
+ annotations = np.asarray(annotations)
81
+ inner_mask = fast_show_mask(
82
+ annotations,
83
+ plt.gca(),
84
+ random_color=mask_random_color,
85
+ bbox=bbox,
86
+ retinamask=use_retina,
87
+ target_height=original_h,
88
+ target_width=original_w,
89
+ )
90
+
91
+ if withContours:
92
+ contour_all = []
93
+ temp = np.zeros((original_h, original_w, 1))
94
+ for i, mask in enumerate(annotations):
95
+ if type(mask) == dict:
96
+ mask = mask["segmentation"]
97
+ annotation = mask.astype(np.uint8)
98
+ if use_retina == False:
99
+ annotation = cv2.resize(
100
+ annotation,
101
+ (original_w, original_h),
102
+ interpolation=cv2.INTER_NEAREST,
103
+ )
104
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
105
+ for contour in contours:
106
+ contour_all.append(contour)
107
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
108
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
109
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
110
+
111
+ image = image.convert("RGBA")
112
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
113
+ image.paste(overlay_inner, (0, 0), overlay_inner)
114
+
115
+ if withContours:
116
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
117
+ image.paste(overlay_contour, (0, 0), overlay_contour)
118
+
119
+ return image
120
+
121
+
122
+ # CPU post process
123
+ def fast_show_mask(
124
+ annotation,
125
+ ax,
126
+ random_color=False,
127
+ bbox=None,
128
+ retinamask=True,
129
+ target_height=960,
130
+ target_width=960,
131
+ ):
132
+ mask_sum = annotation.shape[0]
133
+ height = annotation.shape[1]
134
+ weight = annotation.shape[2]
135
+ areas = np.sum(annotation, axis=(1, 2))
136
+ sorted_indices = np.argsort(areas)[::1]
137
+ annotation = annotation[sorted_indices]
138
+
139
+ index = (annotation != 0).argmax(axis=0)
140
+ if random_color == True:
141
+ color = np.random.random((mask_sum, 1, 1, 3))
142
+ else:
143
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
144
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
145
+ visual = np.concatenate([color, transparency], axis=-1)
146
+ mask_image = np.expand_dims(annotation, -1) * visual
147
+
148
+ mask = np.zeros((height, weight, 4))
149
+
150
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing="ij")
151
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
152
+
153
+ mask[h_indices, w_indices, :] = mask_image[indices]
154
+ if bbox is not None:
155
+ x1, y1, x2, y2 = bbox
156
+ ax.add_patch(
157
+ plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1)
158
+ )
159
+
160
+ if retinamask == False:
161
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
162
+
163
+ return mask
164
+
165
+
166
+ def download_file_from_url(url, output_file, chunk_size=8192):
167
+ output_dir = os.path.dirname(output_file)
168
+ os.makedirs(output_dir, exist_ok=True)
169
+ try:
170
+ with requests.get(url, stream=True) as response:
171
+ if response.status_code == 200:
172
+ with open(output_file, 'wb') as f:
173
+ for chunk in response.iter_content(chunk_size=chunk_size):
174
+ f.write(chunk)
175
+ else:
176
+ print(f"Failed to download file. Status code: {response.status_code}")
177
+ except Exception as e:
178
+ print(f"An error occurred: {e}")