from huggingface_hub import hf_hub_download, hf_hub_url, get_hf_file_metadata from huggingface_hub.utils import disable_progress_bars from pathlib import Path from rich.progress import Progress from fire import Fire from typing import Union, List _EXPERTS = [ "10_model.pth", "Unified_learned_OCIM_RS200_6x+2x.pth", "dpt_hybrid-midas-501f0c75.pt", "icdar2015_hourglass88.pth", "model_final_e0c58e.pkl", "model_final_f07440.pkl", "scannet.pt", ] _MODELS = [ "vqa_prismer_base", "vqa_prismer_large", "vqa_prismerz_base", "vqa_prismerz_large", "caption_prismerz_base", "caption_prismerz_large", "caption_prismer_base", "caption_prismer_large", "pretrain_prismer_base", "pretrain_prismer_large", "pretrain_prismerz_base", "pretrain_prismerz_large", ] _REPO_ID = "lorenmt/prismer" def download_checkpoints( download_experts: bool = False, download_models: Union[bool, List] = False, hide_tqdm: bool = False, force_redownload: bool = False, ): if hide_tqdm: disable_progress_bars() # Convert to list and check for invalid names download_experts = _EXPERTS if download_experts else [] if download_models: # only download single model if isinstance(download_models, str): download_models = [download_models] assert all([m in _MODELS for m in download_models]), f"Invalid model name. Must be one of {_MODELS}" download_models = _MODELS if isinstance(download_models, bool) else download_models else: download_models = [] # Check if files already exist if not force_redownload: download_experts = [e for e in download_experts if not Path(f"./experts/expert_weights/{e}").exists()] download_models = [m for m in download_models if not Path(f"{m}/pytorch_model.bin").exists()] assert download_experts or download_models, "Nothing to download." with Progress() as progress: # Calculate total download size progress.print("[blue]Calculating download size...") total_size = 0 for expert in download_experts: url = hf_hub_url( filename=expert, repo_id=_REPO_ID, subfolder="expert_weights" ) total_size += get_hf_file_metadata(url).size for model in download_models: url = hf_hub_url( filename=f"pytorch_model.bin", repo_id=_REPO_ID, subfolder=model ) total_size += get_hf_file_metadata(url).size progress.print(f"[blue]Total download size: {total_size / 1e9:.2f} GB") # Download files total_files = len(download_experts) + len(download_models) total_task = progress.add_task(f"[green]Downloading files", total=total_files) if download_experts: expert_task = progress.add_task( f"[green]Downloading experts...", total=len(download_experts) ) out_folder = Path("experts/expert_weights") out_folder.mkdir(parents=True, exist_ok=True) for expert in download_experts: path = Path(hf_hub_download( filename=expert, repo_id=_REPO_ID, subfolder="expert_weights" )) path.resolve().rename(out_folder/path.name) path.unlink() progress.advance(expert_task) progress.advance(total_task) if download_models: model_task = progress.add_task( f"[green]Downloading models...", total=len(download_models) ) for model in download_models: path = Path(hf_hub_download( filename=f"pytorch_model.bin", repo_id=_REPO_ID, subfolder=model )) out_folder = Path("./logging")/model out_folder.mkdir(parents=True, exist_ok=True) path.resolve().rename(out_folder/"pytorch_model.bin") path.unlink() progress.advance(model_task) progress.advance(total_task) progress.print("[green]Done!") if __name__ == "__main__": Fire(download_checkpoints)