File size: 16,675 Bytes
7f68877
d639ff1
7f68877
 
 
 
 
 
fcb0923
 
2b1595d
 
46ac36d
 
2b1595d
 
 
 
7f68877
 
 
d639ff1
 
 
 
 
 
 
 
 
 
 
8bba54d
 
d639ff1
8bba54d
 
 
 
 
d639ff1
8bba54d
 
88be7d7
 
 
 
 
8bba54d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d639ff1
88be7d7
 
 
d639ff1
8bba54d
d639ff1
 
 
 
8bba54d
d639ff1
 
88be7d7
d639ff1
46ac36d
 
88be7d7
 
 
 
 
 
 
46ac36d
88be7d7
46ac36d
 
 
8bba54d
 
 
 
 
 
88be7d7
 
 
 
8bba54d
 
 
 
 
 
 
 
 
 
 
 
 
88be7d7
 
 
8bba54d
 
 
 
 
 
 
88be7d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bba54d
 
 
 
46ac36d
 
 
 
7f68877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46ac36d
 
 
 
 
 
 
 
7f68877
 
46ac36d
7f68877
 
 
 
 
 
46ac36d
7f68877
 
 
 
 
 
 
46ac36d
7f68877
 
 
 
 
 
 
 
fcb0923
 
 
 
 
7f68877
 
2b1595d
7f68877
2b1595d
7f68877
 
 
 
2b1595d
7f68877
 
 
 
 
 
 
 
46ac36d
2b1595d
 
bc37575
2b1595d
 
 
bc37575
2b1595d
46ac36d
2b1595d
 
 
 
 
 
 
 
 
 
fcb0923
 
 
2b1595d
fcb0923
2b1595d
 
 
 
 
 
 
 
 
 
 
 
 
 
fcb0923
2b1595d
fcb0923
 
2b1595d
46ac36d
 
2b1595d
fcb0923
 
 
2b1595d
fcb0923
 
 
 
 
 
46ac36d
 
 
 
 
2b1595d
fcb0923
2b1595d
fcb0923
 
 
 
 
 
 
 
 
 
2b1595d
 
fcb0923
 
 
 
 
 
 
 
2b1595d
 
 
 
 
 
fcb0923
2b1595d
 
fcb0923
 
 
 
 
 
 
 
 
 
 
 
2b1595d
 
fcb0923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46ac36d
fcb0923
46ac36d
 
 
 
 
 
fcb0923
 
 
 
46ac36d
fcb0923
 
2b1595d
7f68877
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b1595d
 
 
d639ff1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
"""
Custom CLIP Model with Register Tokens - Import Safe Version with Complete File Download
"""
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers.utils import logging
from typing import Optional, Union, Tuple
import json
from pathlib import Path
import warnings
import os
import sys
import importlib.util

# Suppress all warnings during import
warnings.filterwarnings("ignore")
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

logger = logging.get_logger(__name__)

def ensure_all_files_downloaded():
    """Ensure all repository files are downloaded when this module is imported"""
    try:
        from huggingface_hub import snapshot_download, HfApi
        
        repo_id = 'amildravid4292/clip-vitb16-test-time-registers'
        
        # Get list of all files in the repository
        api = HfApi()
        all_files = api.list_repo_files(repo_id)
        
        # Download everything to ensure all files are available
        print(f"Ensuring all {len(all_files)} repository files are available...")
        
        local_dir = snapshot_download(
            repo_id=repo_id,
            resume_download=True,
            force_download=False  # Don't re-download existing files
        )
        
        print(f"βœ“ Repository files available at: {local_dir}")
        
        # Add the repository directory to Python path immediately
        if str(local_dir) not in sys.path:
            sys.path.insert(0, str(local_dir))
            print(f"βœ“ Added repository directory to Python path: {local_dir}")
        
        # Verify critical files are present
        critical_files = [f for f in all_files if f.endswith(('.py', '.pt', '.json'))]
        missing_critical = []
        
        for file in critical_files:
            file_path = Path(local_dir) / file
            if not file_path.exists():
                missing_critical.append(file)
        
        if missing_critical:
            print(f"Warning: {len(missing_critical)} critical files still missing")
            # Try individual downloads for missing critical files
            from huggingface_hub import hf_hub_download
            for file in missing_critical[:5]:  # Limit to avoid spam
                try:
                    hf_hub_download(repo_id=repo_id, filename=file, force_download=True)
                    print(f"βœ“ Downloaded {file}")
                except Exception as e:
                    print(f"βœ— Could not download {file}: {e}")
        else:
            print(f"βœ“ All {len(critical_files)} critical files verified present")
            
        # List the Python files we found for debugging
        python_files = [f for f in all_files if f.endswith('.py')]
        print(f"βœ“ Python files available: {python_files}")
            
        return local_dir
            
    except Exception as e:
        print(f"Warning: Could not verify/download all repository files: {e}")
        print("Model may still work if core files are present.")
        return None

# Download all files when this module is imported
_repo_dir = ensure_all_files_downloaded()

def safe_import_from_repo(module_name, repo_path):
    """Safely import a module from the downloaded repository"""
    
    # First, ensure the repository directory is in Python path
    global _repo_dir
    if _repo_dir and str(_repo_dir) not in sys.path:
        sys.path.insert(0, str(_repo_dir))
        print(f"βœ“ Added {_repo_dir} to Python path")
    
    try:
        # First try direct import (should work now that path is set)
        return __import__(module_name)
    except ImportError:
        try:
            # Multiple locations to search for the module
            search_paths = [
                Path(__file__).parent,  # Same directory as this file
                Path(__file__).parent.parent,  # Parent directory
            ]
            
            # Add the repository directory if we have it
            if _repo_dir:
                search_paths.append(Path(_repo_dir))
            
            # Also try to find the snapshot download location
            try:
                from transformers.utils import TRANSFORMERS_CACHE
                repo_cache_name = "models--amildravid4292--clip-vitb16-test-time-registers"
                cache_path = Path(TRANSFORMERS_CACHE) / repo_cache_name / "snapshots"
                
                # Find the most recent snapshot
                if cache_path.exists():
                    snapshot_dirs = [d for d in cache_path.iterdir() if d.is_dir()]
                    if snapshot_dirs:
                        # Get the most recent snapshot
                        latest_snapshot = max(snapshot_dirs, key=lambda x: x.stat().st_mtime)
                        search_paths.append(latest_snapshot)
                        # Also add this to Python path
                        if str(latest_snapshot) not in sys.path:
                            sys.path.insert(0, str(latest_snapshot))
            except:
                pass
            
            # Search in all possible locations
            for search_dir in search_paths:
                module_path = search_dir / f"{module_name}.py"
                if module_path.exists():
                    # Add this directory to Python path so relative imports work
                    if str(search_dir) not in sys.path:
                        sys.path.insert(0, str(search_dir))
                    
                    # Now try importing again
                    try:
                        return __import__(module_name)
                    except ImportError:
                        # If direct import still fails, try spec loading
                        spec = importlib.util.spec_from_file_location(module_name, module_path)
                        module = importlib.util.module_from_spec(spec)
                        sys.modules[module_name] = module
                        spec.loader.exec_module(module)
                        print(f"βœ“ Successfully imported {module_name} from {search_dir}")
                        return module
            
            # If we get here, we couldn't find the module anywhere
            searched_locations = [str(p) for p in search_paths]
            raise ImportError(f"Could not find {module_name}.py in any of these locations: {searched_locations}")
            
        except Exception as e:
            raise ImportError(f"Failed to import {module_name}: {e}")

class CustomCLIPConfig(PretrainedConfig):
    model_type = "custom_clip_with_registers"
    
    def __init__(
        self,
        vision_config=None,
        text_config=None,
        num_register_tokens=0,
        neuron_dict=None,
        projection_dim=512,
        logit_scale_init_value=2.6592,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        self.vision_config = vision_config or {}
        self.text_config = text_config or {}
        self.num_register_tokens = num_register_tokens
        self.neuron_dict = neuron_dict
        self.projection_dim = projection_dim
        self.logit_scale_init_value = logit_scale_init_value

class CustomCLIPModel(PreTrainedModel):
    config_class = CustomCLIPConfig
    
    def __init__(self, config):
        super().__init__(config)
        
        # Safe import of custom modules
        try:
            model_module = safe_import_from_repo('model', Path(__file__).parent)
            self.CLIP = model_module.CLIP
            self.CLIPVisionCfg = model_module.CLIPVisionCfg  
            self.CLIPTextCfg = model_module.CLIPTextCfg
        except ImportError as e:
            raise ImportError(f"Could not import model components: {e}. Make sure all model files are in the repository.")
        
        # Create vision and text configs
        vision_cfg = self.CLIPVisionCfg(
            layers=config.vision_config.get("num_hidden_layers", 12),
            width=config.vision_config.get("hidden_size", 768),
            patch_size=config.vision_config.get("patch_size", 16),
            image_size=config.vision_config.get("image_size", 224),
        )
        
        text_cfg = self.CLIPTextCfg(
            context_length=config.text_config.get("max_position_embeddings", 77),
            vocab_size=config.text_config.get("vocab_size", 49408),
            width=config.text_config.get("hidden_size", 512),
            layers=config.text_config.get("num_hidden_layers", 12),
        )
        
        # Initialize your custom CLIP model
        self.model = self.CLIP(
            embed_dim=config.projection_dim,
            vision_cfg=vision_cfg,
            text_cfg=text_cfg,
        )
        
        # These will be set when loading the state dict
        self.neuron_dict = None
        self.num_register_tokens = 0
        
        # These will be loaded separately
        self._tokenizer = None
        self._preprocessor = None
        self._zeroshot_classifier = None
    
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        """Override to handle custom parameters and load weights properly"""
        
        # Extract custom parameters first
        if 'neuron_dict' in state_dict:
            self.neuron_dict = state_dict.pop('neuron_dict')
        
        if 'num_register_tokens' in state_dict:
            self.num_register_tokens = state_dict.pop('num_register_tokens')
        
        # Set these values in the model
        if hasattr(self.model, 'visual'):
            self.model.visual.num_register_tokens = self.num_register_tokens
            self.model.visual.neuron_dict = self.neuron_dict
            self.model.num_register_tokens = self.num_register_tokens  
            self.model.neuron_dict = self.neuron_dict
        
        # Load the weights properly - suppress ALL warnings and errors
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Temporarily set logging to critical only
            original_level = logging.get_verbosity()
            logging.set_verbosity_error()
            
            try:
                # Load weights directly into self.model
                missing, unexpected = self.model.load_state_dict(state_dict, strict=False)
                
                # Don't report any missing/unexpected keys to avoid warnings
                
            except Exception as e:
                # If direct loading fails, try the parent method silently
                super()._load_from_state_dict(state_dict, prefix, local_metadata, False, [], [], [])
            finally:
                # Restore logging level
                logging.set_verbosity(original_level)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """Override to load cleanly and suppress warnings"""
        
        # Suppress warnings during loading
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            
            # Temporarily suppress transformers logging
            original_level = logging.get_verbosity()
            logging.set_verbosity_error()
            
            try:
                # Load the model
                model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
            finally:
                # Restore logging
                logging.set_verbosity(original_level)
        
        # Load additional components
        model._load_additional_components(pretrained_model_name_or_path)
        
        # Print clean success message
        print("Custom CLIP model loaded successfully!")
   
        
        return model
    
    def _load_additional_components(self, pretrained_model_name_or_path):
        """Load tokenizer, preprocessor, and zero-shot classifier silently"""
        
        try:
            from huggingface_hub import hf_hub_download
            
            # Load tokenizer
            try:
                # Safe import of tokenizer
                tokenizer_module = safe_import_from_repo('tokenizer', Path(__file__).parent)
                self._tokenizer = tokenizer_module.SimpleTokenizer()
            except ImportError:
                # If tokenizer import fails, create a dummy tokenizer message
                pass
            
            # Load preprocessor
            try:
                preprocess_config_file = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename="preprocessor_config.json"
                )
                
                with open(preprocess_config_file, 'r') as f:
                    preprocess_config = json.load(f)
                
                self._create_preprocessor(preprocess_config)
            except:
                pass
            
            # Load zero-shot classifier
            try:
                classifier_file = hf_hub_download(
                    repo_id=pretrained_model_name_or_path,
                    filename="zeroshot_classifier.pt"
                )
                
                # Suppress the torch.load warning
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    self._zeroshot_classifier = torch.load(classifier_file, map_location='cpu', weights_only=False)
            except:
                pass
                
        except:
            pass
    
    def _create_preprocessor(self, config):
        """Create image preprocessor from config"""
        try:
            from torchvision import transforms
            
            self._preprocessor = transforms.Compose([
                transforms.Resize(config["image_size"], interpolation=transforms.InterpolationMode.BICUBIC),
                transforms.CenterCrop(config["image_size"]),
                transforms.ToTensor(),
                transforms.Normalize(mean=config["image_mean"], std=config["image_std"]),
            ])
        except:
            pass
    
    @property
    def tokenizer(self):
        """Access the tokenizer"""
        return self._tokenizer
    
    @property 
    def preprocessor(self):
        """Access the image preprocessor"""
        return self._preprocessor
    
    @property
    def zeroshot_classifier(self):
        """Access the zero-shot classifier"""
        return self._zeroshot_classifier
    
    def tokenize(self, texts, context_length=77):
        """Tokenize text using the loaded tokenizer"""
        if self._tokenizer is None:
            raise ValueError("Tokenizer not available. Make sure tokenizer.py is in the repository.")
        
        # Safe import of tokenize function
        try:
            tokenizer_module = safe_import_from_repo('tokenizer', Path(__file__).parent)
            return tokenizer_module.tokenize(texts, context_length)
        except ImportError:
            raise ValueError("Could not import tokenize function.")
    
    def preprocess_image(self, image):
        """Preprocess image using the loaded preprocessor"""
        if self._preprocessor is None:
            raise ValueError("Preprocessor not loaded. Make sure preprocessor_config.json is in the repository.")
        
        return self._preprocessor(image)
    
    def forward(self, input_ids=None, pixel_values=None, num_register_tokens=None, neuron_dict=None, **kwargs):
        """Forward pass supporting your custom functionality"""
        
        if num_register_tokens is None:
            num_register_tokens = self.num_register_tokens
        if neuron_dict is None:
            neuron_dict = self.neuron_dict
            
        return self.model(
            image=pixel_values,
            text=input_ids,
            num_register_tokens=num_register_tokens,
            neuron_dict=neuron_dict
        )
    
    def encode_image(self, pixel_values, num_register_tokens=None, neuron_dict=None, **kwargs):
        """Encode images with register token support"""
        if num_register_tokens is None:
            num_register_tokens = self.num_register_tokens
        if neuron_dict is None:
            neuron_dict = self.neuron_dict
            
        return self.model.encode_image(
            pixel_values, 
            num_register_tokens=num_register_tokens, 
            neuron_dict=neuron_dict,
            **kwargs
        )
    
    def encode_text(self, input_ids, **kwargs):
        """Encode text"""
        return self.model.encode_text(input_ids, **kwargs)

# Auto-suppress warnings at module level
import transformers
transformers.logging.set_verbosity_error()