Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import PretrainedConfig | |
| from typing import List, Optional, Dict, Any, Tuple, Union | |
| from pathlib import Path | |
| class VineConfig(PretrainedConfig): | |
| """ | |
| Configuration class for VINE (Video Understanding with Natural Language) model. | |
| """ | |
| model_type = "vine" | |
| def __init__( | |
| self, | |
| model_name: str = "openai/clip-vit-base-patch32", | |
| hidden_dim: int = 768, | |
| use_hf_repo: bool = True, | |
| model_repo: Optional[str] = "KevinX-Penn28/testing", | |
| model_file: Optional[str] = None, | |
| local_dir: Optional[str] = str(Path(__file__).resolve().parent), | |
| local_filename: Optional[str] = "laser_model_v1.pkl", | |
| num_top_pairs: int = 18, | |
| segmentation_method: str = "grounding_dino_sam2", | |
| box_threshold: float = 0.35, | |
| text_threshold: float = 0.25, | |
| target_fps: int = 1, | |
| alpha: float = 0.5, | |
| white_alpha: float = 0.8, | |
| topk_cate: int = 3, | |
| multi_class: bool = False, | |
| output_logit: bool = False, | |
| use_pretrained_cate_weights: bool = False, | |
| categorical_pool: str = "mean", # "mean" or "max" | |
| max_video_length: int = 100, | |
| bbox_min_dim: int = 1, | |
| visualize: bool = False, | |
| visualization_dir: Optional[str] = None, | |
| return_flattened_segments: bool = False, | |
| return_valid_pairs: bool = False, | |
| interested_object_pairs: Optional[List[Tuple[int, int]]] = None, | |
| debug_visualizations: bool = False, | |
| device: Optional[Union[str, int]] = None, | |
| **kwargs: Any, | |
| ): | |
| self.model_name = model_name | |
| self.use_hf_repo = use_hf_repo | |
| if use_hf_repo: | |
| self.model_repo = model_repo | |
| self.model_file = model_file | |
| self.local_dir = None | |
| self.local_filename = None | |
| else: | |
| self.model_repo = None | |
| self.model_file = None | |
| self.local_dir = local_dir | |
| self.local_filename = local_filename | |
| self.hidden_dim = hidden_dim | |
| self.num_top_pairs = num_top_pairs | |
| self.segmentation_method = segmentation_method | |
| self.box_threshold = box_threshold | |
| self.text_threshold = text_threshold | |
| self.target_fps = target_fps | |
| self.alpha = alpha | |
| self.white_alpha = white_alpha | |
| self.topk_cate = topk_cate | |
| self.multi_class = multi_class | |
| self.output_logit = output_logit | |
| self.use_pretrained_cate_weights = use_pretrained_cate_weights | |
| self.categorical_pool = categorical_pool | |
| self.max_video_length = max_video_length | |
| self.bbox_min_dim = bbox_min_dim | |
| self.visualize = visualize | |
| self.visualization_dir = visualization_dir | |
| self.return_flattened_segments = return_flattened_segments | |
| self.return_valid_pairs = return_valid_pairs | |
| self.interested_object_pairs = interested_object_pairs or [] | |
| self.debug_visualizations = debug_visualizations | |
| if isinstance(device, int): | |
| self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| super().__init__(**kwargs) | |