prismer / prismer /download_checkpoints.py
shikunl's picture
Add prismer
087df0e
raw history blame
No virus
4.38 kB
from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata
from huggingface_hub.utils import disable_progress_bars
from pathlib import Path
from rich.progress import Progress
from fire import Fire
from typing import Union, List
_EXPERTS = [
"10_model.pth",
"Unified_learned_OCIM_RS200_6x+2x.pth",
"dpt_hybrid-midas-501f0c75.pt",
"icdar2015_hourglass88.pth",
"model_final_e0c58e.pkl",
"model_final_f07440.pkl",
"scannet.pt",
]
_MODELS = [
"vqa_prismer_base",
"vqa_prismer_large",
"vqa_prismerz_base",
"vqa_prismerz_large",
"caption_prismerz_base",
"caption_prismerz_large",
"caption_prismer_base",
"caption_prismer_large",
"pretrain_prismer_base",
"pretrain_prismer_large",
"pretrain_prismerz_base",
"pretrain_prismerz_large",
]
_REPO_ID = "lorenmt/prismer"
def download_checkpoints(
download_experts: bool = False,
download_models: Union[bool, List] = False,
hide_tqdm: bool = False,
force_redownload: bool = False,
):
if hide_tqdm:
disable_progress_bars()
# Convert to list and check for invalid names
download_experts = _EXPERTS if download_experts else []
if download_models:
# only download single model
if isinstance(download_models, str):
download_models = [download_models]
assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}"
download_models = _MODELS if isinstance(download_models, bool) else download_models
else:
download_models = []
# Check if files already exist
if not force_redownload:
download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()]
download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()]
assert download_experts or download_models, "Nothing to download."
with Progress() as progress:
# Calculate total download size
progress.print("[blue]Calculating download size...")
total_size = 0
for expert in download_experts:
url = hf_hub_url(
filename=expert,
repo_id=_REPO_ID,
subfolder="expert_weights"
)
total_size += get_hf_file_metadata(url).size
for model in download_models:
url = hf_hub_url(
filename=f"pytorch_model.bin",
repo_id=_REPO_ID,
subfolder=model
)
total_size += get_hf_file_metadata(url).size
progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB")
# Download files
total_files = len(download_experts) + len(download_models)
total_task = progress.add_task(f"[green]Downloading files", total=total_files)
if download_experts:
expert_task = progress.add_task(
f"[green]Downloading experts...", total=len(download_experts)
)
out_folder = Path("experts/expert_weights")
out_folder.mkdir(parents=True, exist_ok=True)
for expert in download_experts:
path = Path(hf_hub_download(
filename=expert,
repo_id=_REPO_ID,
subfolder="expert_weights"
))
path.resolve().rename(out_folder/path.name)
path.unlink()
progress.advance(expert_task)
progress.advance(total_task)
if download_models:
model_task = progress.add_task(
f"[green]Downloading models...", total=len(download_models)
)
for model in download_models:
path = Path(hf_hub_download(
filename=f"pytorch_model.bin",
repo_id=_REPO_ID,
subfolder=model
))
out_folder = Path("./logging")/model
out_folder.mkdir(parents=True, exist_ok=True)
path.resolve().rename(out_folder/"pytorch_model.bin")
path.unlink()
progress.advance(model_task)
progress.advance(total_task)
progress.print("[green]Done!")
if __name__ == "__main__":
Fire(download_checkpoints)