File size: 4,379 Bytes
087df0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata
from huggingface_hub.utils import disable_progress_bars
from pathlib import Path
from rich.progress import Progress
from fire import Fire
from typing import Union, List

_EXPERTS = [
    "10_model.pth",
    "Unified_learned_OCIM_RS200_6x+2x.pth",
    "dpt_hybrid-midas-501f0c75.pt",
    "icdar2015_hourglass88.pth",
    "model_final_e0c58e.pkl",
    "model_final_f07440.pkl",
    "scannet.pt",
]

_MODELS = [
    "vqa_prismer_base",
    "vqa_prismer_large",
    "vqa_prismerz_base",
    "vqa_prismerz_large",
    "caption_prismerz_base",
    "caption_prismerz_large",
    "caption_prismer_base",
    "caption_prismer_large",
    "pretrain_prismer_base",
    "pretrain_prismer_large",
    "pretrain_prismerz_base",
    "pretrain_prismerz_large",
]

_REPO_ID = "lorenmt/prismer"


def download_checkpoints(
        download_experts: bool = False,
        download_models: Union[bool, List] = False,
        hide_tqdm: bool = False,
        force_redownload: bool = False,
):
    if hide_tqdm:
        disable_progress_bars()
    # Convert to list and check for invalid names
    download_experts = _EXPERTS if download_experts else []
    if download_models:
        # only download single model
        if isinstance(download_models, str):
            download_models = [download_models]

        assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}"
        download_models = _MODELS if isinstance(download_models, bool) else download_models
    else:
        download_models = []

    # Check if files already exist
    if not force_redownload:
        download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()]
        download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()]
    
    assert download_experts or download_models, "Nothing to download."

    with Progress() as progress:
        # Calculate total download size
        progress.print("[blue]Calculating download size...")
        total_size = 0
        for expert in download_experts:
            url = hf_hub_url(
                filename=expert,
                repo_id=_REPO_ID,
                subfolder="expert_weights"
            )
            total_size += get_hf_file_metadata(url).size

        for model in download_models:
            url = hf_hub_url(
                filename=f"pytorch_model.bin",
                repo_id=_REPO_ID,
                subfolder=model
            )
            total_size += get_hf_file_metadata(url).size
        progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB")

        # Download files
        total_files = len(download_experts) + len(download_models)
        total_task = progress.add_task(f"[green]Downloading files", total=total_files)
        if download_experts:
            expert_task = progress.add_task(
                f"[green]Downloading experts...", total=len(download_experts)
                )
            out_folder = Path("experts/expert_weights")
            out_folder.mkdir(parents=True, exist_ok=True)
            for expert in download_experts:
                path = Path(hf_hub_download(
                    filename=expert,
                    repo_id=_REPO_ID,
                    subfolder="expert_weights"
                ))
                path.resolve().rename(out_folder/path.name)
                path.unlink()
                progress.advance(expert_task)
                progress.advance(total_task)

        if download_models:
            model_task = progress.add_task(
                f"[green]Downloading models...", total=len(download_models)
                )
            for model in download_models:
                path = Path(hf_hub_download(
                    filename=f"pytorch_model.bin",
                    repo_id=_REPO_ID,
                    subfolder=model
                ))
                out_folder = Path("./logging")/model
                out_folder.mkdir(parents=True, exist_ok=True)
                path.resolve().rename(out_folder/"pytorch_model.bin")
                path.unlink()
                progress.advance(model_task)
                progress.advance(total_task)
        progress.print("[green]Done!")


if __name__ == "__main__":
    Fire(download_checkpoints)