KevinX-Penn28 commited on
Commit
2b669dd
·
verified ·
1 Parent(s): 1e0506d

Upload 3 files

Browse files
Files changed (3) hide show
  1. vine_config.py +18 -2
  2. vine_model.py +66 -20
  3. vine_pipeline.py +1 -1
vine_config.py CHANGED
@@ -40,7 +40,13 @@ class VineConfig(PretrainedConfig):
40
  self,
41
  model_name: str = "openai/clip-vit-base-patch32",
42
  hidden_dim = 768,
43
- pretrained_vine_path: Optional[str] = None,
 
 
 
 
 
 
44
  num_top_pairs: int = 18,
45
  segmentation_method: str = "grounding_dino_sam2",
46
  box_threshold: float = 0.35,
@@ -63,7 +69,17 @@ class VineConfig(PretrainedConfig):
63
  **kwargs
64
  ):
65
  self.model_name = model_name
66
- self.pretrained_vine_path = pretrained_vine_path
 
 
 
 
 
 
 
 
 
 
67
  self.hidden_dim = hidden_dim
68
  self.num_top_pairs = num_top_pairs
69
  self.segmentation_method = segmentation_method
 
40
  self,
41
  model_name: str = "openai/clip-vit-base-patch32",
42
  hidden_dim = 768,
43
+
44
+ use_hf_repo: bool = False,
45
+ model_repo: Optional[str] = None,
46
+ model_file: Optional[str] = None,
47
+ local_dir: Optional[str] = None,
48
+ local_filename: Optional[str] = None,
49
+
50
  num_top_pairs: int = 18,
51
  segmentation_method: str = "grounding_dino_sam2",
52
  box_threshold: float = 0.35,
 
69
  **kwargs
70
  ):
71
  self.model_name = model_name
72
+ self.use_hf_repo = use_hf_repo
73
+ if use_hf_repo:
74
+ self.model_repo = model_repo
75
+ self.model_file = model_file
76
+ self.local_dir = None
77
+ self.local_filename = None
78
+ else:
79
+ self.model_repo = None
80
+ self.model_file = None
81
+ self.local_dir = local_dir
82
+ self.local_filename = local_filename
83
  self.hidden_dim = hidden_dim
84
  self.num_top_pairs = num_top_pairs
85
  self.segmentation_method = segmentation_method
vine_model.py CHANGED
@@ -57,7 +57,6 @@ class VineModel(PreTrainedModel):
57
 
58
 
59
  # Initialize CLIP components
60
-
61
  self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
62
  if self.clip_tokenizer.pad_token is None:
63
  self.clip_tokenizer.pad_token = (
@@ -72,13 +71,36 @@ class VineModel(PreTrainedModel):
72
 
73
 
74
  # Then try to load pretrained VINE weights if specified
75
- if config.pretrained_vine_path:
76
- self._load_pretrained_vine_weights(config.pretrained_vine_path)
 
 
77
 
78
  # Move models to devicexwxw
79
  self.to(self._device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def _load_pretrained_vine_weights(self, pretrained_path: str, epoch: int = 0):
82
  """
83
  Load pretrained VINE weights from a saved .pt file or ensemble format.
84
  """
@@ -86,15 +108,11 @@ class VineModel(PreTrainedModel):
86
 
87
  # x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
88
  # print(f"Loaded VINE checkpoint type: {type(x)}")
89
- if pretrained_path == "video-fm/vine_v0":
90
- self.clip_tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
91
- self.clip_cate_model = AutoModel.from_pretrained(pretrained_path)
92
- self.clip_unary_model = AutoModel.from_pretrained(pretrained_path)
93
- self.clip_binary_model = AutoModel.from_pretrained(pretrained_path)
94
 
95
- if pretrained_path.endswith(".pkl"):
96
- print(f"Loading VINE weights from: {pretrained_path}")
97
- loaded_vine_model = torch.load(pretrained_path, map_location=self._device, weights_only=False)
98
 
99
  print(f"Loaded state type: {type(loaded_vine_model)}")
100
  if not isinstance(loaded_vine_model, dict):
@@ -106,17 +124,17 @@ class VineModel(PreTrainedModel):
106
  self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
107
  return True
108
 
109
- elif pretrained_path.endswith(".pt") or pretrained_path.endswith(".pth"):
110
- state = torch.load(pretrained_path, map_location=self._device, weights_only=True)
111
  print(f"Loaded state type: {type(state)}")
112
  self.load_state_dict(state)
113
  return True
114
 
115
  # handle directory + epoch format
116
- if os.path.isdir(pretrained_path):
117
- model_files = [f for f in os.listdir(pretrained_path) if f.endswith(f'.{epoch}.model')]
118
  if model_files:
119
- model_file = os.path.join(pretrained_path, model_files[0])
120
  print(f"Loading VINE weights from: {model_file}")
121
  pretrained_model = torch.load(model_file, map_location="cpu")
122
 
@@ -131,7 +149,7 @@ class VineModel(PreTrainedModel):
131
  print("✓ Loaded all sub-model weights from ensemble format")
132
  return True
133
  else:
134
- print(f"No model file found for epoch {epoch} in {pretrained_path}")
135
  return False
136
 
137
  print("Unsupported format for pretrained_vine_path")
@@ -249,10 +267,38 @@ class VineModel(PreTrainedModel):
249
  Returns:
250
  VineModel instance with loaded weights
251
  """
 
252
  if config is None:
253
- config = VineConfig(pretrained_vine_path=model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  else:
255
- config.pretrained_vine_path = model_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  # Create model instance (will automatically load weights)
258
  model = cls(config, **kwargs)
 
57
 
58
 
59
  # Initialize CLIP components
 
60
  self.clip_tokenizer = AutoTokenizer.from_pretrained(config.model_name)
61
  if self.clip_tokenizer.pad_token is None:
62
  self.clip_tokenizer.pad_token = (
 
71
 
72
 
73
  # Then try to load pretrained VINE weights if specified
74
+ if config.use_hf_repo:
75
+ self._load_huggingface_vine_weights(config.model_repo, config.model_file)
76
+ else:
77
+ self._load_local_pretrained_vine_weights(config.local_dir, config.local_filename)
78
 
79
  # Move models to devicexwxw
80
  self.to(self._device)
81
+
82
+ def _load_huggingface_vine_weights(self, model_repo: str, model_file: Optional[str] = None):
83
+ """
84
+ Load pretrained VINE weights from HuggingFace Hub.
85
+ """
86
+ try:
87
+ print(f"Loading VINE weights from HuggingFace repo: {model_repo}")
88
+ vine_model = AutoModel.from_pretrained(
89
+ model_repo,
90
+ trust_remote_code=True,
91
+ revision=model_file if model_file else "main"
92
+ )
93
+ self.clip_cate_model = vine_model.clip_cate_model
94
+ self.clip_unary_model = vine_model.clip_unary_model
95
+ self.clip_binary_model = vine_model.clip_binary_model
96
+ print("✓ Successfully loaded VINE weights from HuggingFace Hub")
97
+ return True
98
+ except Exception as e:
99
+ print(f"✗ Error loading VINE weights from HuggingFace Hub: {e}")
100
+ print("Using base CLIP models instead")
101
+ return False
102
 
103
+ def _load_local_pretrained_vine_weights(self, local_dir: str, local_filename: Optional[str] = None, epoch: int = 0):
104
  """
105
  Load pretrained VINE weights from a saved .pt file or ensemble format.
106
  """
 
108
 
109
  # x = torch.load(pretrained_path, map_location=self._device, weights_only=False)
110
  # print(f"Loaded VINE checkpoint type: {type(x)}")
111
+ full_path = os.path.join(local_dir, local_filename) if local_filename else local_dir
 
 
 
 
112
 
113
+ if full_path.endswith(".pkl"):
114
+ print(f"Loading VINE weights from: {full_path}")
115
+ loaded_vine_model = torch.load(full_path, map_location=self._device, weights_only=False)
116
 
117
  print(f"Loaded state type: {type(loaded_vine_model)}")
118
  if not isinstance(loaded_vine_model, dict):
 
124
  self.clip_binary_model.load_state_dict(loaded_vine_model.clip_binary_model.state_dict())
125
  return True
126
 
127
+ elif full_path.endswith(".pt") or full_path.endswith(".pth"):
128
+ state = torch.load(full_path, map_location=self._device, weights_only=True)
129
  print(f"Loaded state type: {type(state)}")
130
  self.load_state_dict(state)
131
  return True
132
 
133
  # handle directory + epoch format
134
+ if os.path.isdir(full_path):
135
+ model_files = [f for f in os.listdir(full_path) if f.endswith(f'.{epoch}.model')]
136
  if model_files:
137
+ model_file = os.path.join(full_path, model_files[0])
138
  print(f"Loading VINE weights from: {model_file}")
139
  pretrained_model = torch.load(model_file, map_location="cpu")
140
 
 
149
  print("✓ Loaded all sub-model weights from ensemble format")
150
  return True
151
  else:
152
+ print(f"No model file found for epoch {epoch} in {full_path}")
153
  return False
154
 
155
  print("Unsupported format for pretrained_vine_path")
 
267
  Returns:
268
  VineModel instance with loaded weights
269
  """
270
+ # Normalize the incoming model_path into the new VineConfig fields.
271
  if config is None:
272
+ # Heuristics: if path looks like a HF repo (contains a "/" and
273
+ # doesn't exist on disk) treat it as a repo. Otherwise treat as local.
274
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
275
+ config = VineConfig(use_hf_repo=True, model_repo=model_path)
276
+ else:
277
+ # Local path: could be a file or directory
278
+ if os.path.isdir(model_path):
279
+ config = VineConfig(use_hf_repo=False, local_dir=model_path)
280
+ else:
281
+ config = VineConfig(
282
+ use_hf_repo=False,
283
+ local_dir=os.path.dirname(model_path) or None,
284
+ local_filename=os.path.basename(model_path) or None,
285
+ )
286
  else:
287
+ # Update provided config to reflect the requested pretrained path
288
+ if model_path and ("/" in model_path and not os.path.exists(model_path)):
289
+ config.use_hf_repo = True
290
+ config.model_repo = model_path
291
+ config.model_file = None
292
+ config.local_dir = None
293
+ config.local_filename = None
294
+ else:
295
+ config.use_hf_repo = False
296
+ if os.path.isdir(model_path):
297
+ config.local_dir = model_path
298
+ config.local_filename = None
299
+ else:
300
+ config.local_dir = os.path.dirname(model_path) or None
301
+ config.local_filename = os.path.basename(model_path) or None
302
 
303
  # Create model instance (will automatically load weights)
304
  model = cls(config, **kwargs)
vine_pipeline.py CHANGED
@@ -391,7 +391,7 @@ class VinePipeline(Pipeline):
391
  crop_n_layers=2,
392
  box_nms_thresh=0.6,
393
  crop_n_points_downscale_factor=2,
394
- min_mask_region_area=30.0,
395
  use_m2m=True,
396
  )
397
  print("✓ SAM2 models initialized successfully")
 
391
  crop_n_layers=2,
392
  box_nms_thresh=0.6,
393
  crop_n_points_downscale_factor=2,
394
+ min_mask_region_area=100,
395
  use_m2m=True,
396
  )
397
  print("✓ SAM2 models initialized successfully")