Hadiil commited on
Commit
29465a2
·
verified ·
1 Parent(s): f0055ff

Update download_models.py

Browse files
Files changed (1) hide show
  1. download_models.py +100 -0
download_models.py CHANGED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import time
4
+ from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForQuestionAnswering
5
+ import retrying
6
+
7
+ # Configure logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(levelname)s - %(message)s'
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # List of models to download
15
+ MODELS = [
16
+ {
17
+ "name": "sshleifer/distilbart-cnn-12-6",
18
+ "model_class": AutoModel,
19
+ "tokenizer_class": AutoTokenizer
20
+ },
21
+ {
22
+ "name": "Salesforce/blip-image-captioning-large",
23
+ "model_class": AutoModel,
24
+ "tokenizer_class": AutoProcessor
25
+ },
26
+ {
27
+ "name": "dandelin/vilt-b32-finetuned-vqa",
28
+ "model_class": AutoModel,
29
+ "tokenizer_class": AutoProcessor
30
+ },
31
+ {
32
+ "name": "distilbert-base-cased-distilled-squad",
33
+ "model_class": AutoModelForQuestionAnswering,
34
+ "tokenizer_class": AutoTokenizer
35
+ },
36
+ {
37
+ "name": "facebook/m2m100_418M",
38
+ "model_class": AutoModel,
39
+ "tokenizer_class": AutoTokenizer
40
+ }
41
+ ]
42
+
43
+ # Retry decorator for network failures
44
+ @retrying.retry(
45
+ stop_max_attempt_number=3,
46
+ wait_fixed=5000,
47
+ retry_on_exception=lambda e: isinstance(e, Exception)
48
+ )
49
+ def download_model(model_name, model_class, tokenizer_class):
50
+ """Download a model and its tokenizer/processor with retries."""
51
+ cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
52
+ logger.info(f"Downloading model: {model_name} to {cache_dir}")
53
+
54
+ try:
55
+ # Download model
56
+ model = model_class.from_pretrained(model_name, cache_dir=cache_dir)
57
+ logger.info(f"Successfully downloaded model: {model_name}")
58
+
59
+ # Download tokenizer/processor
60
+ tokenizer = tokenizer_class.from_pretrained(model_name, cache_dir=cache_dir)
61
+ logger.info(f"Successfully downloaded tokenizer/processor: {model_name}")
62
+
63
+ return True
64
+ except Exception as e:
65
+ logger.error(f"Failed to download {model_name}: {str(e)}")
66
+ raise
67
+
68
+ def main():
69
+ """Main function to download all models."""
70
+ cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
71
+
72
+ # Verify cache directory permissions
73
+ if not os.path.exists(cache_dir):
74
+ os.makedirs(cache_dir, exist_ok=True)
75
+ if not os.access(cache_dir, os.W_OK):
76
+ logger.error(f"Cache directory {cache_dir} is not writable")
77
+ exit(1)
78
+
79
+ success = True
80
+ for model_info in MODELS:
81
+ model_name = model_info["name"]
82
+ model_class = model_info["model_class"]
83
+ tokenizer_class = model_info["tokenizer_class"]
84
+
85
+ try:
86
+ start_time = time.time()
87
+ download_model(model_name, model_class, tokenizer_class)
88
+ logger.info(f"Downloaded {model_name} in {time.time() - start_time:.2f} seconds")
89
+ except Exception as e:
90
+ logger.error(f"Failed to download {model_name} after retries: {str(e)}")
91
+ success = False
92
+
93
+ if not success:
94
+ logger.error("One or more model downloads failed")
95
+ exit(1)
96
+ else:
97
+ logger.info("All models downloaded successfully")
98
+
99
+ if __name__ == "__main__":
100
+ main()