yonikremer commited on
Commit
22e2fd1
1 Parent(s): 05393a3

adaptation to new versions of grouped-sampling

Browse files
Files changed (2) hide show
  1. hanlde_form_submit.py +5 -14
  2. tests.py +4 -8
hanlde_form_submit.py CHANGED
@@ -2,12 +2,10 @@ import os
2
  from time import time
3
 
4
  import streamlit as st
5
- from grouped_sampling import GroupedSamplingPipeLine
6
 
7
  from download_repo import download_pytorch_model
8
  from prompt_engeneering import rewrite_prompt
9
- from supported_models import is_supported, SUPPORTED_MODEL_NAME_PAGES_FORMAT, BLACKLISTED_MODEL_NAMES, \
10
- BLACKLISTED_ORGANIZATIONS
11
 
12
 
13
  def is_downloaded(model_name: str) -> bool:
@@ -94,17 +92,10 @@ def on_form_submit(
94
  TypeError: If the output length is not an integer or the prompt is not a string.
95
  RuntimeError: If the model is not found.
96
  """
97
- if not is_supported(model_name, 1, 1):
98
- raise ValueError(
99
- f"The model: {model_name} is not supported."
100
- f"The supported models are the models from {SUPPORTED_MODEL_NAME_PAGES_FORMAT}"
101
- f" that satisfy the following conditions:\n"
102
- f"1. The model has at least one like and one download.\n"
103
- f"2. The model is not one of: {BLACKLISTED_MODEL_NAMES}.\n"
104
- f"3. The model was not created any of those organizations: {BLACKLISTED_ORGANIZATIONS}.\n"
105
- )
106
  if len(prompt) == 0:
107
- raise ValueError(f"The prompt must not be empty.")
 
 
108
  st.write(f"Loading model: {model_name}...")
109
  loading_start_time = time()
110
  pipeline = create_pipeline(
@@ -114,7 +105,7 @@ def on_form_submit(
114
  loading_end_time = time()
115
  loading_time = loading_end_time - loading_start_time
116
  st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
117
- st.write(f"Generating text...")
118
  generation_start_time = time()
119
  generated_text = generate_text(
120
  pipeline=pipeline,
 
2
  from time import time
3
 
4
  import streamlit as st
5
+ from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
6
 
7
  from download_repo import download_pytorch_model
8
  from prompt_engeneering import rewrite_prompt
 
 
9
 
10
 
11
  def is_downloaded(model_name: str) -> bool:
 
92
  TypeError: If the output length is not an integer or the prompt is not a string.
93
  RuntimeError: If the model is not found.
94
  """
 
 
 
 
 
 
 
 
 
95
  if len(prompt) == 0:
96
+ raise ValueError("The prompt must not be empty.")
97
+ if not is_supported(model_name):
98
+ raise UnsupportedModelNameException(model_name)
99
  st.write(f"Loading model: {model_name}...")
100
  loading_start_time = time()
101
  pipeline = create_pipeline(
 
105
  loading_end_time = time()
106
  loading_time = loading_end_time - loading_start_time
107
  st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
108
+ st.write("Generating text...")
109
  generation_start_time = time()
110
  generated_text = generate_text(
111
  pipeline=pipeline,
tests.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
2
 
3
  import pytest as pytest
4
- from grouped_sampling import GroupedSamplingPipeLine
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
8
  from prompt_engeneering import rewrite_prompt
9
- from supported_models import get_supported_model_names
10
 
11
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
12
 
@@ -21,7 +20,7 @@ def test_prompt_engineering():
21
 
22
 
23
  def test_get_supported_model_names():
24
- supported_model_names = get_supported_model_names()
25
  assert len(supported_model_names) > 0
26
  assert "gpt2" in supported_model_names
27
  assert all(isinstance(name, str) for name in supported_model_names)
@@ -44,16 +43,13 @@ def test_on_form_submit():
44
  with pytest.raises(ValueError):
45
  on_form_submit(model_name, output_length, empty_prompt, web_search=False)
46
  unsupported_model_name = "unsupported_model_name"
47
- with pytest.raises(ValueError):
48
  on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
49
 
50
 
51
  @pytest.mark.parametrize(
52
  "model_name",
53
- get_supported_model_names(
54
- min_number_of_downloads=1000,
55
- min_number_of_likes=100,
56
- )
57
  )
58
  def test_create_pipeline(model_name: str):
59
  pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)
 
1
  import os
2
 
3
  import pytest as pytest
4
+ from grouped_sampling import GroupedSamplingPipeLine, get_full_models_list, UnsupportedModelNameException
5
 
6
  from on_server_start import download_useful_models
7
  from hanlde_form_submit import create_pipeline, on_form_submit
8
  from prompt_engeneering import rewrite_prompt
 
9
 
10
  HUGGING_FACE_CACHE_DIR = "/home/yoni/.cache/huggingface/hub"
11
 
 
20
 
21
 
22
  def test_get_supported_model_names():
23
+ supported_model_names = get_full_models_list()
24
  assert len(supported_model_names) > 0
25
  assert "gpt2" in supported_model_names
26
  assert all(isinstance(name, str) for name in supported_model_names)
 
43
  with pytest.raises(ValueError):
44
  on_form_submit(model_name, output_length, empty_prompt, web_search=False)
45
  unsupported_model_name = "unsupported_model_name"
46
+ with pytest.raises(UnsupportedModelNameException):
47
  on_form_submit(unsupported_model_name, output_length, prompt, web_search=False)
48
 
49
 
50
  @pytest.mark.parametrize(
51
  "model_name",
52
+ get_full_models_list()[:3]
 
 
 
53
  )
54
  def test_create_pipeline(model_name: str):
55
  pipeline: GroupedSamplingPipeLine = create_pipeline(model_name, 5)