Spaces:
Running
on
Zero
Running
on
Zero
Revert "Florence-2 + SAM2 + FLUX.1"
Browse filesThis reverts commit b38c358bbb73c6626d065b797723ecdb9954331a.
- .gitattributes +0 -1
- app.py +87 -71
- configs/__init__.py +0 -5
- configs/sam2_hiera_b+.yaml +0 -113
- configs/sam2_hiera_l.yaml +0 -117
- configs/sam2_hiera_s.yaml +0 -116
- configs/sam2_hiera_t.yaml +0 -118
- requirements.txt +1 -8
- utils/florence.py +0 -54
- utils/sam.py +0 -45
.gitattributes
CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
checkpoints/ filter=lfs diff=lfs merge=lfs -text
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
app.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
-
import
|
4 |
import random
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
-
from PIL import Image
|
10 |
from diffusers import FluxInpaintPipeline
|
11 |
|
12 |
-
from utils.florence import load_florence_model, run_florence_inference, \
|
13 |
-
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
|
14 |
-
from utils.sam import load_sam_image_model, run_sam_inference
|
15 |
-
|
16 |
MARKDOWN = """
|
17 |
# FLUX.1 Inpainting 🔥
|
18 |
|
@@ -23,16 +19,52 @@ for taking it to the next level by enabling inpainting with the FLUX.
|
|
23 |
|
24 |
MAX_SEED = np.iinfo(np.int32).max
|
25 |
IMAGE_SIZE = 1024
|
26 |
-
DEVICE =
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
37 |
|
38 |
|
@@ -42,6 +74,11 @@ def resize_image_dimensions(
|
|
42 |
) -> Tuple[int, int]:
|
43 |
width, height = original_resolution_wh
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
if width > height:
|
46 |
scaling_factor = maximum_dimension / width
|
47 |
else:
|
@@ -56,20 +93,17 @@ def resize_image_dimensions(
|
|
56 |
return new_width, new_height
|
57 |
|
58 |
|
59 |
-
@spaces.GPU(duration=
|
60 |
-
@torch.inference_mode()
|
61 |
-
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
62 |
def process(
|
63 |
input_image_editor: dict,
|
64 |
-
|
65 |
-
segmentation_prompt_text: str,
|
66 |
seed_slicer: int,
|
67 |
randomize_seed_checkbox: bool,
|
68 |
strength_slider: float,
|
69 |
num_inference_steps_slider: int,
|
70 |
progress=gr.Progress(track_tqdm=True)
|
71 |
):
|
72 |
-
if not
|
73 |
gr.Info("Please enter a text prompt.")
|
74 |
return None, None
|
75 |
|
@@ -80,50 +114,21 @@ def process(
|
|
80 |
gr.Info("Please upload an image.")
|
81 |
return None, None
|
82 |
|
83 |
-
if not mask
|
84 |
-
gr.Info("Please draw a mask
|
85 |
-
return None, None
|
86 |
-
|
87 |
-
if mask and segmentation_prompt_text:
|
88 |
-
gr.Info("Both mask and segmentation prompt are provided. Please provide only "
|
89 |
-
"one.")
|
90 |
return None, None
|
91 |
|
92 |
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
93 |
-
|
94 |
-
|
95 |
-
if segmentation_prompt_text:
|
96 |
-
_, result = run_florence_inference(
|
97 |
-
model=FLORENCE_MODEL,
|
98 |
-
processor=FLORENCE_PROCESSOR,
|
99 |
-
device=DEVICE,
|
100 |
-
image=image,
|
101 |
-
task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
|
102 |
-
text=segmentation_prompt_text
|
103 |
-
)
|
104 |
-
detections = sv.Detections.from_lmm(
|
105 |
-
lmm=sv.LMM.FLORENCE_2,
|
106 |
-
result=result,
|
107 |
-
resolution_wh=image.size
|
108 |
-
)
|
109 |
-
detections = run_sam_inference(SAM_IMAGE_MODEL, image, detections)
|
110 |
-
|
111 |
-
if len(detections) == 0:
|
112 |
-
gr.Info(f"{segmentation_prompt_text} prompt did not return any detections.")
|
113 |
-
return None, None
|
114 |
-
|
115 |
-
mask = Image.fromarray((detections.mask[0].astype(np.uint8)) * 255)
|
116 |
-
|
117 |
-
mask = mask.resize((width, height), Image.LANCZOS)
|
118 |
-
mask = mask.filter(ImageFilter.GaussianBlur(radius=10))
|
119 |
|
120 |
if randomize_seed_checkbox:
|
121 |
seed_slicer = random.randint(0, MAX_SEED)
|
122 |
generator = torch.Generator().manual_seed(seed_slicer)
|
123 |
-
result =
|
124 |
-
prompt=
|
125 |
-
image=
|
126 |
-
mask_image=
|
127 |
width=width,
|
128 |
height=height,
|
129 |
strength=strength_slider,
|
@@ -131,7 +136,7 @@ def process(
|
|
131 |
num_inference_steps=num_inference_steps_slider
|
132 |
).images[0]
|
133 |
print('INFERENCE DONE')
|
134 |
-
return result,
|
135 |
|
136 |
|
137 |
with gr.Blocks() as demo:
|
@@ -147,24 +152,17 @@ with gr.Blocks() as demo:
|
|
147 |
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
148 |
|
149 |
with gr.Row():
|
150 |
-
|
151 |
label="Prompt",
|
152 |
show_label=False,
|
153 |
max_lines=1,
|
154 |
-
placeholder="Enter
|
155 |
container=False,
|
156 |
)
|
157 |
submit_button_component = gr.Button(
|
158 |
value='Submit', variant='primary', scale=0)
|
159 |
|
160 |
with gr.Accordion("Advanced Settings", open=False):
|
161 |
-
segmentation_prompt_text_component = gr.Text(
|
162 |
-
label="Prompt",
|
163 |
-
show_label=False,
|
164 |
-
max_lines=1,
|
165 |
-
placeholder="Enter segmentation prompt",
|
166 |
-
container=False,
|
167 |
-
)
|
168 |
seed_slicer_component = gr.Slider(
|
169 |
label="Seed",
|
170 |
minimum=0,
|
@@ -203,13 +201,31 @@ with gr.Blocks() as demo:
|
|
203 |
with gr.Accordion("Debug", open=False):
|
204 |
output_mask_component = gr.Image(
|
205 |
type='pil', image_mode='RGB', label='Input mask', format="png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
submit_button_component.click(
|
208 |
fn=process,
|
209 |
inputs=[
|
210 |
input_image_editor_component,
|
211 |
-
|
212 |
-
segmentation_prompt_text_component,
|
213 |
seed_slicer_component,
|
214 |
randomize_seed_checkbox_component,
|
215 |
strength_slider_component,
|
|
|
1 |
from typing import Tuple
|
2 |
|
3 |
+
import requests
|
4 |
import random
|
5 |
import numpy as np
|
6 |
import gradio as gr
|
7 |
import spaces
|
8 |
import torch
|
9 |
+
from PIL import Image
|
10 |
from diffusers import FluxInpaintPipeline
|
11 |
|
|
|
|
|
|
|
|
|
12 |
MARKDOWN = """
|
13 |
# FLUX.1 Inpainting 🔥
|
14 |
|
|
|
19 |
|
20 |
MAX_SEED = np.iinfo(np.int32).max
|
21 |
IMAGE_SIZE = 1024
|
22 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
23 |
+
|
24 |
+
|
25 |
+
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
26 |
+
image = image.convert("RGBA")
|
27 |
+
data = image.getdata()
|
28 |
+
new_data = []
|
29 |
+
for item in data:
|
30 |
+
avg = sum(item[:3]) / 3
|
31 |
+
if avg < threshold:
|
32 |
+
new_data.append((0, 0, 0, 0))
|
33 |
+
else:
|
34 |
+
new_data.append(item)
|
35 |
+
|
36 |
+
image.putdata(new_data)
|
37 |
+
return image
|
38 |
+
|
39 |
+
|
40 |
+
EXAMPLES = [
|
41 |
+
[
|
42 |
+
{
|
43 |
+
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
|
44 |
+
"layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-2.png", stream=True).raw))],
|
45 |
+
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-2.png", stream=True).raw),
|
46 |
+
},
|
47 |
+
"little lion",
|
48 |
+
42,
|
49 |
+
False,
|
50 |
+
0.85,
|
51 |
+
30
|
52 |
+
],
|
53 |
+
[
|
54 |
+
{
|
55 |
+
"background": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-image.png", stream=True).raw),
|
56 |
+
"layers": [remove_background(Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-mask-3.png", stream=True).raw))],
|
57 |
+
"composite": Image.open(requests.get("https://media.roboflow.com/spaces/doge-2-composite-3.png", stream=True).raw),
|
58 |
+
},
|
59 |
+
"tattoos",
|
60 |
+
42,
|
61 |
+
False,
|
62 |
+
0.85,
|
63 |
+
30
|
64 |
+
]
|
65 |
+
]
|
66 |
+
|
67 |
+
pipe = FluxInpaintPipeline.from_pretrained(
|
68 |
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(DEVICE)
|
69 |
|
70 |
|
|
|
74 |
) -> Tuple[int, int]:
|
75 |
width, height = original_resolution_wh
|
76 |
|
77 |
+
# if width <= maximum_dimension and height <= maximum_dimension:
|
78 |
+
# width = width - (width % 32)
|
79 |
+
# height = height - (height % 32)
|
80 |
+
# return width, height
|
81 |
+
|
82 |
if width > height:
|
83 |
scaling_factor = maximum_dimension / width
|
84 |
else:
|
|
|
93 |
return new_width, new_height
|
94 |
|
95 |
|
96 |
+
@spaces.GPU(duration=100)
|
|
|
|
|
97 |
def process(
|
98 |
input_image_editor: dict,
|
99 |
+
input_text: str,
|
|
|
100 |
seed_slicer: int,
|
101 |
randomize_seed_checkbox: bool,
|
102 |
strength_slider: float,
|
103 |
num_inference_steps_slider: int,
|
104 |
progress=gr.Progress(track_tqdm=True)
|
105 |
):
|
106 |
+
if not input_text:
|
107 |
gr.Info("Please enter a text prompt.")
|
108 |
return None, None
|
109 |
|
|
|
114 |
gr.Info("Please upload an image.")
|
115 |
return None, None
|
116 |
|
117 |
+
if not mask:
|
118 |
+
gr.Info("Please draw a mask on the image.")
|
|
|
|
|
|
|
|
|
|
|
119 |
return None, None
|
120 |
|
121 |
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
122 |
+
resized_image = image.resize((width, height), Image.LANCZOS)
|
123 |
+
resized_mask = mask.resize((width, height), Image.LANCZOS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
if randomize_seed_checkbox:
|
126 |
seed_slicer = random.randint(0, MAX_SEED)
|
127 |
generator = torch.Generator().manual_seed(seed_slicer)
|
128 |
+
result = pipe(
|
129 |
+
prompt=input_text,
|
130 |
+
image=resized_image,
|
131 |
+
mask_image=resized_mask,
|
132 |
width=width,
|
133 |
height=height,
|
134 |
strength=strength_slider,
|
|
|
136 |
num_inference_steps=num_inference_steps_slider
|
137 |
).images[0]
|
138 |
print('INFERENCE DONE')
|
139 |
+
return result, resized_mask
|
140 |
|
141 |
|
142 |
with gr.Blocks() as demo:
|
|
|
152 |
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
153 |
|
154 |
with gr.Row():
|
155 |
+
input_text_component = gr.Text(
|
156 |
label="Prompt",
|
157 |
show_label=False,
|
158 |
max_lines=1,
|
159 |
+
placeholder="Enter your prompt",
|
160 |
container=False,
|
161 |
)
|
162 |
submit_button_component = gr.Button(
|
163 |
value='Submit', variant='primary', scale=0)
|
164 |
|
165 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
seed_slicer_component = gr.Slider(
|
167 |
label="Seed",
|
168 |
minimum=0,
|
|
|
201 |
with gr.Accordion("Debug", open=False):
|
202 |
output_mask_component = gr.Image(
|
203 |
type='pil', image_mode='RGB', label='Input mask', format="png")
|
204 |
+
with gr.Row():
|
205 |
+
gr.Examples(
|
206 |
+
fn=process,
|
207 |
+
examples=EXAMPLES,
|
208 |
+
inputs=[
|
209 |
+
input_image_editor_component,
|
210 |
+
input_text_component,
|
211 |
+
seed_slicer_component,
|
212 |
+
randomize_seed_checkbox_component,
|
213 |
+
strength_slider_component,
|
214 |
+
num_inference_steps_slider_component
|
215 |
+
],
|
216 |
+
outputs=[
|
217 |
+
output_image_component,
|
218 |
+
output_mask_component
|
219 |
+
],
|
220 |
+
run_on_click=True,
|
221 |
+
cache_examples=True
|
222 |
+
)
|
223 |
|
224 |
submit_button_component.click(
|
225 |
fn=process,
|
226 |
inputs=[
|
227 |
input_image_editor_component,
|
228 |
+
input_text_component,
|
|
|
229 |
seed_slicer_component,
|
230 |
randomize_seed_checkbox_component,
|
231 |
strength_slider_component,
|
configs/__init__.py
CHANGED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
configs/sam2_hiera_b+.yaml
DELETED
@@ -1,113 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# Model
|
4 |
-
model:
|
5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
-
image_encoder:
|
7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
-
scalp: 1
|
9 |
-
trunk:
|
10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
-
embed_dim: 112
|
12 |
-
num_heads: 2
|
13 |
-
neck:
|
14 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
15 |
-
position_encoding:
|
16 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
17 |
-
num_pos_feats: 256
|
18 |
-
normalize: true
|
19 |
-
scale: null
|
20 |
-
temperature: 10000
|
21 |
-
d_model: 256
|
22 |
-
backbone_channel_list: [896, 448, 224, 112]
|
23 |
-
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
24 |
-
fpn_interp_model: nearest
|
25 |
-
|
26 |
-
memory_attention:
|
27 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
28 |
-
d_model: 256
|
29 |
-
pos_enc_at_input: true
|
30 |
-
layer:
|
31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
32 |
-
activation: relu
|
33 |
-
dim_feedforward: 2048
|
34 |
-
dropout: 0.1
|
35 |
-
pos_enc_at_attn: false
|
36 |
-
self_attention:
|
37 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
38 |
-
rope_theta: 10000.0
|
39 |
-
feat_sizes: [32, 32]
|
40 |
-
embedding_dim: 256
|
41 |
-
num_heads: 1
|
42 |
-
downsample_rate: 1
|
43 |
-
dropout: 0.1
|
44 |
-
d_model: 256
|
45 |
-
pos_enc_at_cross_attn_keys: true
|
46 |
-
pos_enc_at_cross_attn_queries: false
|
47 |
-
cross_attention:
|
48 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
49 |
-
rope_theta: 10000.0
|
50 |
-
feat_sizes: [32, 32]
|
51 |
-
rope_k_repeat: True
|
52 |
-
embedding_dim: 256
|
53 |
-
num_heads: 1
|
54 |
-
downsample_rate: 1
|
55 |
-
dropout: 0.1
|
56 |
-
kv_in_dim: 64
|
57 |
-
num_layers: 4
|
58 |
-
|
59 |
-
memory_encoder:
|
60 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
61 |
-
out_dim: 64
|
62 |
-
position_encoding:
|
63 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
64 |
-
num_pos_feats: 64
|
65 |
-
normalize: true
|
66 |
-
scale: null
|
67 |
-
temperature: 10000
|
68 |
-
mask_downsampler:
|
69 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
70 |
-
kernel_size: 3
|
71 |
-
stride: 2
|
72 |
-
padding: 1
|
73 |
-
fuser:
|
74 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
75 |
-
layer:
|
76 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
77 |
-
dim: 256
|
78 |
-
kernel_size: 7
|
79 |
-
padding: 3
|
80 |
-
layer_scale_init_value: 1e-6
|
81 |
-
use_dwconv: True # depth-wise convs
|
82 |
-
num_layers: 2
|
83 |
-
|
84 |
-
num_maskmem: 7
|
85 |
-
image_size: 1024
|
86 |
-
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
87 |
-
sigmoid_scale_for_mem_enc: 20.0
|
88 |
-
sigmoid_bias_for_mem_enc: -10.0
|
89 |
-
use_mask_input_as_output_without_sam: true
|
90 |
-
# Memory
|
91 |
-
directly_add_no_mem_embed: true
|
92 |
-
# use high-resolution feature map in the SAM mask decoder
|
93 |
-
use_high_res_features_in_sam: true
|
94 |
-
# output 3 masks on the first click on initial conditioning frames
|
95 |
-
multimask_output_in_sam: true
|
96 |
-
# SAM heads
|
97 |
-
iou_prediction_use_sigmoid: True
|
98 |
-
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
99 |
-
use_obj_ptrs_in_encoder: true
|
100 |
-
add_tpos_enc_to_obj_ptrs: false
|
101 |
-
only_obj_ptrs_in_the_past_for_eval: true
|
102 |
-
# object occlusion prediction
|
103 |
-
pred_obj_scores: true
|
104 |
-
pred_obj_scores_mlp: true
|
105 |
-
fixed_no_obj_ptr: true
|
106 |
-
# multimask tracking settings
|
107 |
-
multimask_output_for_tracking: true
|
108 |
-
use_multimask_token_for_obj_ptr: true
|
109 |
-
multimask_min_pt_num: 0
|
110 |
-
multimask_max_pt_num: 1
|
111 |
-
use_mlp_for_obj_ptr_proj: true
|
112 |
-
# Compilation flag
|
113 |
-
compile_image_encoder: False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/sam2_hiera_l.yaml
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# Model
|
4 |
-
model:
|
5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
-
image_encoder:
|
7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
-
scalp: 1
|
9 |
-
trunk:
|
10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
-
embed_dim: 144
|
12 |
-
num_heads: 2
|
13 |
-
stages: [2, 6, 36, 4]
|
14 |
-
global_att_blocks: [23, 33, 43]
|
15 |
-
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
-
window_spec: [8, 4, 16, 8]
|
17 |
-
neck:
|
18 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
19 |
-
position_encoding:
|
20 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
21 |
-
num_pos_feats: 256
|
22 |
-
normalize: true
|
23 |
-
scale: null
|
24 |
-
temperature: 10000
|
25 |
-
d_model: 256
|
26 |
-
backbone_channel_list: [1152, 576, 288, 144]
|
27 |
-
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
28 |
-
fpn_interp_model: nearest
|
29 |
-
|
30 |
-
memory_attention:
|
31 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
32 |
-
d_model: 256
|
33 |
-
pos_enc_at_input: true
|
34 |
-
layer:
|
35 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
36 |
-
activation: relu
|
37 |
-
dim_feedforward: 2048
|
38 |
-
dropout: 0.1
|
39 |
-
pos_enc_at_attn: false
|
40 |
-
self_attention:
|
41 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
42 |
-
rope_theta: 10000.0
|
43 |
-
feat_sizes: [32, 32]
|
44 |
-
embedding_dim: 256
|
45 |
-
num_heads: 1
|
46 |
-
downsample_rate: 1
|
47 |
-
dropout: 0.1
|
48 |
-
d_model: 256
|
49 |
-
pos_enc_at_cross_attn_keys: true
|
50 |
-
pos_enc_at_cross_attn_queries: false
|
51 |
-
cross_attention:
|
52 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
53 |
-
rope_theta: 10000.0
|
54 |
-
feat_sizes: [32, 32]
|
55 |
-
rope_k_repeat: True
|
56 |
-
embedding_dim: 256
|
57 |
-
num_heads: 1
|
58 |
-
downsample_rate: 1
|
59 |
-
dropout: 0.1
|
60 |
-
kv_in_dim: 64
|
61 |
-
num_layers: 4
|
62 |
-
|
63 |
-
memory_encoder:
|
64 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
65 |
-
out_dim: 64
|
66 |
-
position_encoding:
|
67 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
68 |
-
num_pos_feats: 64
|
69 |
-
normalize: true
|
70 |
-
scale: null
|
71 |
-
temperature: 10000
|
72 |
-
mask_downsampler:
|
73 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
74 |
-
kernel_size: 3
|
75 |
-
stride: 2
|
76 |
-
padding: 1
|
77 |
-
fuser:
|
78 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
79 |
-
layer:
|
80 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
81 |
-
dim: 256
|
82 |
-
kernel_size: 7
|
83 |
-
padding: 3
|
84 |
-
layer_scale_init_value: 1e-6
|
85 |
-
use_dwconv: True # depth-wise convs
|
86 |
-
num_layers: 2
|
87 |
-
|
88 |
-
num_maskmem: 7
|
89 |
-
image_size: 1024
|
90 |
-
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
91 |
-
sigmoid_scale_for_mem_enc: 20.0
|
92 |
-
sigmoid_bias_for_mem_enc: -10.0
|
93 |
-
use_mask_input_as_output_without_sam: true
|
94 |
-
# Memory
|
95 |
-
directly_add_no_mem_embed: true
|
96 |
-
# use high-resolution feature map in the SAM mask decoder
|
97 |
-
use_high_res_features_in_sam: true
|
98 |
-
# output 3 masks on the first click on initial conditioning frames
|
99 |
-
multimask_output_in_sam: true
|
100 |
-
# SAM heads
|
101 |
-
iou_prediction_use_sigmoid: True
|
102 |
-
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
-
use_obj_ptrs_in_encoder: true
|
104 |
-
add_tpos_enc_to_obj_ptrs: false
|
105 |
-
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
-
# object occlusion prediction
|
107 |
-
pred_obj_scores: true
|
108 |
-
pred_obj_scores_mlp: true
|
109 |
-
fixed_no_obj_ptr: true
|
110 |
-
# multimask tracking settings
|
111 |
-
multimask_output_for_tracking: true
|
112 |
-
use_multimask_token_for_obj_ptr: true
|
113 |
-
multimask_min_pt_num: 0
|
114 |
-
multimask_max_pt_num: 1
|
115 |
-
use_mlp_for_obj_ptr_proj: true
|
116 |
-
# Compilation flag
|
117 |
-
compile_image_encoder: False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/sam2_hiera_s.yaml
DELETED
@@ -1,116 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# Model
|
4 |
-
model:
|
5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
-
image_encoder:
|
7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
-
scalp: 1
|
9 |
-
trunk:
|
10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
-
embed_dim: 96
|
12 |
-
num_heads: 1
|
13 |
-
stages: [1, 2, 11, 2]
|
14 |
-
global_att_blocks: [7, 10, 13]
|
15 |
-
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
-
neck:
|
17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
-
position_encoding:
|
19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
-
num_pos_feats: 256
|
21 |
-
normalize: true
|
22 |
-
scale: null
|
23 |
-
temperature: 10000
|
24 |
-
d_model: 256
|
25 |
-
backbone_channel_list: [768, 384, 192, 96]
|
26 |
-
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
-
fpn_interp_model: nearest
|
28 |
-
|
29 |
-
memory_attention:
|
30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
-
d_model: 256
|
32 |
-
pos_enc_at_input: true
|
33 |
-
layer:
|
34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
-
activation: relu
|
36 |
-
dim_feedforward: 2048
|
37 |
-
dropout: 0.1
|
38 |
-
pos_enc_at_attn: false
|
39 |
-
self_attention:
|
40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
-
rope_theta: 10000.0
|
42 |
-
feat_sizes: [32, 32]
|
43 |
-
embedding_dim: 256
|
44 |
-
num_heads: 1
|
45 |
-
downsample_rate: 1
|
46 |
-
dropout: 0.1
|
47 |
-
d_model: 256
|
48 |
-
pos_enc_at_cross_attn_keys: true
|
49 |
-
pos_enc_at_cross_attn_queries: false
|
50 |
-
cross_attention:
|
51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
-
rope_theta: 10000.0
|
53 |
-
feat_sizes: [32, 32]
|
54 |
-
rope_k_repeat: True
|
55 |
-
embedding_dim: 256
|
56 |
-
num_heads: 1
|
57 |
-
downsample_rate: 1
|
58 |
-
dropout: 0.1
|
59 |
-
kv_in_dim: 64
|
60 |
-
num_layers: 4
|
61 |
-
|
62 |
-
memory_encoder:
|
63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
-
out_dim: 64
|
65 |
-
position_encoding:
|
66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
-
num_pos_feats: 64
|
68 |
-
normalize: true
|
69 |
-
scale: null
|
70 |
-
temperature: 10000
|
71 |
-
mask_downsampler:
|
72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
-
kernel_size: 3
|
74 |
-
stride: 2
|
75 |
-
padding: 1
|
76 |
-
fuser:
|
77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
-
layer:
|
79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
-
dim: 256
|
81 |
-
kernel_size: 7
|
82 |
-
padding: 3
|
83 |
-
layer_scale_init_value: 1e-6
|
84 |
-
use_dwconv: True # depth-wise convs
|
85 |
-
num_layers: 2
|
86 |
-
|
87 |
-
num_maskmem: 7
|
88 |
-
image_size: 1024
|
89 |
-
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
-
sigmoid_scale_for_mem_enc: 20.0
|
91 |
-
sigmoid_bias_for_mem_enc: -10.0
|
92 |
-
use_mask_input_as_output_without_sam: true
|
93 |
-
# Memory
|
94 |
-
directly_add_no_mem_embed: true
|
95 |
-
# use high-resolution feature map in the SAM mask decoder
|
96 |
-
use_high_res_features_in_sam: true
|
97 |
-
# output 3 masks on the first click on initial conditioning frames
|
98 |
-
multimask_output_in_sam: true
|
99 |
-
# SAM heads
|
100 |
-
iou_prediction_use_sigmoid: True
|
101 |
-
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
102 |
-
use_obj_ptrs_in_encoder: true
|
103 |
-
add_tpos_enc_to_obj_ptrs: false
|
104 |
-
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
-
# object occlusion prediction
|
106 |
-
pred_obj_scores: true
|
107 |
-
pred_obj_scores_mlp: true
|
108 |
-
fixed_no_obj_ptr: true
|
109 |
-
# multimask tracking settings
|
110 |
-
multimask_output_for_tracking: true
|
111 |
-
use_multimask_token_for_obj_ptr: true
|
112 |
-
multimask_min_pt_num: 0
|
113 |
-
multimask_max_pt_num: 1
|
114 |
-
use_mlp_for_obj_ptr_proj: true
|
115 |
-
# Compilation flag
|
116 |
-
compile_image_encoder: False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/sam2_hiera_t.yaml
DELETED
@@ -1,118 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# Model
|
4 |
-
model:
|
5 |
-
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
-
image_encoder:
|
7 |
-
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
-
scalp: 1
|
9 |
-
trunk:
|
10 |
-
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
-
embed_dim: 96
|
12 |
-
num_heads: 1
|
13 |
-
stages: [1, 2, 7, 2]
|
14 |
-
global_att_blocks: [5, 7, 9]
|
15 |
-
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
-
neck:
|
17 |
-
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
-
position_encoding:
|
19 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
-
num_pos_feats: 256
|
21 |
-
normalize: true
|
22 |
-
scale: null
|
23 |
-
temperature: 10000
|
24 |
-
d_model: 256
|
25 |
-
backbone_channel_list: [768, 384, 192, 96]
|
26 |
-
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
-
fpn_interp_model: nearest
|
28 |
-
|
29 |
-
memory_attention:
|
30 |
-
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
-
d_model: 256
|
32 |
-
pos_enc_at_input: true
|
33 |
-
layer:
|
34 |
-
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
-
activation: relu
|
36 |
-
dim_feedforward: 2048
|
37 |
-
dropout: 0.1
|
38 |
-
pos_enc_at_attn: false
|
39 |
-
self_attention:
|
40 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
-
rope_theta: 10000.0
|
42 |
-
feat_sizes: [32, 32]
|
43 |
-
embedding_dim: 256
|
44 |
-
num_heads: 1
|
45 |
-
downsample_rate: 1
|
46 |
-
dropout: 0.1
|
47 |
-
d_model: 256
|
48 |
-
pos_enc_at_cross_attn_keys: true
|
49 |
-
pos_enc_at_cross_attn_queries: false
|
50 |
-
cross_attention:
|
51 |
-
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
-
rope_theta: 10000.0
|
53 |
-
feat_sizes: [32, 32]
|
54 |
-
rope_k_repeat: True
|
55 |
-
embedding_dim: 256
|
56 |
-
num_heads: 1
|
57 |
-
downsample_rate: 1
|
58 |
-
dropout: 0.1
|
59 |
-
kv_in_dim: 64
|
60 |
-
num_layers: 4
|
61 |
-
|
62 |
-
memory_encoder:
|
63 |
-
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
-
out_dim: 64
|
65 |
-
position_encoding:
|
66 |
-
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
-
num_pos_feats: 64
|
68 |
-
normalize: true
|
69 |
-
scale: null
|
70 |
-
temperature: 10000
|
71 |
-
mask_downsampler:
|
72 |
-
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
-
kernel_size: 3
|
74 |
-
stride: 2
|
75 |
-
padding: 1
|
76 |
-
fuser:
|
77 |
-
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
-
layer:
|
79 |
-
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
-
dim: 256
|
81 |
-
kernel_size: 7
|
82 |
-
padding: 3
|
83 |
-
layer_scale_init_value: 1e-6
|
84 |
-
use_dwconv: True # depth-wise convs
|
85 |
-
num_layers: 2
|
86 |
-
|
87 |
-
num_maskmem: 7
|
88 |
-
image_size: 1024
|
89 |
-
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
-
# SAM decoder
|
91 |
-
sigmoid_scale_for_mem_enc: 20.0
|
92 |
-
sigmoid_bias_for_mem_enc: -10.0
|
93 |
-
use_mask_input_as_output_without_sam: true
|
94 |
-
# Memory
|
95 |
-
directly_add_no_mem_embed: true
|
96 |
-
# use high-resolution feature map in the SAM mask decoder
|
97 |
-
use_high_res_features_in_sam: true
|
98 |
-
# output 3 masks on the first click on initial conditioning frames
|
99 |
-
multimask_output_in_sam: true
|
100 |
-
# SAM heads
|
101 |
-
iou_prediction_use_sigmoid: True
|
102 |
-
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
103 |
-
use_obj_ptrs_in_encoder: true
|
104 |
-
add_tpos_enc_to_obj_ptrs: false
|
105 |
-
only_obj_ptrs_in_the_past_for_eval: true
|
106 |
-
# object occlusion prediction
|
107 |
-
pred_obj_scores: true
|
108 |
-
pred_obj_scores_mlp: true
|
109 |
-
fixed_no_obj_ptr: true
|
110 |
-
# multimask tracking settings
|
111 |
-
multimask_output_for_tracking: true
|
112 |
-
use_multimask_token_for_obj_ptr: true
|
113 |
-
multimask_min_pt_num: 0
|
114 |
-
multimask_max_pt_num: 1
|
115 |
-
use_mlp_for_obj_ptr_proj: true
|
116 |
-
# Compilation flag
|
117 |
-
# HieraT does not currently support compilation, should always be set to False
|
118 |
-
compile_image_encoder: False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
-
tqdm
|
2 |
-
einops
|
3 |
-
timm
|
4 |
-
samv2
|
5 |
-
opencv-python
|
6 |
-
pytest
|
7 |
gradio
|
8 |
spaces
|
9 |
accelerate
|
10 |
transformers==4.42.4
|
11 |
sentencepiece
|
12 |
-
|
13 |
-
git+https://github.com/Gothos/diffusers.git@flux-inpaint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
gradio
|
2 |
spaces
|
3 |
accelerate
|
4 |
transformers==4.42.4
|
5 |
sentencepiece
|
6 |
+
git+https://github.com/Gothos/diffusers.git@flux-inpaint
|
|
utils/florence.py
CHANGED
@@ -1,54 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
from typing import Union, Any, Tuple, Dict
|
3 |
-
from unittest.mock import patch
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from PIL import Image
|
7 |
-
from transformers import AutoModelForCausalLM, AutoProcessor
|
8 |
-
from transformers.dynamic_module_utils import get_imports
|
9 |
-
|
10 |
-
FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
|
11 |
-
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
|
12 |
-
|
13 |
-
|
14 |
-
def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
|
15 |
-
"""Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
|
16 |
-
if not str(filename).endswith("/modeling_florence2.py"):
|
17 |
-
return get_imports(filename)
|
18 |
-
imports = get_imports(filename)
|
19 |
-
imports.remove("flash_attn")
|
20 |
-
return imports
|
21 |
-
|
22 |
-
|
23 |
-
def load_florence_model(
|
24 |
-
device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
|
25 |
-
) -> Tuple[Any, Any]:
|
26 |
-
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
27 |
-
model = AutoModelForCausalLM.from_pretrained(
|
28 |
-
checkpoint, trust_remote_code=True).to(device).eval()
|
29 |
-
processor = AutoProcessor.from_pretrained(
|
30 |
-
checkpoint, trust_remote_code=True)
|
31 |
-
return model, processor
|
32 |
-
|
33 |
-
|
34 |
-
def run_florence_inference(
|
35 |
-
model: Any,
|
36 |
-
processor: Any,
|
37 |
-
device: torch.device,
|
38 |
-
image: Image,
|
39 |
-
task: str,
|
40 |
-
text: str = ""
|
41 |
-
) -> Tuple[str, Dict]:
|
42 |
-
prompt = task + text
|
43 |
-
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
44 |
-
generated_ids = model.generate(
|
45 |
-
input_ids=inputs["input_ids"],
|
46 |
-
pixel_values=inputs["pixel_values"],
|
47 |
-
max_new_tokens=1024,
|
48 |
-
num_beams=3
|
49 |
-
)
|
50 |
-
generated_text = processor.batch_decode(
|
51 |
-
generated_ids, skip_special_tokens=False)[0]
|
52 |
-
response = processor.post_process_generation(
|
53 |
-
generated_text, task=task, image_size=image.size)
|
54 |
-
return generated_text, response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/sam.py
CHANGED
@@ -1,45 +0,0 @@
|
|
1 |
-
from typing import Any
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import supervision as sv
|
5 |
-
import torch
|
6 |
-
from PIL import Image
|
7 |
-
from sam2.build_sam import build_sam2, build_sam2_video_predictor
|
8 |
-
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
9 |
-
|
10 |
-
SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
|
11 |
-
SAM_CONFIG = "sam2_hiera_s.yaml"
|
12 |
-
|
13 |
-
|
14 |
-
def load_sam_image_model(
|
15 |
-
device: torch.device,
|
16 |
-
config: str = SAM_CONFIG,
|
17 |
-
checkpoint: str = SAM_CHECKPOINT
|
18 |
-
) -> SAM2ImagePredictor:
|
19 |
-
model = build_sam2(config, checkpoint, device=device)
|
20 |
-
return SAM2ImagePredictor(sam_model=model)
|
21 |
-
|
22 |
-
|
23 |
-
def load_sam_video_model(
|
24 |
-
device: torch.device,
|
25 |
-
config: str = SAM_CONFIG,
|
26 |
-
checkpoint: str = SAM_CHECKPOINT
|
27 |
-
) -> Any:
|
28 |
-
return build_sam2_video_predictor(config, checkpoint, device=device)
|
29 |
-
|
30 |
-
|
31 |
-
def run_sam_inference(
|
32 |
-
model: Any,
|
33 |
-
image: Image,
|
34 |
-
detections: sv.Detections
|
35 |
-
) -> sv.Detections:
|
36 |
-
image = np.array(image.convert("RGB"))
|
37 |
-
model.set_image(image)
|
38 |
-
mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
|
39 |
-
|
40 |
-
# dirty fix; remove this later
|
41 |
-
if len(mask.shape) == 4:
|
42 |
-
mask = np.squeeze(mask)
|
43 |
-
|
44 |
-
detections.mask = mask.astype(bool)
|
45 |
-
return detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|