yonikremer commited on
Commit
a4b0060
·
1 Parent(s): 28a440d

added a function to download models from the hub

Browse files
Files changed (3) hide show
  1. download_repo.py +26 -0
  2. hanlde_form_submit.py +12 -1
  3. on_server_start.py +2 -15
download_repo.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib3
2
+
3
+ from huggingface_hub import snapshot_download
4
+
5
+
6
+ def change_default_timeout(new_timeout: int) -> None:
7
+ """
8
+ Changes the default timeout for downloading repositories from the Hugging Face Hub.
9
+ Prevents the following errors:
10
+ urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='huggingface.co', port=443):
11
+ Read timed out. (read timeout=10)
12
+ """
13
+ urllib3.util.timeout.DEFAULT_TIMEOUT = new_timeout
14
+
15
+
16
+ def download_repository(name: str) -> None:
17
+ """Downloads a repository from the Hugging Face Hub."""
18
+ number_of_seconds_in_a_day: int = 86_400
19
+ change_default_timeout(number_of_seconds_in_a_day)
20
+ snapshot_download(
21
+ repo_id=name,
22
+ etag_timeout=number_of_seconds_in_a_day,
23
+ resume_download=True,
24
+ repo_type="model",
25
+ library_name="pytorch"
26
+ )
hanlde_form_submit.py CHANGED
@@ -3,6 +3,7 @@ from time import time
3
  import streamlit as st
4
  from grouped_sampling import GroupedSamplingPipeLine
5
 
 
6
  from prompt_engeneering import rewrite_prompt
7
  from supported_models import get_supported_model_names
8
 
@@ -18,6 +19,14 @@ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine
18
  :return: A pipeline with the given model name and group size.
19
  """
20
  print(f"Starts downloading model: {model_name} from the internet.")
 
 
 
 
 
 
 
 
21
  pipeline = GroupedSamplingPipeLine(
22
  model_name=model_name,
23
  group_size=group_size,
@@ -25,7 +34,9 @@ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine
25
  temp=0.5,
26
  top_p=0.6,
27
  )
28
- print(f"Finished downloading model: {model_name} from the internet.")
 
 
29
  return pipeline
30
 
31
 
 
3
  import streamlit as st
4
  from grouped_sampling import GroupedSamplingPipeLine
5
 
6
+ from download_repo import download_repository
7
  from prompt_engeneering import rewrite_prompt
8
  from supported_models import get_supported_model_names
9
 
 
19
  :return: A pipeline with the given model name and group size.
20
  """
21
  print(f"Starts downloading model: {model_name} from the internet.")
22
+ download_repository_start_time = time()
23
+ st.write(f"Starts downloading model: {model_name} from the internet.")
24
+ download_repository(model_name)
25
+ download_repository_end_time = time()
26
+ download_time = download_repository_end_time - download_repository_start_time
27
+ st.write(f"Finished downloading model: {model_name} from the internet in {download_time} seconds.")
28
+ st.write(f"Starts creating pipeline with model: {model_name}")
29
+ pipeline_start_time = time()
30
  pipeline = GroupedSamplingPipeLine(
31
  model_name=model_name,
32
  group_size=group_size,
 
34
  temp=0.5,
35
  top_p=0.6,
36
  )
37
+ pipeline_end_time = time()
38
+ pipeline_time = pipeline_end_time - pipeline_start_time
39
+ st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time} seconds.")
40
  return pipeline
41
 
42
 
on_server_start.py CHANGED
@@ -1,20 +1,7 @@
1
  """
2
  A script that is run when the server starts.
3
  """
4
- from huggingface_hub import snapshot_download
5
-
6
-
7
- def download_model(model_name: str):
8
- """
9
- Downloads a model from hugging face hub to the disk but not to the RAM.
10
- :param model_name: The name of the model to download.
11
- """
12
- number_of_seconds_in_a_day: int = 86_400
13
- snapshot_download(
14
- repo_id=model_name,
15
- etag_timeout=number_of_seconds_in_a_day,
16
- resume_download=True,
17
- )
18
 
19
 
20
  def download_useful_models():
@@ -27,7 +14,7 @@ def download_useful_models():
27
  "facebook/opt-125m",
28
  )
29
  for model_name in useful_models:
30
- download_model(model_name)
31
 
32
 
33
  def main():
 
1
  """
2
  A script that is run when the server starts.
3
  """
4
+ from download_repo import download_repository
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def download_useful_models():
 
14
  "facebook/opt-125m",
15
  )
16
  for model_name in useful_models:
17
+ download_repository(model_name)
18
 
19
 
20
  def main():