import urllib3 from huggingface_hub import snapshot_download from available_models import AVAILABLE_MODELS def change_default_timeout(new_timeout: int) -> None: """ Changes the default timeout for downloading repositories from the Hugging Face Hub. Prevents the following errors: urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10) """ urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout def download_pytorch_model(name: str) -> None: """ Downloads a pytorch model and all the small files from the model's repository. Other model formats (tensorflow, tflite, safetensors, msgpack, ot...) are not downloaded. """ number_of_seconds_in_a_year: int = 60 * 60 * 24 * 365 change_default_timeout(number_of_seconds_in_a_year) snapshot_download( repo_id=name, etag_timeout=number_of_seconds_in_a_year, resume_download=True, repo_type="model", library_name="pt", # h5, tflite, safetensors, msgpack and ot models files are not needed ignore_patterns=[ "*.h5", "*.tflite", "*.safetensors", "*.msgpack", "*.ot", "*.md" ], ) if __name__ == "__main__": for model_name in AVAILABLE_MODELS: download_pytorch_model(model_name)