File size: 1,397 Bytes
a4b0060
 
 
 
cad3946
 
a4b0060
 
 
 
 
 
 
 
 
 
 
d3e85c8
 
 
474b6f1
d3e85c8
2101135
 
a4b0060
 
2101135
a4b0060
 
d3e85c8
 
 
 
 
 
 
 
de0f77a
d3e85c8
fca1dff
d3e85c8
 
 
cad3946
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import urllib3

from huggingface_hub import snapshot_download

from available_models import AVAILABLE_MODELS


def change_default_timeout(new_timeout: int) -> None:
    """
    Changes the default timeout for downloading repositories from the Hugging Face Hub.
    Prevents the following errors:
    urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
    Read timed out. (read timeout=10)
    """
    urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout


def download_pytorch_model(name: str) -> None:
    """
    Downloads a pytorch model and all the small files from the model's repository.
    Other model formats (tensorflow, tflite, safetensors, msgpack, ot...) are not downloaded.
    """
    number_of_seconds_in_a_year: int = 60 * 60 * 24 * 365
    change_default_timeout(number_of_seconds_in_a_year)
    snapshot_download(
        repo_id=name,
        etag_timeout=number_of_seconds_in_a_year,
        resume_download=True,
        repo_type="model",
        library_name="pt",
        # h5, tflite, safetensors, msgpack and ot models files are not needed
        ignore_patterns=[
            "*.h5",
            "*.tflite",
            "*.safetensors",
            "*.msgpack",
            "*.ot",
            "*.md"
        ],
    )


if __name__ == "__main__":
    for model_name in AVAILABLE_MODELS:
        download_pytorch_model(model_name)