sushmanth commited on
Commit
a974071
1 Parent(s): b922f46

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +7 -6
  2. app.py +242 -0
  3. demo.py +87 -0
  4. requirements.txt +3 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Segment
3
- emoji:
4
- colorFrom: purple
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.28.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Segment-Anything-Video
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.19.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
3
+
4
+
5
+ def image_app():
6
+ with gr.Blocks():
7
+ with gr.Row():
8
+ with gr.Column():
9
+ seg_automask_image_file = gr.Image(type="filepath").style(height=260)
10
+ with gr.Row():
11
+ with gr.Column():
12
+ seg_automask_image_model_type = gr.Dropdown(
13
+ choices=[
14
+ "vit_h",
15
+ "vit_l",
16
+ "vit_b",
17
+ ],
18
+ value="vit_l",
19
+ label="Model Type",
20
+ )
21
+
22
+ seg_automask_image_min_area = gr.Number(
23
+ value=0,
24
+ label="Min Area",
25
+ )
26
+ with gr.Row():
27
+ with gr.Column():
28
+ seg_automask_image_points_per_side = gr.Slider(
29
+ minimum=0,
30
+ maximum=32,
31
+ step=2,
32
+ value=16,
33
+ label="Points per Side",
34
+ )
35
+
36
+ seg_automask_image_points_per_batch = gr.Slider(
37
+ minimum=0,
38
+ maximum=64,
39
+ step=2,
40
+ value=64,
41
+ label="Points per Batch",
42
+ )
43
+
44
+ seg_automask_image_predict = gr.Button(value="Generator")
45
+
46
+ with gr.Column():
47
+ output_image = gr.Image()
48
+
49
+ seg_automask_image_predict.click(
50
+ fn=automask_image_app,
51
+ inputs=[
52
+ seg_automask_image_file,
53
+ seg_automask_image_model_type,
54
+ seg_automask_image_points_per_side,
55
+ seg_automask_image_points_per_batch,
56
+ seg_automask_image_min_area,
57
+ ],
58
+ outputs=[output_image],
59
+ )
60
+
61
+
62
+ def video_app():
63
+ with gr.Blocks():
64
+ with gr.Row():
65
+ with gr.Column():
66
+ seg_automask_video_file = gr.Video().style(height=260)
67
+ with gr.Row():
68
+ with gr.Column():
69
+ seg_automask_video_model_type = gr.Dropdown(
70
+ choices=[
71
+ "vit_h",
72
+ "vit_l",
73
+ "vit_b",
74
+ ],
75
+ value="vit_l",
76
+ label="Model Type",
77
+ )
78
+ seg_automask_video_min_area = gr.Number(
79
+ value=1000,
80
+ label="Min Area",
81
+ )
82
+
83
+ with gr.Row():
84
+ with gr.Column():
85
+ seg_automask_video_points_per_side = gr.Slider(
86
+ minimum=0,
87
+ maximum=32,
88
+ step=2,
89
+ value=16,
90
+ label="Points per Side",
91
+ )
92
+
93
+ seg_automask_video_points_per_batch = gr.Slider(
94
+ minimum=0,
95
+ maximum=64,
96
+ step=2,
97
+ value=64,
98
+ label="Points per Batch",
99
+ )
100
+
101
+ seg_automask_video_predict = gr.Button(value="Generator")
102
+ with gr.Column():
103
+ output_video = gr.Video()
104
+
105
+ seg_automask_video_predict.click(
106
+ fn=automask_video_app,
107
+ inputs=[
108
+ seg_automask_video_file,
109
+ seg_automask_video_model_type,
110
+ seg_automask_video_points_per_side,
111
+ seg_automask_video_points_per_batch,
112
+ seg_automask_video_min_area,
113
+ ],
114
+ outputs=[output_video],
115
+ )
116
+
117
+
118
+ def sahi_app():
119
+ with gr.Blocks():
120
+ with gr.Row():
121
+ with gr.Column():
122
+ sahi_image_file = gr.Image(type="filepath").style(height=260)
123
+ sahi_autoseg_model_type = gr.Dropdown(
124
+ choices=[
125
+ "vit_h",
126
+ "vit_l",
127
+ "vit_b",
128
+ ],
129
+ value="vit_l",
130
+ label="Sam Model Type",
131
+ )
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ sahi_model_type = gr.Dropdown(
136
+ choices=[
137
+ "yolov5",
138
+ "yolov8",
139
+ ],
140
+ value="yolov5",
141
+ label="Detector Model Type",
142
+ )
143
+ sahi_image_size = gr.Slider(
144
+ minimum=0,
145
+ maximum=1600,
146
+ step=32,
147
+ value=640,
148
+ label="Image Size",
149
+ )
150
+
151
+ sahi_overlap_width = gr.Slider(
152
+ minimum=0,
153
+ maximum=1,
154
+ step=0.1,
155
+ value=0.2,
156
+ label="Overlap Width",
157
+ )
158
+
159
+ sahi_slice_width = gr.Slider(
160
+ minimum=0,
161
+ maximum=640,
162
+ step=32,
163
+ value=256,
164
+ label="Slice Width",
165
+ )
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ sahi_model_path = gr.Dropdown(
170
+ choices=[
171
+ "yolov5l.pt",
172
+ "yolov5l6.pt",
173
+ "yolov8l.pt",
174
+ "yolov8x.pt"
175
+ ],
176
+ value="yolov5l6.pt",
177
+ label="Detector Model Path",
178
+ )
179
+
180
+ sahi_conf_th = gr.Slider(
181
+ minimum=0,
182
+ maximum=1,
183
+ step=0.1,
184
+ value=0.2,
185
+ label="Confidence Threshold",
186
+ )
187
+ sahi_overlap_height = gr.Slider(
188
+ minimum=0,
189
+ maximum=1,
190
+ step=0.1,
191
+ value=0.2,
192
+ label="Overlap Height",
193
+ )
194
+ sahi_slice_height = gr.Slider(
195
+ minimum=0,
196
+ maximum=640,
197
+ step=32,
198
+ value=256,
199
+ label="Slice Height",
200
+ )
201
+ sahi_image_predict = gr.Button(value="Generator")
202
+
203
+ with gr.Column():
204
+ output_image = gr.Image()
205
+
206
+ sahi_image_predict.click(
207
+ fn=sahi_autoseg_app,
208
+ inputs=[
209
+ sahi_image_file,
210
+ sahi_autoseg_model_type,
211
+ sahi_model_type,
212
+ sahi_model_path,
213
+ sahi_conf_th,
214
+ sahi_image_size,
215
+ sahi_slice_height,
216
+ sahi_slice_width,
217
+ sahi_overlap_height,
218
+ sahi_overlap_width,
219
+
220
+ ],
221
+ outputs=[output_image],
222
+ )
223
+
224
+ def metaseg_app():
225
+ app = gr.Blocks()
226
+ with app:
227
+ with gr.Row():
228
+ with gr.Column():
229
+ with gr.Tab("Image"):
230
+ image_app()
231
+ with gr.Tab("Video"):
232
+ video_app()
233
+ with gr.Tab("SAHI"):
234
+ sahi_app()
235
+
236
+
237
+ app.queue(concurrency_count=1)
238
+ app.launch(debug=True, enable_queue=True)
239
+
240
+
241
+ if __name__ == "__main__":
242
+ metaseg_app()
demo.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metaseg import SegAutoMaskPredictor, SegManualMaskPredictor, SahiAutoSegmentation, sahi_sliced_predict
2
+
3
+ # For image
4
+
5
+ def automask_image_app(image_path, model_type, points_per_side, points_per_batch, min_area):
6
+ SegAutoMaskPredictor().image_predict(
7
+ source=image_path,
8
+ model_type=model_type, # vit_l, vit_h, vit_b
9
+ points_per_side=points_per_side,
10
+ points_per_batch=points_per_batch,
11
+ min_area=min_area,
12
+ output_path="output.png",
13
+ show=False,
14
+ save=True,
15
+ )
16
+ return "output.png"
17
+
18
+
19
+ # For video
20
+
21
+ def automask_video_app(video_path, model_type, points_per_side, points_per_batch, min_area):
22
+ SegAutoMaskPredictor().video_predict(
23
+ source=video_path,
24
+ model_type=model_type, # vit_l, vit_h, vit_b
25
+ points_per_side=points_per_side,
26
+ points_per_batch=points_per_batch,
27
+ min_area=min_area,
28
+ output_path="output.mp4",
29
+ )
30
+ return "output.mp4"
31
+
32
+
33
+ # For manuel box and point selection
34
+
35
+ def manual_app(image_path, model_type, input_point, input_label, input_box, multimask_output, random_color):
36
+ SegManualMaskPredictor().image_predict(
37
+ source=image_path,
38
+ model_type=model_type, # vit_l, vit_h, vit_b
39
+ input_point=input_point,
40
+ input_label=input_label,
41
+ input_box=input_box,
42
+ multimask_output=multimask_output,
43
+ random_color=random_color,
44
+ output_path="output.png",
45
+ show=False,
46
+ save=True,
47
+ )
48
+ return "output.png"
49
+
50
+
51
+ # For sahi sliced prediction
52
+
53
+ def sahi_autoseg_app(
54
+ image_path,
55
+ sam_model_type,
56
+ detection_model_type,
57
+ detection_model_path,
58
+ conf_th,
59
+ image_size,
60
+ slice_height,
61
+ slice_width,
62
+ overlap_height_ratio,
63
+ overlap_width_ratio,
64
+ ):
65
+ boxes = sahi_sliced_predict(
66
+ image_path=image_path,
67
+ detection_model_type=detection_model_type, # yolov8, detectron2, mmdetection, torchvision
68
+ detection_model_path=detection_model_path,
69
+ conf_th=conf_th,
70
+ image_size=image_size,
71
+ slice_height=slice_height,
72
+ slice_width=slice_width,
73
+ overlap_height_ratio=overlap_height_ratio,
74
+ overlap_width_ratio=overlap_width_ratio,
75
+ )
76
+
77
+ SahiAutoSegmentation().predict(
78
+ source=image_path,
79
+ model_type=sam_model_type,
80
+ input_box=boxes,
81
+ multimask_output=False,
82
+ random_color=False,
83
+ show=False,
84
+ save=True,
85
+ )
86
+
87
+ return "output.png"
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ metaseg==0.5.8
2
+ sahi
3
+ yolov5