Upload 3 files
Browse files- vine_config.py +18 -2
- vine_model.py +66 -20
- vine_pipeline.py +1 -1
vine_config.py
CHANGED
|
@@ -40,7 +40,13 @@ class VineConfig(PretrainedConfig):
|
|
| 40 |
self,
|
| 41 |
model_name: str = "openai/clip-vit-base-patch32",
|
| 42 |
hidden_dim = 768,
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
num_top_pairs: int = 18,
|
| 45 |
segmentation_method: str = "grounding_dino_sam2",
|
| 46 |
box_threshold: float = 0.35,
|
|
@@ -63,7 +69,17 @@ class VineConfig(PretrainedConfig):
|
|
| 63 |
**kwargs
|
| 64 |
):
|
| 65 |
self.model_name = model_name
|
| 66 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
self.hidden_dim = hidden_dim
|
| 68 |
self.num_top_pairs = num_top_pairs
|
| 69 |
self.segmentation_method = segmentation_method
|
|
|
|
| 40 |
self,
|
| 41 |
model_name: str = "openai/clip-vit-base-patch32",
|
| 42 |
hidden_dim = 768,
|
| 43 |
+
|
| 44 |
+
use_hf_repo: bool = False,
|
| 45 |
+
model_repo: Optional[str] = None,
|
| 46 |
+
model_file: Optional[str] = None,
|
| 47 |
+
local_dir: Optional[str] = None,
|
| 48 |
+
local_filename: Optional[str] = None,
|
| 49 |
+
|
| 50 |
num_top_pairs: int = 18,
|
| 51 |
segmentation_method: str = "grounding_dino_sam2",
|
| 52 |
box_threshold: float = 0.35,
|
|
|
|
| 69 |
**kwargs
|
| 70 |
):
|
| 71 |
self.model_name = model_name
|
| 72 |
+
self.use_hf_repo = use_hf_repo
|
| 73 |
+
if use_hf_repo:
|
| 74 |
+
self.model_repo = model_repo
|
| 75 |
+
self.model_file = model_file
|
| 76 |
+
self.local_dir = None
|
| 77 |
+
self.local_filename = None
|
| 78 |
+
else:
|
| 79 |
+
self.model_repo = None
|
| 80 |
+
self.model_file = None
|
| 81 |
+
self.local_dir = local_dir
|
| 82 |
+
self.local_filename = local_filename
|
| 83 |
self.hidden_dim = hidden_dim
|
| 84 |
self.num_top_pairs = num_top_pairs
|
| 85 |
self.segmentation_method = segmentation_method
|
vine_model.py
CHANGED
|
@@ -57,7 +57,6 @@ class VineModel(PreTrainedModel):
|
|
| 57 |
|
| 58 |
|
| 59 |
# Initialize CLIP components
|
| 60 |
-
|
| 61 |
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 62 |
if self.clip_tokenizer.pad_token is None:
|
| 63 |
self.clip_tokenizer.pad_token = (
|
|
@@ -72,13 +71,36 @@ class VineModel(PreTrainedModel):
|
|
| 72 |
|
| 73 |
|
| 74 |
# Then try to load pretrained VINE weights if specified
|
| 75 |
-
if config.
|
| 76 |
-
self.
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Move models to devicexwxw
|
| 79 |
self.to(self._device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
def
|
| 82 |
"""
|
| 83 |
Load pretrained VINE weights from a saved .pt file or ensemble format.
|
| 84 |
"""
|
|
@@ -86,15 +108,11 @@ class VineModel(PreTrainedModel):
|
|
| 86 |
|
| 87 |
# x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
|
| 88 |
# print(f"Loaded VINE checkpoint type: {type(x)}")
|
| 89 |
-
if
|
| 90 |
-
self.clip_tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
| 91 |
-
self.clip_cate_model = AutoModel.from_pretrained(pretrained_path)
|
| 92 |
-
self.clip_unary_model = AutoModel.from_pretrained(pretrained_path)
|
| 93 |
-
self.clip_binary_model = AutoModel.from_pretrained(pretrained_path)
|
| 94 |
|
| 95 |
-
if
|
| 96 |
-
print(f"Loading VINE weights from: {
|
| 97 |
-
loaded_vine_model = torch.load(
|
| 98 |
|
| 99 |
print(f"Loaded state type: {type(loaded_vine_model)}")
|
| 100 |
if not isinstance(loaded_vine_model, dict):
|
|
@@ -106,17 +124,17 @@ class VineModel(PreTrainedModel):
|
|
| 106 |
self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
|
| 107 |
return True
|
| 108 |
|
| 109 |
-
elif
|
| 110 |
-
state = torch.load(
|
| 111 |
print(f"Loaded state type: {type(state)}")
|
| 112 |
self.load_state_dict(state)
|
| 113 |
return True
|
| 114 |
|
| 115 |
# handle directory + epoch format
|
| 116 |
-
if os.path.isdir(
|
| 117 |
-
model_files = [f for f in os.listdir(
|
| 118 |
if model_files:
|
| 119 |
-
model_file = os.path.join(
|
| 120 |
print(f"Loading VINE weights from: {model_file}")
|
| 121 |
pretrained_model = torch.load(model_file, map_location="cpu")
|
| 122 |
|
|
@@ -131,7 +149,7 @@ class VineModel(PreTrainedModel):
|
|
| 131 |
print("✓ Loaded all sub-model weights from ensemble format")
|
| 132 |
return True
|
| 133 |
else:
|
| 134 |
-
print(f"No model file found for epoch {epoch} in {
|
| 135 |
return False
|
| 136 |
|
| 137 |
print("Unsupported format for pretrained_vine_path")
|
|
@@ -249,10 +267,38 @@ class VineModel(PreTrainedModel):
|
|
| 249 |
Returns:
|
| 250 |
VineModel instance with loaded weights
|
| 251 |
"""
|
|
|
|
| 252 |
if config is None:
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
else:
|
| 255 |
-
config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
# Create model instance (will automatically load weights)
|
| 258 |
model = cls(config, **kwargs)
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
# Initialize CLIP components
|
|
|
|
| 60 |
self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
| 61 |
if self.clip_tokenizer.pad_token is None:
|
| 62 |
self.clip_tokenizer.pad_token = (
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
# Then try to load pretrained VINE weights if specified
|
| 74 |
+
if config.use_hf_repo:
|
| 75 |
+
self._load_huggingface_vine_weights(config.model_repo, config.model_file)
|
| 76 |
+
else:
|
| 77 |
+
self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename)
|
| 78 |
|
| 79 |
# Move models to devicexwxw
|
| 80 |
self.to(self._device)
|
| 81 |
+
|
| 82 |
+
def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None):
|
| 83 |
+
"""
|
| 84 |
+
Load pretrained VINE weights from HuggingFace Hub.
|
| 85 |
+
"""
|
| 86 |
+
try:
|
| 87 |
+
print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
|
| 88 |
+
vine_model = AutoModel.from_pretrained(
|
| 89 |
+
model_repo,
|
| 90 |
+
trust_remote_code=True,
|
| 91 |
+
revision=model_file if model_file else "main"
|
| 92 |
+
)
|
| 93 |
+
self.clip_cate_model = vine_model.clip_cate_model
|
| 94 |
+
self.clip_unary_model = vine_model.clip_unary_model
|
| 95 |
+
self.clip_binary_model = vine_model.clip_binary_model
|
| 96 |
+
print("✓ Successfully loaded VINE weights from HuggingFace Hub")
|
| 97 |
+
return True
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
|
| 100 |
+
print("Using base CLIP models instead")
|
| 101 |
+
return False
|
| 102 |
|
| 103 |
+
def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0):
|
| 104 |
"""
|
| 105 |
Load pretrained VINE weights from a saved .pt file or ensemble format.
|
| 106 |
"""
|
|
|
|
| 108 |
|
| 109 |
# x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
|
| 110 |
# print(f"Loaded VINE checkpoint type: {type(x)}")
|
| 111 |
+
full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
if full_path.endswith(".pkl"):
|
| 114 |
+
print(f"Loading VINE weights from: {full_path}")
|
| 115 |
+
loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False)
|
| 116 |
|
| 117 |
print(f"Loaded state type: {type(loaded_vine_model)}")
|
| 118 |
if not isinstance(loaded_vine_model, dict):
|
|
|
|
| 124 |
self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
|
| 125 |
return True
|
| 126 |
|
| 127 |
+
elif full_path.endswith(".pt") or full_path.endswith(".pth"):
|
| 128 |
+
state = torch.load(full_path, map_location=self._device, weights_only=True)
|
| 129 |
print(f"Loaded state type: {type(state)}")
|
| 130 |
self.load_state_dict(state)
|
| 131 |
return True
|
| 132 |
|
| 133 |
# handle directory + epoch format
|
| 134 |
+
if os.path.isdir(full_path):
|
| 135 |
+
model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')]
|
| 136 |
if model_files:
|
| 137 |
+
model_file = os.path.join(full_path, model_files[0])
|
| 138 |
print(f"Loading VINE weights from: {model_file}")
|
| 139 |
pretrained_model = torch.load(model_file, map_location="cpu")
|
| 140 |
|
|
|
|
| 149 |
print("✓ Loaded all sub-model weights from ensemble format")
|
| 150 |
return True
|
| 151 |
else:
|
| 152 |
+
print(f"No model file found for epoch {epoch} in {full_path}")
|
| 153 |
return False
|
| 154 |
|
| 155 |
print("Unsupported format for pretrained_vine_path")
|
|
|
|
| 267 |
Returns:
|
| 268 |
VineModel instance with loaded weights
|
| 269 |
"""
|
| 270 |
+
# Normalize the incoming model_path into the new VineConfig fields.
|
| 271 |
if config is None:
|
| 272 |
+
# Heuristics: if path looks like a HF repo (contains a "/" and
|
| 273 |
+
# doesn't exist on disk) treat it as a repo. Otherwise treat as local.
|
| 274 |
+
if model_path and ("/" in model_path and not os.path.exists(model_path)):
|
| 275 |
+
config = VineConfig(use_hf_repo=True, model_repo=model_path)
|
| 276 |
+
else:
|
| 277 |
+
# Local path: could be a file or directory
|
| 278 |
+
if os.path.isdir(model_path):
|
| 279 |
+
config = VineConfig(use_hf_repo=False, local_dir=model_path)
|
| 280 |
+
else:
|
| 281 |
+
config = VineConfig(
|
| 282 |
+
use_hf_repo=False,
|
| 283 |
+
local_dir=os.path.dirname(model_path) or None,
|
| 284 |
+
local_filename=os.path.basename(model_path) or None,
|
| 285 |
+
)
|
| 286 |
else:
|
| 287 |
+
# Update provided config to reflect the requested pretrained path
|
| 288 |
+
if model_path and ("/" in model_path and not os.path.exists(model_path)):
|
| 289 |
+
config.use_hf_repo = True
|
| 290 |
+
config.model_repo = model_path
|
| 291 |
+
config.model_file = None
|
| 292 |
+
config.local_dir = None
|
| 293 |
+
config.local_filename = None
|
| 294 |
+
else:
|
| 295 |
+
config.use_hf_repo = False
|
| 296 |
+
if os.path.isdir(model_path):
|
| 297 |
+
config.local_dir = model_path
|
| 298 |
+
config.local_filename = None
|
| 299 |
+
else:
|
| 300 |
+
config.local_dir = os.path.dirname(model_path) or None
|
| 301 |
+
config.local_filename = os.path.basename(model_path) or None
|
| 302 |
|
| 303 |
# Create model instance (will automatically load weights)
|
| 304 |
model = cls(config, **kwargs)
|
vine_pipeline.py
CHANGED
|
@@ -391,7 +391,7 @@ class VinePipeline(Pipeline):
|
|
| 391 |
crop_n_layers=2,
|
| 392 |
box_nms_thresh=0.6,
|
| 393 |
crop_n_points_downscale_factor=2,
|
| 394 |
-
min_mask_region_area=
|
| 395 |
use_m2m=True,
|
| 396 |
)
|
| 397 |
print("✓ SAM2 models initialized successfully")
|
|
|
|
| 391 |
crop_n_layers=2,
|
| 392 |
box_nms_thresh=0.6,
|
| 393 |
crop_n_points_downscale_factor=2,
|
| 394 |
+
min_mask_region_area=100,
|
| 395 |
use_m2m=True,
|
| 396 |
)
|
| 397 |
print("✓ SAM2 models initialized successfully")
|