Spaces:
Sleeping
Sleeping
עדכון נסיוני
Browse files- backend.py +37 -21
- checkpoints/sam2_hiera_small.pt +3 -0
- configs/sam2_hiera_s.yaml +116 -0
- requirements.txt +1 -0
backend.py
CHANGED
@@ -145,8 +145,9 @@ YOLO_MODEL_PATH = '../../models/yolo11m.pt'
|
|
145 |
try:
|
146 |
yolo_model = YOLO(YOLO_MODEL_PATH)
|
147 |
yolo_model.to("cpu")
|
|
|
148 |
except Exception as e:
|
149 |
-
print(f"[YOLO] לא מצליח לטעון את המודל בנתיב: {YOLO_MODEL_PATH}")
|
150 |
yolo_model = None
|
151 |
|
152 |
TARGET_CLASS = "person"
|
@@ -155,40 +156,56 @@ CONF_THRESHOLD = 0.2
|
|
155 |
# -----------------------------
|
156 |
# 4) הכנה ל-SAM2
|
157 |
# -----------------------------
|
158 |
-
|
159 |
from typing import Any
|
160 |
import supervision as sv
|
161 |
from sam2.build_sam import build_sam2
|
162 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
163 |
|
164 |
SAM2_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
|
165 |
-
SAM2_CONFIG = "sam2_hiera_s.yaml"
|
166 |
|
167 |
def load_sam_image_model(
|
168 |
device: torch.device,
|
169 |
config: str = SAM2_CONFIG,
|
170 |
checkpoint: str = SAM2_CHECKPOINT
|
171 |
) -> SAM2ImagePredictor:
|
172 |
-
|
173 |
-
|
174 |
-
sam2_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2_hiera_small.pt"
|
175 |
os.makedirs(os.path.dirname(checkpoint), exist_ok=True)
|
176 |
-
response = requests.get(sam2_url)
|
177 |
-
with open(checkpoint, 'wb') as f:
|
178 |
-
f.write(response.content)
|
179 |
-
print("[SAM2] מודל SAM2 הורד בהצלחה.")
|
180 |
-
|
181 |
-
if not os.path.exists(config):
|
182 |
-
print("[SAM2] קובץ הקונפיג SAM2 לא נמצא. מנסה להוריד את הקונפיג...")
|
183 |
-
sam2_config_url = "https://path_to_your_config/sam2_hiera_s.yaml" # עדכן את הקישור המתאים
|
184 |
os.makedirs(os.path.dirname(config), exist_ok=True)
|
185 |
-
response = requests.get(sam2_config_url)
|
186 |
-
with open(config, 'wb') as f:
|
187 |
-
f.write(response.content)
|
188 |
-
print("[SAM2] קובץ הקונפיג SAM2 הורד בהצלחה.")
|
189 |
|
190 |
-
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
try:
|
194 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -240,7 +257,6 @@ def blur_regions_with_mask(
|
|
240 |
|
241 |
return Image.fromarray(combined)
|
242 |
|
243 |
-
|
244 |
# -----------------------------
|
245 |
# 6) הפונקציה המרכזית
|
246 |
# -----------------------------
|
|
|
145 |
try:
|
146 |
yolo_model = YOLO(YOLO_MODEL_PATH)
|
147 |
yolo_model.to("cpu")
|
148 |
+
print("[YOLO] מודל YOLO נטען בהצלחה.")
|
149 |
except Exception as e:
|
150 |
+
print(f"[YOLO] לא מצליח לטעון את המודל בנתיב: {YOLO_MODEL_PATH}. שגיאה: {e}")
|
151 |
yolo_model = None
|
152 |
|
153 |
TARGET_CLASS = "person"
|
|
|
156 |
# -----------------------------
|
157 |
# 4) הכנה ל-SAM2
|
158 |
# -----------------------------
|
|
|
159 |
from typing import Any
|
160 |
import supervision as sv
|
161 |
from sam2.build_sam import build_sam2
|
162 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
163 |
|
164 |
SAM2_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
|
165 |
+
SAM2_CONFIG = "configs/sam2_hiera_s.yaml"
|
166 |
|
167 |
def load_sam_image_model(
|
168 |
device: torch.device,
|
169 |
config: str = SAM2_CONFIG,
|
170 |
checkpoint: str = SAM2_CHECKPOINT
|
171 |
) -> SAM2ImagePredictor:
|
172 |
+
try:
|
173 |
+
# יצירת התיקיות אם הן לא קיימות
|
|
|
174 |
os.makedirs(os.path.dirname(checkpoint), exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
os.makedirs(os.path.dirname(config), exist_ok=True)
|
|
|
|
|
|
|
|
|
176 |
|
177 |
+
# הורדת קובץ ה-checkpoint אם אינו קיים
|
178 |
+
if not os.path.exists(checkpoint):
|
179 |
+
print("[SAM2] מודל SAM2 לא נמצא. מנסה להוריד את המודל...")
|
180 |
+
sam2_url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2_hiera_small.pt"
|
181 |
+
response = requests.get(sam2_url)
|
182 |
+
if response.status_code == 200:
|
183 |
+
with open(checkpoint, 'wb') as f:
|
184 |
+
f.write(response.content)
|
185 |
+
print("[SAM2] מודל SAM2 הורד בהצלחה.")
|
186 |
+
else:
|
187 |
+
raise FileNotFoundError(f"לא הצליח להוריד את מודל SAM2 מהכתובת: {sam2_url}")
|
188 |
+
|
189 |
+
# הורדת קובץ הקונפיג אם אינו קיים
|
190 |
+
if not os.path.exists(config):
|
191 |
+
print("[SAM2] קובץ הקונפיג SAM2 לא נמצא. מנסה להוריד את הקונפיג...")
|
192 |
+
sam2_config_url = "https://raw.githubusercontent.com/facebookresearch/sam2/refs/heads/main/sam2/configs/sam2/sam2_hiera_s.yaml" # עדכן לכתובת הנכונה
|
193 |
+
response = requests.get(sam2_config_url)
|
194 |
+
if response.status_code == 200:
|
195 |
+
with open(config, 'wb') as f:
|
196 |
+
f.write(response.content)
|
197 |
+
print("[SAM2] קובץ הקונפיג SAM2 הורד בהצלחה.")
|
198 |
+
else:
|
199 |
+
raise FileNotFoundError(f"לא הצליח להוריד את קובץ הקונפיג SAM2 מהכתובת: {sam2_config_url}")
|
200 |
+
|
201 |
+
# בניית המודל
|
202 |
+
print("[SAM2] מנסה לבנות את המודל SAM2...")
|
203 |
+
model = build_sam2(config, checkpoint, device=device)
|
204 |
+
print("[SAM2] המודל SAM2 נבנה בהצלחה.")
|
205 |
+
return SAM2ImagePredictor(sam_model=model)
|
206 |
+
except Exception as e:
|
207 |
+
print(f"[SAM2] שגיאה בטעינת המודל SAM2: {e}")
|
208 |
+
raise e # העלאת השגיאה למעלה כדי שתוכל להיתפס ולהדפיס בהמשך
|
209 |
|
210 |
try:
|
211 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
257 |
|
258 |
return Image.fromarray(combined)
|
259 |
|
|
|
260 |
# -----------------------------
|
261 |
# 6) הפונקציה המרכזית
|
262 |
# -----------------------------
|
checkpoints/sam2_hiera_small.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ec09b256af142490dd9f363799c90f0bbb854e19142070fbe045eb1e7673ed6
|
3 |
+
size 243
|
configs/sam2_hiera_s.yaml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# Model
|
4 |
+
model:
|
5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
6 |
+
image_encoder:
|
7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
8 |
+
scalp: 1
|
9 |
+
trunk:
|
10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
11 |
+
embed_dim: 96
|
12 |
+
num_heads: 1
|
13 |
+
stages: [1, 2, 11, 2]
|
14 |
+
global_att_blocks: [7, 10, 13]
|
15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
16 |
+
neck:
|
17 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
18 |
+
position_encoding:
|
19 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
20 |
+
num_pos_feats: 256
|
21 |
+
normalize: true
|
22 |
+
scale: null
|
23 |
+
temperature: 10000
|
24 |
+
d_model: 256
|
25 |
+
backbone_channel_list: [768, 384, 192, 96]
|
26 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
27 |
+
fpn_interp_model: nearest
|
28 |
+
|
29 |
+
memory_attention:
|
30 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
31 |
+
d_model: 256
|
32 |
+
pos_enc_at_input: true
|
33 |
+
layer:
|
34 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
35 |
+
activation: relu
|
36 |
+
dim_feedforward: 2048
|
37 |
+
dropout: 0.1
|
38 |
+
pos_enc_at_attn: false
|
39 |
+
self_attention:
|
40 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
41 |
+
rope_theta: 10000.0
|
42 |
+
feat_sizes: [64, 64]
|
43 |
+
embedding_dim: 256
|
44 |
+
num_heads: 1
|
45 |
+
downsample_rate: 1
|
46 |
+
dropout: 0.1
|
47 |
+
d_model: 256
|
48 |
+
pos_enc_at_cross_attn_keys: true
|
49 |
+
pos_enc_at_cross_attn_queries: false
|
50 |
+
cross_attention:
|
51 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
52 |
+
rope_theta: 10000.0
|
53 |
+
feat_sizes: [64, 64]
|
54 |
+
rope_k_repeat: True
|
55 |
+
embedding_dim: 256
|
56 |
+
num_heads: 1
|
57 |
+
downsample_rate: 1
|
58 |
+
dropout: 0.1
|
59 |
+
kv_in_dim: 64
|
60 |
+
num_layers: 4
|
61 |
+
|
62 |
+
memory_encoder:
|
63 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
64 |
+
out_dim: 64
|
65 |
+
position_encoding:
|
66 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
67 |
+
num_pos_feats: 64
|
68 |
+
normalize: true
|
69 |
+
scale: null
|
70 |
+
temperature: 10000
|
71 |
+
mask_downsampler:
|
72 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
73 |
+
kernel_size: 3
|
74 |
+
stride: 2
|
75 |
+
padding: 1
|
76 |
+
fuser:
|
77 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
78 |
+
layer:
|
79 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
80 |
+
dim: 256
|
81 |
+
kernel_size: 7
|
82 |
+
padding: 3
|
83 |
+
layer_scale_init_value: 1e-6
|
84 |
+
use_dwconv: True # depth-wise convs
|
85 |
+
num_layers: 2
|
86 |
+
|
87 |
+
num_maskmem: 7
|
88 |
+
image_size: 1024
|
89 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
90 |
+
sigmoid_scale_for_mem_enc: 20.0
|
91 |
+
sigmoid_bias_for_mem_enc: -10.0
|
92 |
+
use_mask_input_as_output_without_sam: true
|
93 |
+
# Memory
|
94 |
+
directly_add_no_mem_embed: true
|
95 |
+
# use high-resolution feature map in the SAM mask decoder
|
96 |
+
use_high_res_features_in_sam: true
|
97 |
+
# output 3 masks on the first click on initial conditioning frames
|
98 |
+
multimask_output_in_sam: true
|
99 |
+
# SAM heads
|
100 |
+
iou_prediction_use_sigmoid: True
|
101 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
102 |
+
use_obj_ptrs_in_encoder: true
|
103 |
+
add_tpos_enc_to_obj_ptrs: false
|
104 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
105 |
+
# object occlusion prediction
|
106 |
+
pred_obj_scores: true
|
107 |
+
pred_obj_scores_mlp: true
|
108 |
+
fixed_no_obj_ptr: true
|
109 |
+
# multimask tracking settings
|
110 |
+
multimask_output_for_tracking: true
|
111 |
+
use_multimask_token_for_obj_ptr: true
|
112 |
+
multimask_min_pt_num: 0
|
113 |
+
multimask_max_pt_num: 1
|
114 |
+
use_mlp_for_obj_ptr_proj: true
|
115 |
+
# Compilation flag
|
116 |
+
compile_image_encoder: False
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ ultralytics
|
|
8 |
scipy
|
9 |
hydra-core
|
10 |
torchvision
|
|
|
11 |
git+https://github.com/facebookresearch/sam2.git
|
|
|
8 |
scipy
|
9 |
hydra-core
|
10 |
torchvision
|
11 |
+
supervision
|
12 |
git+https://github.com/facebookresearch/sam2.git
|