Web-ai-app / download_models.py
Hadiil's picture
Update download_models.py
a2febbf verified
raw
history blame
3.03 kB
import os
import logging
import time
from transformers import AutoModel, AutoTokenizer, AutoProcessor
import retrying
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# List of models to download (excluding distilbert-base-cased-distilled-squad)
MODELS = [
{
"name": "sshleifer/distilbart-cnn-12-6",
"model_class": AutoModel,
"tokenizer_class": AutoTokenizer
},
{
"name": "Salesforce/blip-image-captioning-large",
"model_class": AutoModel,
"tokenizer_class": AutoProcessor
},
{
"name": "dandelin/vilt-b32-finetuned-vqa",
"model_class": AutoModel,
"tokenizer_class": AutoProcessor
},
{
"name": "facebook/m2m100_418M",
"model_class": AutoModel,
"tokenizer_class": AutoTokenizer
}
]
# Retry decorator for network failures
@retrying.retry(
stop_max_attempt_number=3,
wait_fixed=5000,
retry_on_exception=lambda e: isinstance(e, Exception)
)
def download_model(model_name, model_class, tokenizer_class):
"""Download a model and its tokenizer/processor with retries."""
cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
logger.info(f"Downloading model: {model_name} to {cache_dir}")
try:
# Download model
model = model_class.from_pretrained(model_name, cache_dir=cache_dir)
logger.info(f"Successfully downloaded model: {model_name}")
# Download tokenizer/processor
tokenizer = tokenizer_class.from_pretrained(model_name, cache_dir=cache_dir)
logger.info(f"Successfully downloaded tokenizer/processor: {model_name}")
return True
except Exception as e:
logger.error(f"Failed to download {model_name}: {str(e)}")
raise
def main():
"""Main function to download all models."""
cache_dir = os.getenv('HF_HOME', '/cache/huggingface')
# Verify cache directory permissions
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)
if not os.access(cache_dir, os.W_OK):
logger.error(f"Cache directory {cache_dir} is not writable")
exit(1)
success = True
for model_info in MODELS:
model_name = model_info["name"]
model_class = model_info["model_class"]
tokenizer_class = model_info["tokenizer_class"]
try:
start_time = time.time()
download_model(model_name, model_class, tokenizer_class)
logger.info(f"Downloaded {model_name} in {time.time() - start_time:.2f} seconds")
except Exception as e:
logger.error(f"Failed to download {model_name} after retries: {str(e)}")
success = False
if not success:
logger.warning("Some model downloads failed, but continuing build")
else:
logger.info("All models downloaded successfully")
if __name__ == "__main__":
main()