Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import time | |
| from transformers import AutoModel, AutoTokenizer, AutoProcessor | |
| import retrying | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # List of models to download (excluding distilbert-base-cased-distilled-squad) | |
| MODELS = [ | |
| { | |
| "name": "sshleifer/distilbart-cnn-12-6", | |
| "model_class": AutoModel, | |
| "tokenizer_class": AutoTokenizer | |
| }, | |
| { | |
| "name": "Salesforce/blip-image-captioning-large", | |
| "model_class": AutoModel, | |
| "tokenizer_class": AutoProcessor | |
| }, | |
| { | |
| "name": "dandelin/vilt-b32-finetuned-vqa", | |
| "model_class": AutoModel, | |
| "tokenizer_class": AutoProcessor | |
| }, | |
| { | |
| "name": "facebook/m2m100_418M", | |
| "model_class": AutoModel, | |
| "tokenizer_class": AutoTokenizer | |
| } | |
| ] | |
| # Retry decorator for network failures | |
| def download_model(model_name, model_class, tokenizer_class): | |
| """Download a model and its tokenizer/processor with retries.""" | |
| cache_dir = os.getenv('HF_HOME', '/cache/huggingface') | |
| logger.info(f"Downloading model: {model_name} to {cache_dir}") | |
| try: | |
| # Download model | |
| model = model_class.from_pretrained(model_name, cache_dir=cache_dir) | |
| logger.info(f"Successfully downloaded model: {model_name}") | |
| # Download tokenizer/processor | |
| tokenizer = tokenizer_class.from_pretrained(model_name, cache_dir=cache_dir) | |
| logger.info(f"Successfully downloaded tokenizer/processor: {model_name}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to download {model_name}: {str(e)}") | |
| raise | |
| def main(): | |
| """Main function to download all models.""" | |
| cache_dir = os.getenv('HF_HOME', '/cache/huggingface') | |
| # Verify cache directory permissions | |
| if not os.path.exists(cache_dir): | |
| os.makedirs(cache_dir, exist_ok=True) | |
| if not os.access(cache_dir, os.W_OK): | |
| logger.error(f"Cache directory {cache_dir} is not writable") | |
| exit(1) | |
| success = True | |
| for model_info in MODELS: | |
| model_name = model_info["name"] | |
| model_class = model_info["model_class"] | |
| tokenizer_class = model_info["tokenizer_class"] | |
| try: | |
| start_time = time.time() | |
| download_model(model_name, model_class, tokenizer_class) | |
| logger.info(f"Downloaded {model_name} in {time.time() - start_time:.2f} seconds") | |
| except Exception as e: | |
| logger.error(f"Failed to download {model_name} after retries: {str(e)}") | |
| success = False | |
| if not success: | |
| logger.warning("Some model downloads failed, but continuing build") | |
| else: | |
| logger.info("All models downloaded successfully") | |
| if __name__ == "__main__": | |
| main() |