downloading only the pytorch model and important files, not the other versions of the model
d3e85c8
import os | |
import urllib3 | |
from huggingface_hub import snapshot_download | |
def change_default_timeout(new_timeout: int) -> None: | |
""" | |
Changes the default timeout for downloading repositories from the Hugging Face Hub. | |
Prevents the following errors: | |
urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443): | |
Read timed out. (read timeout=10) | |
""" | |
urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout | |
def download_pytorch_model(name: str) -> None: | |
""" | |
Downloads a pytorch model and all the small files from the model's repository. | |
Other model formats (tensorflow, tflite, safetensors, msgpack and ot) are not downloaded. | |
""" | |
number_of_seconds_in_a_day: int = 86_400 | |
change_default_timeout(number_of_seconds_in_a_day) | |
curr_folder: str = os.path.dirname(__file__) | |
snapshot_download( | |
cache_dir=os.path.join(curr_folder, "huggingface", "models"), | |
repo_id=name, | |
etag_timeout=number_of_seconds_in_a_day, | |
resume_download=True, | |
repo_type="model", | |
library_name="pt", | |
# h5, tflite, safetensors, msgpack and ot models files are not needed | |
ignore_patterns=[ | |
"*.h5", | |
"*.tflite", | |
"*.safetensors", | |
"*.msgpack", | |
"*.ot", | |
], | |
) | |
if __name__ == "__main__": | |
download_pytorch_model("facebook/opt-125m") | |