yonikremer commited on
Commit
cd990a6
1 Parent(s): 41b6658

stopped downloading the models when the servers start

Browse files
Files changed (3) hide show
  1. app.py +0 -2
  2. on_server_start.py +0 -37
  3. tests.py +0 -7
app.py CHANGED
@@ -9,9 +9,7 @@ from torch.cuda import CudaError
9
 
10
  from available_models import AVAILABLE_MODELS
11
  from hanlde_form_submit import on_form_submit
12
- from on_server_start import main as on_server_start_main
13
 
14
- # on_server_start_main()
15
 
16
  st.title("Grouped Sampling Demo")
17
 
 
9
 
10
  from available_models import AVAILABLE_MODELS
11
  from hanlde_form_submit import on_form_submit
 
12
 
 
13
 
14
  st.title("Grouped Sampling Demo")
15
 
on_server_start.py DELETED
@@ -1,37 +0,0 @@
1
- """
2
- A script that is run when the server starts.
3
- """
4
- from concurrent.futures import ThreadPoolExecutor
5
-
6
- from transformers import logging as transformers_logging
7
- from huggingface_hub import logging as huggingface_hub_logging
8
-
9
- from available_models import AVAILABLE_MODELS
10
- from download_repo import download_pytorch_model
11
-
12
-
13
- def disable_progress_bar():
14
- """
15
- Disables the progress bar when downloading models.
16
- """
17
- transformers_logging.disable_progress_bar()
18
- huggingface_hub_logging.disable_propagation()
19
-
20
-
21
- def download_useful_models():
22
- """
23
- Downloads the models that are useful for this project.
24
- So that the user doesn't have to wait for the models to download when they first use the app.
25
- """
26
- print("Downloading useful models. It might take a while...")
27
- with ThreadPoolExecutor() as executor:
28
- executor.map(download_pytorch_model, AVAILABLE_MODELS)
29
-
30
-
31
- def main():
32
- # disable_progress_bar()
33
- download_useful_models()
34
-
35
-
36
- if __name__ == "__main__":
37
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests.py CHANGED
@@ -3,18 +3,11 @@ import os
3
  import pytest as pytest
4
  from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
5
 
6
- from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
8
 
9
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
10
 
11
 
12
- def test_on_server_start():
13
- download_useful_models()
14
- assert os.path.exists(HUGGING_FACE_CACHE_DIR)
15
- assert len(os.listdir(HUGGING_FACE_CACHE_DIR)) > 0
16
-
17
-
18
  def test_on_form_submit():
19
  model_name = "gpt2"
20
  output_length = 10
 
3
  import pytest as pytest
4
  from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
5
 
 
6
  from hanlde_form_submit import create_pipeline, on_form_submit
7
 
8
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
9
 
10
 
 
 
 
 
 
 
11
  def test_on_form_submit():
12
  model_name = "gpt2"
13
  output_length = 10