dhkim2810 commited on
Commit
a8a11ec
1 Parent(s): 3a05dab

Initial Commit

Browse files
README.md CHANGED
@@ -1,13 +1,172 @@
1
- ---
2
- title: MobileSAM
3
- emoji: 🦀
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 3.35.2
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p float="center">
2
+ <img src="assets/logo.png?raw=true" width="99.1%" />
3
+ </p>
4
+
5
+ # Faster Segment Anything (MobileSAM)
6
+
7
+ :pushpin: MobileSAM paper is available at [paper link](https://arxiv.org/pdf/2306.14289.pdf).
8
+
9
+ ![MobileSAM](assets/model_diagram.jpg?raw=true)
10
+
11
+ <p float="left">
12
+ <img src="assets/mask_comparision.jpg?raw=true" width="99.1%" />
13
+ </p>
14
+
15
+ **MobileSAM** performs on par with the original SAM (at least visually) and keeps exactly the same pipeline as the original SAM except for a change on the image encoder. Specifically, we replace the original heavyweight ViT-H encoder (632M) with a much smaller Tiny-ViT (5M). On a single GPU, MobileSAM runs around 12ms per image: 8ms on the image encoder and 4ms on the mask decoder.
16
+
17
+ The comparison of ViT-based image encoder is summarzed as follows:
18
+
19
+ Image Encoder | Original SAM | MobileSAM
20
+ :-----------------------------------------:|:---------|:-----:
21
+ Paramters | 611M | 5M
22
+ Speed | 452ms | 8ms
23
+
24
+ Original SAM and MobileSAM have exactly the same prompt-guided mask decoder:
25
+
26
+ Mask Decoder | Original SAM | MobileSAM
27
+ :-----------------------------------------:|:---------|:-----:
28
+ Paramters | 3.876M | 3.876M
29
+ Speed | 4ms | 4ms
30
+
31
+ The comparison of the whole pipeline is summarzed as follows:
32
+ Whole Pipeline (Enc+Dec) | Original SAM | MobileSAM
33
+ :-----------------------------------------:|:---------|:-----:
34
+ Paramters | 615M | 9.66M
35
+ Speed | 456ms | 12ms
36
+
37
+ **Original SAM and MobileSAM with a (single) point as the prompt.**
38
+
39
+ <p float="left">
40
+ <img src="assets/mask_point.jpg?raw=true" width="99.1%" />
41
+ </p>
42
+
43
+ **Original SAM and MobileSAM with a box as the prompt.**
44
+ <p float="left">
45
+ <img src="assets/mask_box.jpg?raw=true" width="99.1%" />
46
+ </p>
47
+
48
+ **Is MobileSAM faster and smaller than FastSAM? Yes, to our knowledge!**
49
+ MobileSAM is around 7 times smaller and around 5 times faster than the concurrent FastSAM.
50
+ The comparison of the whole pipeline is summarzed as follows:
51
+ Whole Pipeline (Enc+Dec) | FastSAM | MobileSAM
52
+ :-----------------------------------------:|:---------|:-----:
53
+ Paramters | 68M | 9.66M
54
+ Speed | 64ms |12ms
55
+
56
+ **Is MobileSAM better than FastSAM for performance? Yes, to our knowledge!**
57
+ FastSAM cannot work with a single prompt as the original SAM or our MobileSAM. Therefore, we compare the mIoU with two prompt points (with different pixel distances) and show the resutls as follows. Our MobileSAM is much better than FastSAM under this setup.
58
+ mIoU | FastSAM | MobileSAM
59
+ :-----------------------------------------:|:---------|:-----:
60
+ 100 | 0.27 | 0.73
61
+ 200 | 0.33 |0.71
62
+ 300 | 0.37 |0.74
63
+ 400 | 0.41 |0.73
64
+ 500 | 0.41 |0.73
65
+
66
+
67
+
68
+
69
+ **How to Adapt from SAM to MobileSAM?** Since MobileSAM keeps exactly the same pipeline as the original SAM, we inherit pre-processing, post-processing, and all other interfaces from the original SAM. The users who use the original SAM can adapt to MobileSAM with zero effort, by assuming everything is exactly the same except for a smaller image encoder in the SAM.
70
+
71
+ **How is MobileSAM trained?** MobileSAM is trained on a single GPU with 100k datasets (1% of the original images) for less than a day. The training code will be available soon.
72
+
73
+
74
+
75
+ ## Installation
76
+
77
+ The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
78
+
79
+ Install Mobile Segment Anything:
80
+
81
+ ```
82
+ pip install git+https://github.com/ChaoningZhang/MobileSAM.git
83
+ ```
84
+
85
+ or clone the repository locally and install with
86
+
87
+ ```
88
+ git clone git@github.com:ChaoningZhang/MobileSAM.git
89
+ cd MobileSAM; pip install -e .
90
+ ```
91
+
92
+
93
+ ## <a name="GettingStarted"></a>Getting Started
94
+ The MobileSAM can be loaded in the following ways:
95
+
96
+ ```
97
+ from mobile_encoder.setup_mobile_sam import setup_model
98
+ checkpoint = torch.load('../weights/mobile_sam.pt')
99
+ mobile_sam = setup_model()
100
+ mobile_sam.load_state_dict(checkpoint,strict=True)
101
+ ```
102
+
103
+ Then the model can be easily used in just a few lines to get masks from a given prompt:
104
+
105
+ ```
106
+ from segment_anything import SamPredictor
107
+ device = "cuda"
108
+ mobile_sam.to(device=device)
109
+ mobile_sam.eval()
110
+ predictor = SamPredictor(mobile_sam)
111
+ predictor.set_image(<your_image>)
112
+ masks, _, _ = predictor.predict(<input_prompts>)
113
+ ```
114
+
115
+ or generate masks for an entire image:
116
+
117
+ ```
118
+ from segment_anything import SamAutomaticMaskGenerator
119
+
120
+ mask_generator = SamAutomaticMaskGenerator(mobile_sam)
121
+ masks = mask_generator.generate(<your_image>)
122
+ ```
123
+
124
+
125
+ ## BibTex of our MobileSAM
126
+ If you use MobileSAM in your research, please use the following BibTeX entry. :mega: Thank you!
127
+
128
+ ```bibtex
129
+ @article{mobile_sam,
130
+ title={Faster Segment Anything: Towards Lightweight SAM for Mobile Applications},
131
+ author={Zhang, Chaoning and Han, Dongshen and Qiao, Yu and Kim, Jung Uk and Bae, Sung Ho and Lee, Seungkyu and Hong, Choong Seon},
132
+ journal={arXiv preprint arXiv:2306.14289},
133
+ year={2023}
134
+ }
135
+ ```
136
+
137
+ ## Acknowledgement
138
+
139
+ <details>
140
+ <summary>
141
+ <a href="https://github.com/facebookresearch/segment-anything">SAM</a> (Segment Anything) [<b>bib</b>]
142
+ </summary>
143
+
144
+ ```bibtex
145
+ @article{kirillov2023segany,
146
+ title={Segment Anything},
147
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
148
+ journal={arXiv:2304.02643},
149
+ year={2023}
150
+ }
151
+ ```
152
+ </details>
153
+
154
+
155
+
156
+ <details>
157
+ <summary>
158
+ <a href="https://github.com/microsoft/Cream/tree/main/TinyViT">TinyViT</a> (TinyViT: Fast Pretraining Distillation for Small Vision Transformers) [<b>bib</b>]
159
+ </summary>
160
+
161
+ ```bibtex
162
+ @InProceedings{tiny_vit,
163
+ title={TinyViT: Fast Pretraining Distillation for Small Vision Transformers},
164
+ author={Wu, Kan and Zhang, Jinnian and Peng, Houwen and Liu, Mengchen and Xiao, Bin and Fu, Jianlong and Yuan, Lu},
165
+ booktitle={European conference on computer vision (ECCV)},
166
+ year={2022}
167
+ ```
168
+ </details>
169
+
170
+
171
+
172
+
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
6
+ from PIL import ImageDraw
7
+ from utils.tools import box_prompt, format_results, point_prompt
8
+ from utils.tools_gradio import fast_process
9
+
10
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # Load the pre-trained model
15
+ sam_checkpoint = "./mobile_sam.pt"
16
+ model_type = "vit_t"
17
+
18
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
+ mobile_sam = mobile_sam.to(device=device)
20
+ mobile_sam.eval()
21
+
22
+ mask_generator = SamAutomaticMaskGenerator(mobile_sam)
23
+ predictor = SamPredictor(mobile_sam)
24
+
25
+ # Description
26
+ title = "<center><strong><font size='8'>Faster Segment Anything(MobileSAM)<font></strong></center>"
27
+
28
+ description_e = """This is a demo on Github project [Faster Segment Anything(MobileSAM) Model](https://github.com/ChaoningZhang/MobileSAM). Welcome to give a star ⭐️ to it.
29
+
30
+ 🎯 Upload an Image, segment it with Faster Segment Anything (Everything mode). The other modes will come soon.
31
+
32
+ ⌛️ It takes about 5~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
33
+
34
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
35
+
36
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1k6azd5wdOOYkFwi61uXoIHfP-qBzuoOu/view?usp=sharing)
37
+
38
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/dhkim2810/MobileSAM)
39
+
40
+ 😚 Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
41
+
42
+ """
43
+
44
+ description_p = """ # 🎯 Instructions for points mode
45
+ This is a demo on Github project [Faster Segment Anything(MobileSAM) Model](https://github.com/ChaoningZhang/MobileSAM). Welcome to give a star ⭐️ to it.
46
+
47
+ 🎯 Upload an Image, segment it with Faster Segment Anything (Everything mode). The other modes will come soon.
48
+
49
+ ⌛️ It takes about 5~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
50
+
51
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
52
+
53
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1jibN6HTQcC4C2okoaKLRzHIo_pS0Eeom/view?usp=sharing)
54
+
55
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/dhkim2810/MobileSAM)
56
+
57
+
58
+
59
+ 1. Upload an image or choose an example.
60
+
61
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
62
+
63
+ 3. Add points one by one on the image.
64
+
65
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
66
+
67
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
68
+
69
+ """
70
+
71
+ examples = [
72
+ ["assets/sa_8776.jpg"],
73
+ ["assets/sa_414.jpg"],
74
+ ["assets/sa_1309.jpg"],
75
+ ["assets/sa_11025.jpg"],
76
+ ["assets/sa_561.jpg"],
77
+ ["assets/sa_192.jpg"],
78
+ ["assets/sa_10039.jpg"],
79
+ ["assets/sa_862.jpg"],
80
+ ]
81
+
82
+ default_example = examples[0]
83
+
84
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
85
+
86
+
87
+ @torch.no_grad()
88
+ def segment_everything(
89
+ image,
90
+ input_size=1024,
91
+ better_quality=False,
92
+ withContours=True,
93
+ use_retina=True,
94
+ mask_random_color=True,
95
+ ):
96
+ global mask_generator
97
+
98
+ input_size = int(input_size)
99
+ w, h = image.size
100
+ scale = input_size / max(w, h)
101
+ new_w = int(w * scale)
102
+ new_h = int(h * scale)
103
+ image = image.resize((new_w, new_h))
104
+
105
+ nd_image = np.array(image)
106
+ annotations = mask_generator.generate(nd_image)
107
+
108
+ fig = fast_process(
109
+ annotations=annotations,
110
+ image=image,
111
+ device=device,
112
+ scale=(1024 // input_size),
113
+ better_quality=better_quality,
114
+ mask_random_color=mask_random_color,
115
+ bbox=None,
116
+ use_retina=use_retina,
117
+ withContours=withContours,
118
+ )
119
+ return fig
120
+
121
+
122
+ def segment_with_points(
123
+ image,
124
+ input_size=1024,
125
+ better_quality=False,
126
+ withContours=True,
127
+ use_retina=True,
128
+ mask_random_color=True,
129
+ ):
130
+ global global_points
131
+ global global_point_label
132
+
133
+ input_size = int(input_size)
134
+ w, h = image.size
135
+ scale = input_size / max(w, h)
136
+ new_w = int(w * scale)
137
+ new_h = int(h * scale)
138
+ image = image.resize((new_w, new_h))
139
+
140
+ scaled_points = np.array([[int(x * scale) for x in point] for point in global_points])
141
+ global_point_label = np.array(global_point_label)
142
+
143
+ nd_image = np.array(image)
144
+ predictor.set_image(nd_image)
145
+ masks, scores, logits = predictor.predict(
146
+ point_coords=scaled_points,
147
+ point_labels=global_point_label,
148
+ multimask_output=True,
149
+ )
150
+
151
+ results = format_results(masks, scores, logits, 0)
152
+
153
+ annotations, _ = point_prompt(
154
+ results, scaled_points, global_point_label, new_h, new_w
155
+ )
156
+ annotations = np.array([annotations])
157
+
158
+ fig = fast_process(
159
+ annotations=annotations,
160
+ image=image,
161
+ device=device,
162
+ scale=(1024 // input_size),
163
+ better_quality=better_quality,
164
+ mask_random_color=mask_random_color,
165
+ bbox=None,
166
+ use_retina=use_retina,
167
+ withContours=withContours,
168
+ )
169
+
170
+ global_points = []
171
+ global_point_label = []
172
+ # return fig, None
173
+ return fig, image
174
+
175
+
176
+ def get_points_with_draw(image, label, evt: gr.SelectData):
177
+ global global_points
178
+ global global_point_label
179
+
180
+ x, y = evt.index[0], evt.index[1]
181
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
182
+ 255,
183
+ 0,
184
+ 255,
185
+ )
186
+ global_points.append([x, y])
187
+ global_point_label.append(1 if label == "Add Mask" else 0)
188
+
189
+ print(x, y, label == "Add Mask")
190
+
191
+ # 创建一个可以在图像上绘图的对象
192
+ draw = ImageDraw.Draw(image)
193
+ draw.ellipse(
194
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
195
+ fill=point_color,
196
+ )
197
+ return image
198
+
199
+
200
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil")
201
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type="pil")
202
+
203
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil")
204
+ segm_img_p = gr.Image(
205
+ label="Segmented Image with points", interactive=False, type="pil"
206
+ )
207
+
208
+ global_points = []
209
+ global_point_label = []
210
+
211
+ input_size_slider = gr.components.Slider(
212
+ minimum=512,
213
+ maximum=1024,
214
+ value=1024,
215
+ step=64,
216
+ label="Input_size",
217
+ info="Our model was trained on a size of 1024",
218
+ )
219
+
220
+ with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo:
221
+ with gr.Row():
222
+ with gr.Column(scale=1):
223
+ # Title
224
+ gr.Markdown(title)
225
+
226
+ # with gr.Tab("Everything mode"):
227
+ # # Images
228
+ # with gr.Row(variant="panel"):
229
+ # with gr.Column(scale=1):
230
+ # cond_img_e.render()
231
+ #
232
+ # with gr.Column(scale=1):
233
+ # segm_img_e.render()
234
+ #
235
+ # # Submit & Clear
236
+ # with gr.Row():
237
+ # with gr.Column():
238
+ # input_size_slider.render()
239
+ #
240
+ # with gr.Row():
241
+ # contour_check = gr.Checkbox(
242
+ # value=True,
243
+ # label="withContours",
244
+ # info="draw the edges of the masks",
245
+ # )
246
+ #
247
+ # with gr.Column():
248
+ # segment_btn_e = gr.Button(
249
+ # "Segment Everything", variant="primary"
250
+ # )
251
+ # clear_btn_e = gr.Button("Clear", variant="secondary")
252
+ #
253
+ # gr.Markdown("Try some of the examples below ⬇️")
254
+ # gr.Examples(
255
+ # examples=examples,
256
+ # inputs=[cond_img_e],
257
+ # outputs=segm_img_e,
258
+ # fn=segment_everything,
259
+ # cache_examples=True,
260
+ # examples_per_page=4,
261
+ # )
262
+ #
263
+ # with gr.Column():
264
+ # with gr.Accordion("Advanced options", open=False):
265
+ # # text_box = gr.Textbox(label="text prompt")
266
+ # with gr.Row():
267
+ # mor_check = gr.Checkbox(
268
+ # value=False,
269
+ # label="better_visual_quality",
270
+ # info="better quality using morphologyEx",
271
+ # )
272
+ # with gr.Column():
273
+ # retina_check = gr.Checkbox(
274
+ # value=True,
275
+ # label="use_retina",
276
+ # info="draw high-resolution segmentation masks",
277
+ # )
278
+ # # Description
279
+ # gr.Markdown(description_e)
280
+ #
281
+ with gr.Tab("Points mode"):
282
+ # Images
283
+ with gr.Row(variant="panel"):
284
+ with gr.Column(scale=1):
285
+ cond_img_p.render()
286
+
287
+ with gr.Column(scale=1):
288
+ segm_img_p.render()
289
+
290
+ # Submit & Clear
291
+ with gr.Row():
292
+ with gr.Column():
293
+ with gr.Row():
294
+ add_or_remove = gr.Radio(
295
+ ["Add Mask", "Remove Area"],
296
+ value="Add Mask",
297
+ label="Point_label (foreground/background)",
298
+ )
299
+
300
+ with gr.Column():
301
+ segment_btn_p = gr.Button(
302
+ "Segment with points prompt", variant="primary"
303
+ )
304
+ clear_btn_p = gr.Button("Clear points", variant="secondary")
305
+
306
+ gr.Markdown("Try some of the examples below ⬇️")
307
+ gr.Examples(
308
+ examples=examples,
309
+ inputs=[cond_img_p],
310
+ # outputs=segm_img_p,
311
+ # fn=segment_with_points,
312
+ # cache_examples=True,
313
+ examples_per_page=4,
314
+ )
315
+
316
+ with gr.Column():
317
+ # Description
318
+ gr.Markdown(description_p)
319
+
320
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
321
+
322
+ # segment_btn_e.click(
323
+ # segment_everything,
324
+ # inputs=[
325
+ # cond_img_e,
326
+ # input_size_slider,
327
+ # mor_check,
328
+ # contour_check,
329
+ # retina_check,
330
+ # ],
331
+ # outputs=segm_img_e,
332
+ # )
333
+
334
+ segment_btn_p.click(
335
+ segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p]
336
+ )
337
+
338
+ def clear():
339
+ return None, None
340
+
341
+ def clear_text():
342
+ return None, None, None
343
+
344
+ # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
345
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
346
+
347
+ demo.queue()
348
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ git+https://github.com/dhkim2810/MobileSAM.git
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (170 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (163 Bytes). View file
 
utils/__pycache__/tools.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
utils/__pycache__/tools.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
utils/__pycache__/tools_gradio.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
utils/__pycache__/tools_gradio.cpython-38.pyc ADDED
Binary file (4.22 kB). View file
 
utils/tools.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ x1 = box[0]
13
+ y1 = box[1]
14
+ x2 = box[0] + box[2]
15
+ y2 = box[1] + box[3]
16
+ return [x1, y1, x2, y2]
17
+
18
+
19
+ def segment_image(image, bbox):
20
+ image_array = np.array(image)
21
+ segmented_image_array = np.zeros_like(image_array)
22
+ x1, y1, x2, y2 = bbox
23
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
26
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
27
+ transparency_mask = np.zeros(
28
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
29
+ )
30
+ transparency_mask[y1:y2, x1:x2] = 255
31
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
32
+ black_image.paste(segmented_image, mask=transparency_mask_image)
33
+ return black_image
34
+
35
+
36
+ def format_results(masks, scores, logits, filter=0):
37
+ annotations = []
38
+ n = len(scores)
39
+ for i in range(n):
40
+ annotation = {}
41
+
42
+ mask = masks[i]
43
+ tmp = np.where(mask != 0)
44
+ if np.sum(mask) < filter:
45
+ continue
46
+ annotation["id"] = i
47
+ annotation["segmentation"] = mask
48
+ annotation["bbox"] = [
49
+ np.min(tmp[0]),
50
+ np.min(tmp[1]),
51
+ np.max(tmp[1]),
52
+ np.max(tmp[0]),
53
+ ]
54
+ annotation["score"] = scores[i]
55
+ annotation["area"] = annotation["segmentation"].sum()
56
+ annotations.append(annotation)
57
+ return annotations
58
+
59
+
60
+ def filter_masks(annotations): # filter the overlap mask
61
+ annotations.sort(key=lambda x: x["area"], reverse=True)
62
+ to_remove = set()
63
+ for i in range(0, len(annotations)):
64
+ a = annotations[i]
65
+ for j in range(i + 1, len(annotations)):
66
+ b = annotations[j]
67
+ if i != j and j not in to_remove:
68
+ # check if
69
+ if b["area"] < a["area"]:
70
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
71
+ "segmentation"
72
+ ].sum() > 0.8:
73
+ to_remove.add(j)
74
+
75
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
76
+
77
+
78
+ def get_bbox_from_mask(mask):
79
+ mask = mask.astype(np.uint8)
80
+ contours, hierarchy = cv2.findContours(
81
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
82
+ )
83
+ x1, y1, w, h = cv2.boundingRect(contours[0])
84
+ x2, y2 = x1 + w, y1 + h
85
+ if len(contours) > 1:
86
+ for b in contours:
87
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
88
+ # 将多个bbox合并成一个
89
+ x1 = min(x1, x_t)
90
+ y1 = min(y1, y_t)
91
+ x2 = max(x2, x_t + w_t)
92
+ y2 = max(y2, y_t + h_t)
93
+ h = y2 - y1
94
+ w = x2 - x1
95
+ return [x1, y1, x2, y2]
96
+
97
+
98
+ def fast_process(
99
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
100
+ ):
101
+ if isinstance(annotations[0], dict):
102
+ annotations = [annotation["segmentation"] for annotation in annotations]
103
+ result_name = os.path.basename(args.img_path)
104
+ image = cv2.imread(args.img_path)
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
+ original_h = image.shape[0]
107
+ original_w = image.shape[1]
108
+ if sys.platform == "darwin":
109
+ plt.switch_backend("TkAgg")
110
+ plt.figure(figsize=(original_w / 100, original_h / 100))
111
+ # Add subplot with no margin.
112
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
113
+ plt.margins(0, 0)
114
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
115
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
116
+ plt.imshow(image)
117
+ if args.better_quality == True:
118
+ if isinstance(annotations[0], torch.Tensor):
119
+ annotations = np.array(annotations.cpu())
120
+ for i, mask in enumerate(annotations):
121
+ mask = cv2.morphologyEx(
122
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
123
+ )
124
+ annotations[i] = cv2.morphologyEx(
125
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
126
+ )
127
+ if args.device == "cpu":
128
+ annotations = np.array(annotations)
129
+ fast_show_mask(
130
+ annotations,
131
+ plt.gca(),
132
+ random_color=mask_random_color,
133
+ bbox=bbox,
134
+ points=points,
135
+ point_label=args.point_label,
136
+ retinamask=args.retina,
137
+ target_height=original_h,
138
+ target_width=original_w,
139
+ )
140
+ else:
141
+ if isinstance(annotations[0], np.ndarray):
142
+ annotations = torch.from_numpy(annotations)
143
+ fast_show_mask_gpu(
144
+ annotations,
145
+ plt.gca(),
146
+ random_color=args.randomcolor,
147
+ bbox=bbox,
148
+ points=points,
149
+ point_label=args.point_label,
150
+ retinamask=args.retina,
151
+ target_height=original_h,
152
+ target_width=original_w,
153
+ )
154
+ if isinstance(annotations, torch.Tensor):
155
+ annotations = annotations.cpu().numpy()
156
+ if args.withContours == True:
157
+ contour_all = []
158
+ temp = np.zeros((original_h, original_w, 1))
159
+ for i, mask in enumerate(annotations):
160
+ if type(mask) == dict:
161
+ mask = mask["segmentation"]
162
+ annotation = mask.astype(np.uint8)
163
+ if args.retina == False:
164
+ annotation = cv2.resize(
165
+ annotation,
166
+ (original_w, original_h),
167
+ interpolation=cv2.INTER_NEAREST,
168
+ )
169
+ contours, hierarchy = cv2.findContours(
170
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
171
+ )
172
+ for contour in contours:
173
+ contour_all.append(contour)
174
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
175
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
176
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
177
+ plt.imshow(contour_mask)
178
+
179
+ save_path = args.output
180
+ if not os.path.exists(save_path):
181
+ os.makedirs(save_path)
182
+ plt.axis("off")
183
+ fig = plt.gcf()
184
+ plt.draw()
185
+
186
+ try:
187
+ buf = fig.canvas.tostring_rgb()
188
+ except AttributeError:
189
+ fig.canvas.draw()
190
+ buf = fig.canvas.tostring_rgb()
191
+
192
+ cols, rows = fig.canvas.get_width_height()
193
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
194
+ cv2.imwrite(
195
+ os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
196
+ )
197
+
198
+
199
+ # CPU post process
200
+ def fast_show_mask(
201
+ annotation,
202
+ ax,
203
+ random_color=False,
204
+ bbox=None,
205
+ points=None,
206
+ point_label=None,
207
+ retinamask=True,
208
+ target_height=960,
209
+ target_width=960,
210
+ ):
211
+ msak_sum = annotation.shape[0]
212
+ height = annotation.shape[1]
213
+ weight = annotation.shape[2]
214
+ # 将annotation 按照面积 排序
215
+ areas = np.sum(annotation, axis=(1, 2))
216
+ sorted_indices = np.argsort(areas)
217
+ annotation = annotation[sorted_indices]
218
+
219
+ index = (annotation != 0).argmax(axis=0)
220
+ if random_color == True:
221
+ color = np.random.random((msak_sum, 1, 1, 3))
222
+ else:
223
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
224
+ [30 / 255, 144 / 255, 255 / 255]
225
+ )
226
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
227
+ visual = np.concatenate([color, transparency], axis=-1)
228
+ mask_image = np.expand_dims(annotation, -1) * visual
229
+
230
+ show = np.zeros((height, weight, 4))
231
+ h_indices, w_indices = np.meshgrid(
232
+ np.arange(height), np.arange(weight), indexing="ij"
233
+ )
234
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
235
+ # 使用向量化索引更新show的值
236
+ show[h_indices, w_indices, :] = mask_image[indices]
237
+ if bbox is not None:
238
+ x1, y1, x2, y2 = bbox
239
+ ax.add_patch(
240
+ plt.Rectangle(
241
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
242
+ )
243
+ )
244
+ # draw point
245
+ if points is not None:
246
+ plt.scatter(
247
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
248
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
249
+ s=20,
250
+ c="y",
251
+ )
252
+ plt.scatter(
253
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
254
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
255
+ s=20,
256
+ c="m",
257
+ )
258
+
259
+ if retinamask == False:
260
+ show = cv2.resize(
261
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
262
+ )
263
+ ax.imshow(show)
264
+
265
+
266
+ def fast_show_mask_gpu(
267
+ annotation,
268
+ ax,
269
+ random_color=False,
270
+ bbox=None,
271
+ points=None,
272
+ point_label=None,
273
+ retinamask=True,
274
+ target_height=960,
275
+ target_width=960,
276
+ ):
277
+ msak_sum = annotation.shape[0]
278
+ height = annotation.shape[1]
279
+ weight = annotation.shape[2]
280
+ areas = torch.sum(annotation, dim=(1, 2))
281
+ sorted_indices = torch.argsort(areas, descending=False)
282
+ annotation = annotation[sorted_indices]
283
+ # 找每个位置第一个非零值下标
284
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
285
+ if random_color == True:
286
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
287
+ else:
288
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
289
+ [30 / 255, 144 / 255, 255 / 255]
290
+ ).to(annotation.device)
291
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
292
+ visual = torch.cat([color, transparency], dim=-1)
293
+ mask_image = torch.unsqueeze(annotation, -1) * visual
294
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
295
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
296
+ h_indices, w_indices = torch.meshgrid(
297
+ torch.arange(height), torch.arange(weight), indexing="ij"
298
+ )
299
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
300
+ # 使用向量化索引更新show的值
301
+ show[h_indices, w_indices, :] = mask_image[indices]
302
+ show_cpu = show.cpu().numpy()
303
+ if bbox is not None:
304
+ x1, y1, x2, y2 = bbox
305
+ ax.add_patch(
306
+ plt.Rectangle(
307
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
308
+ )
309
+ )
310
+ # draw point
311
+ if points is not None:
312
+ plt.scatter(
313
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
314
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
315
+ s=20,
316
+ c="y",
317
+ )
318
+ plt.scatter(
319
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
320
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
321
+ s=20,
322
+ c="m",
323
+ )
324
+ if retinamask == False:
325
+ show_cpu = cv2.resize(
326
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
327
+ )
328
+ ax.imshow(show_cpu)
329
+
330
+
331
+ def crop_image(annotations, image_like):
332
+ if isinstance(image_like, str):
333
+ image = Image.open(image_like)
334
+ else:
335
+ image = image_like
336
+ ori_w, ori_h = image.size
337
+ mask_h, mask_w = annotations[0]["segmentation"].shape
338
+ if ori_w != mask_w or ori_h != mask_h:
339
+ image = image.resize((mask_w, mask_h))
340
+ cropped_boxes = []
341
+ cropped_images = []
342
+ not_crop = []
343
+ filter_id = []
344
+ # annotations, _ = filter_masks(annotations)
345
+ # filter_id = list(_)
346
+ for _, mask in enumerate(annotations):
347
+ if np.sum(mask["segmentation"]) <= 100:
348
+ filter_id.append(_)
349
+ continue
350
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
351
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
352
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
353
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
354
+
355
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
356
+
357
+
358
+ def box_prompt(masks, bbox, target_height, target_width):
359
+ h = masks.shape[1]
360
+ w = masks.shape[2]
361
+ if h != target_height or w != target_width:
362
+ bbox = [
363
+ int(bbox[0] * w / target_width),
364
+ int(bbox[1] * h / target_height),
365
+ int(bbox[2] * w / target_width),
366
+ int(bbox[3] * h / target_height),
367
+ ]
368
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
369
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
370
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
371
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
372
+
373
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
374
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
375
+
376
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
377
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
378
+
379
+ union = bbox_area + orig_masks_area - masks_area
380
+ IoUs = masks_area / union
381
+ max_iou_index = torch.argmax(IoUs)
382
+
383
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
384
+
385
+
386
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
387
+ h = masks[0]["segmentation"].shape[0]
388
+ w = masks[0]["segmentation"].shape[1]
389
+ if h != target_height or w != target_width:
390
+ points = [
391
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
392
+ for point in points
393
+ ]
394
+ onemask = np.zeros((h, w))
395
+ for i, annotation in enumerate(masks):
396
+ if type(annotation) == dict:
397
+ mask = annotation["segmentation"]
398
+ else:
399
+ mask = annotation
400
+ for i, point in enumerate(points):
401
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
402
+ onemask += mask
403
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
404
+ onemask -= mask
405
+ onemask = onemask >= 1
406
+ return onemask, 0
utils/tools_gradio.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation["segmentation"] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(
29
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
30
+ )
31
+ annotations[i] = cv2.morphologyEx(
32
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
33
+ )
34
+ if device == "cpu":
35
+ annotations = np.array(annotations)
36
+ inner_mask = fast_show_mask(
37
+ annotations,
38
+ plt.gca(),
39
+ random_color=mask_random_color,
40
+ bbox=bbox,
41
+ retinamask=use_retina,
42
+ target_height=original_h,
43
+ target_width=original_w,
44
+ )
45
+ else:
46
+ if isinstance(annotations[0], np.ndarray):
47
+ annotations = np.array(annotations)
48
+ annotations = torch.from_numpy(annotations)
49
+ inner_mask = fast_show_mask_gpu(
50
+ annotations,
51
+ plt.gca(),
52
+ random_color=mask_random_color,
53
+ bbox=bbox,
54
+ retinamask=use_retina,
55
+ target_height=original_h,
56
+ target_width=original_w,
57
+ )
58
+ if isinstance(annotations, torch.Tensor):
59
+ annotations = annotations.cpu().numpy()
60
+
61
+ if withContours:
62
+ contour_all = []
63
+ temp = np.zeros((original_h, original_w, 1))
64
+ for i, mask in enumerate(annotations):
65
+ if type(mask) == dict:
66
+ mask = mask["segmentation"]
67
+ annotation = mask.astype(np.uint8)
68
+ if use_retina == False:
69
+ annotation = cv2.resize(
70
+ annotation,
71
+ (original_w, original_h),
72
+ interpolation=cv2.INTER_NEAREST,
73
+ )
74
+ contours, _ = cv2.findContours(
75
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
76
+ )
77
+ for contour in contours:
78
+ contour_all.append(contour)
79
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
80
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
81
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
82
+
83
+ image = image.convert("RGBA")
84
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
85
+ image.paste(overlay_inner, (0, 0), overlay_inner)
86
+
87
+ if withContours:
88
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
89
+ image.paste(overlay_contour, (0, 0), overlay_contour)
90
+
91
+ return image
92
+
93
+
94
+ # CPU post process
95
+ def fast_show_mask(
96
+ annotation,
97
+ ax,
98
+ random_color=False,
99
+ bbox=None,
100
+ retinamask=True,
101
+ target_height=960,
102
+ target_width=960,
103
+ ):
104
+ mask_sum = annotation.shape[0]
105
+ height = annotation.shape[1]
106
+ weight = annotation.shape[2]
107
+ # 将annotation 按照面积 排序
108
+ areas = np.sum(annotation, axis=(1, 2))
109
+ sorted_indices = np.argsort(areas)[::1]
110
+ annotation = annotation[sorted_indices]
111
+
112
+ index = (annotation != 0).argmax(axis=0)
113
+ if random_color == True:
114
+ color = np.random.random((mask_sum, 1, 1, 3))
115
+ else:
116
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
117
+ [30 / 255, 144 / 255, 255 / 255]
118
+ )
119
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
120
+ visual = np.concatenate([color, transparency], axis=-1)
121
+ mask_image = np.expand_dims(annotation, -1) * visual
122
+
123
+ mask = np.zeros((height, weight, 4))
124
+
125
+ h_indices, w_indices = np.meshgrid(
126
+ np.arange(height), np.arange(weight), indexing="ij"
127
+ )
128
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
129
+
130
+ mask[h_indices, w_indices, :] = mask_image[indices]
131
+ if bbox is not None:
132
+ x1, y1, x2, y2 = bbox
133
+ ax.add_patch(
134
+ plt.Rectangle(
135
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
136
+ )
137
+ )
138
+
139
+ if retinamask == False:
140
+ mask = cv2.resize(
141
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
142
+ )
143
+
144
+ return mask
145
+
146
+
147
+ def fast_show_mask_gpu(
148
+ annotation,
149
+ ax,
150
+ random_color=False,
151
+ bbox=None,
152
+ retinamask=True,
153
+ target_height=960,
154
+ target_width=960,
155
+ ):
156
+ device = annotation.device
157
+ mask_sum = annotation.shape[0]
158
+ height = annotation.shape[1]
159
+ weight = annotation.shape[2]
160
+ areas = torch.sum(annotation, dim=(1, 2))
161
+ sorted_indices = torch.argsort(areas, descending=False)
162
+ annotation = annotation[sorted_indices]
163
+ # 找每个位置第一个非零值下标
164
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
165
+ if random_color == True:
166
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
167
+ else:
168
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
169
+ [30 / 255, 144 / 255, 255 / 255]
170
+ ).to(device)
171
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
172
+ visual = torch.cat([color, transparency], dim=-1)
173
+ mask_image = torch.unsqueeze(annotation, -1) * visual
174
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
175
+ mask = torch.zeros((height, weight, 4)).to(device)
176
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
177
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
178
+ # 使用向量化索引更新show的值
179
+ mask[h_indices, w_indices, :] = mask_image[indices]
180
+ mask_cpu = mask.cpu().numpy()
181
+ if bbox is not None:
182
+ x1, y1, x2, y2 = bbox
183
+ ax.add_patch(
184
+ plt.Rectangle(
185
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
186
+ )
187
+ )
188
+ if retinamask == False:
189
+ mask_cpu = cv2.resize(
190
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
191
+ )
192
+ return mask_cpu