zetavg commited on
Commit
a1c44f4
1 Parent(s): 184ef80

download_base_model: use snapshot_download from huggingface_hub to download models but not load them

Browse files
Files changed (2) hide show
  1. download_base_model.py +7 -7
  2. requirements.txt +1 -0
download_base_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import fire
2
 
3
- from llama_lora.models import get_new_base_model, clear_cache
4
 
5
 
6
  def main(
@@ -16,17 +16,17 @@ def main(
16
  base_model_names
17
  ), "Please specify --base_model_names, e.g. --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'"
18
 
19
- base_model_names = base_model_names.split(',')
20
- base_model_names = [name.strip() for name in base_model_names]
21
 
22
- print(f"Base models: {', '.join(base_model_names)}.")
23
 
24
- for name in base_model_names:
25
  print(f"Preparing {name}...")
26
- get_new_base_model(name)
27
- clear_cache()
28
 
29
  print("Done.")
30
 
 
31
  if __name__ == "__main__":
32
  fire.Fire(main)
 
1
  import fire
2
 
3
+ from huggingface_hub import snapshot_download
4
 
5
 
6
  def main(
 
16
  base_model_names
17
  ), "Please specify --base_model_names, e.g. --base_model_names='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j'"
18
 
19
+ base_model_names_list = base_model_names.split(',')
20
+ base_model_names_list = [name.strip() for name in base_model_names]
21
 
22
+ print(f"Base models: {', '.join(base_model_names_list)}.")
23
 
24
+ for name in base_model_names_list:
25
  print(f"Preparing {name}...")
26
+ snapshot_download(name)
 
27
 
28
  print("Done.")
29
 
30
+
31
  if __name__ == "__main__":
32
  fire.Fire(main)
requirements.txt CHANGED
@@ -7,6 +7,7 @@ datasets
7
  fire
8
  git+https://github.com/huggingface/peft.git
9
  git+https://github.com/huggingface/transformers.git
 
10
  numba
11
  nvidia-ml-py3
12
  gradio
 
7
  fire
8
  git+https://github.com/huggingface/peft.git
9
  git+https://github.com/huggingface/transformers.git
10
+ huggingface_hub
11
  numba
12
  nvidia-ml-py3
13
  gradio