File size: 1,430 Bytes
fca1dff
 
a4b0060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3e85c8
 
 
 
 
a4b0060
 
927596e
a4b0060
fca1dff
a4b0060
 
 
 
d3e85c8
 
 
 
 
 
 
 
de0f77a
d3e85c8
fca1dff
d3e85c8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os

import urllib3

from huggingface_hub import snapshot_download


def change_default_timeout(new_timeout: int) -> None:
    """
    Changes the default timeout for downloading repositories from the Hugging Face Hub.
    Prevents the following errors:
    urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
    Read timed out. (read timeout=10)
    """
    urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout


def download_pytorch_model(name: str) -> None:
    """
    Downloads a pytorch model and all the small files from the model's repository.
    Other model formats (tensorflow, tflite, safetensors, msgpack and ot) are not downloaded.
    """
    number_of_seconds_in_a_day: int = 86_400
    change_default_timeout(number_of_seconds_in_a_day)
    curr_folder: str = os.path.dirname(__file__)
    snapshot_download(
        cache_dir=os.path.join(curr_folder, "huggingface", "models"),
        repo_id=name,
        etag_timeout=number_of_seconds_in_a_day,
        resume_download=True,
        repo_type="model",
        library_name="pt",
        # h5, tflite, safetensors, msgpack and ot models files are not needed
        ignore_patterns=[
            "*.h5",
            "*.tflite",
            "*.safetensors",
            "*.msgpack",
            "*.ot",
            "*.md"
        ],
    )


if __name__ == "__main__":
    download_pytorch_model("facebook/opt-125m")