File size: 1,292 Bytes
7d52396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import os
from typing import Optional
from urllib.request import urlretrieve

files = {
    "original_model.py": "https://gist.githubusercontent.com/lantiga/fd36849fb1c498da949a0af635318a7b/raw/7dd20f51c2a1ff2886387f0e25c1750a485a08e1/llama_model.py",
    "original_adapter.py": "https://gist.githubusercontent.com/awaelchli/546f33fcdb84cc9f1b661ca1ca18418d/raw/e81d8f35fb1fec53af1099349b0c455fc8c9fb01/original_adapter.py",
}


def download_original(wd: str) -> None:
    for file, url in files.items():
        filepath = os.path.join(wd, file)
        if not os.path.isfile(filepath):
            print(f"Downloading original implementation to {filepath!r}")
            urlretrieve(url=url, filename=file)
            print("Done")
        else:
            print("Original implementation found. Skipping download.")


def download_from_hub(repo_id: Optional[str] = None, local_dir: str = "checkpoints/hf-llama/7B") -> None:
    if repo_id is None:
        raise ValueError("Please pass `--repo_id=...`. You can try googling 'huggingface hub llama' for options.")

    from huggingface_hub import snapshot_download

    snapshot_download(repo_id, local_dir=local_dir, local_dir_use_symlinks=False)


if __name__ == "__main__":
    from jsonargparse import CLI

    CLI(download_from_hub)