yonikremer commited on
Commit
df273ff
โ€ข
1 Parent(s): 25de4e0

downloading models at the start of the app and not at usage time

Browse files
Files changed (2) hide show
  1. app.py +28 -1
  2. hanlde_form_submit.py +3 -41
app.py CHANGED
@@ -4,8 +4,10 @@ In the demo, the user can write a prompt
4
  and the model will generate a response using the grouped sampling algorithm.
5
  """
6
  import os
 
7
 
8
  import streamlit as st
 
9
  from torch.cuda import CudaError
10
  from huggingface_hub import logging as hf_hub_logging
11
 
@@ -13,6 +15,27 @@ from available_models import AVAILABLE_MODELS
13
  from hanlde_form_submit import on_form_submit
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  hf_hub_logging.set_verbosity_error()
17
 
18
  st.set_page_config(
@@ -20,6 +43,10 @@ st.set_page_config(
20
  layout="wide",
21
  )
22
 
 
 
 
 
23
  with st.form("request_form"):
24
  selected_model_name: str = st.selectbox(
25
  label="ื‘ื—ืจื• ืžื•ื“ืœ",
@@ -50,7 +77,7 @@ with st.form("request_form"):
50
  if submitted:
51
  try:
52
  output = on_form_submit(
53
- selected_model_name,
54
  output_length,
55
  submitted_prompt,
56
  )
 
4
  and the model will generate a response using the grouped sampling algorithm.
5
  """
6
  import os
7
+ from time import time
8
 
9
  import streamlit as st
10
+ from grouped_sampling import GroupedSamplingPipeLine
11
  from torch.cuda import CudaError
12
  from huggingface_hub import logging as hf_hub_logging
13
 
 
15
  from hanlde_form_submit import on_form_submit
16
 
17
 
18
+ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
19
+ """
20
+ Creates a pipeline with the given model name and group size.
21
+ :param model_name: The name of the model to use.
22
+ :param group_size: The size of the groups to use.
23
+ :return: A pipeline with the given model name and group size.
24
+ """
25
+ st.write(f"Starts creating pipeline with model: {model_name}")
26
+ pipeline_start_time = time()
27
+ pipeline = GroupedSamplingPipeLine(
28
+ model_name=model_name,
29
+ group_size=group_size,
30
+ end_of_sentence_stop=False,
31
+ top_k=50,
32
+ )
33
+ pipeline_end_time = time()
34
+ pipeline_time = pipeline_end_time - pipeline_start_time
35
+ st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
36
+ return pipeline
37
+
38
+
39
  hf_hub_logging.set_verbosity_error()
40
 
41
  st.set_page_config(
 
43
  layout="wide",
44
  )
45
 
46
+ pipelines = {
47
+ model_name: create_pipeline(model_name, 1024) for model_name in AVAILABLE_MODELS[1:]
48
+ }
49
+
50
  with st.form("request_form"):
51
  selected_model_name: str = st.selectbox(
52
  label="ื‘ื—ืจื• ืžื•ื“ืœ",
 
77
  if submitted:
78
  try:
79
  output = on_form_submit(
80
+ pipelines[selected_model_name],
81
  output_length,
82
  submitted_prompt,
83
  )
hanlde_form_submit.py CHANGED
@@ -1,28 +1,7 @@
1
  from time import time
2
 
3
  import streamlit as st
4
- from grouped_sampling import GroupedSamplingPipeLine, is_supported, UnsupportedModelNameException
5
-
6
-
7
- def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
8
- """
9
- Creates a pipeline with the given model name and group size.
10
- :param model_name: The name of the model to use.
11
- :param group_size: The size of the groups to use.
12
- :return: A pipeline with the given model name and group size.
13
- """
14
- st.write(f"Starts creating pipeline with model: {model_name}")
15
- pipeline_start_time = time()
16
- pipeline = GroupedSamplingPipeLine(
17
- model_name=model_name,
18
- group_size=group_size,
19
- end_of_sentence_stop=False,
20
- top_k=50,
21
- )
22
- pipeline_end_time = time()
23
- pipeline_time = pipeline_end_time - pipeline_start_time
24
- st.write(f"Finished creating pipeline with model: {model_name} in {pipeline_time:,.2f} seconds.")
25
- return pipeline
26
 
27
 
28
  def generate_text(
@@ -46,13 +25,13 @@ def generate_text(
46
 
47
 
48
  def on_form_submit(
49
- model_name: str,
50
  output_length: int,
51
  prompt: str,
52
  ) -> str:
53
  """
54
  Called when the user submits the form.
55
- :param model_name: The name of the model to use.
56
  :param output_length: The size of the groups to use.
57
  :param prompt: The prompt to use.
58
  :return: The output of the model.
@@ -64,19 +43,6 @@ def on_form_submit(
64
  """
65
  if len(prompt) == 0:
66
  raise ValueError("The prompt must not be empty.")
67
- if not is_supported(model_name):
68
- raise UnsupportedModelNameException(model_name)
69
- st.write(f"Loading model: {model_name}...")
70
- print(f"Loading model: {model_name}...")
71
- loading_start_time = time()
72
- pipeline = create_pipeline(
73
- model_name=model_name,
74
- group_size=output_length,
75
- )
76
- loading_end_time = time()
77
- loading_time = loading_end_time - loading_start_time
78
- st.write(f"Finished loading model: {model_name} in {loading_time:,.2f} seconds.")
79
- print(f"Finished loading model: {model_name} in {loading_time:,} seconds.")
80
  st.write("Generating text...")
81
  print("Generating text...")
82
  generation_start_time = time()
@@ -89,8 +55,4 @@ def on_form_submit(
89
  generation_time = generation_end_time - generation_start_time
90
  st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
91
  print(f"Finished generating text in {generation_time:,.2f} seconds.")
92
- if not isinstance(generated_text, str):
93
- raise RuntimeError(f"The model {model_name} did not generate any text.")
94
- if len(generated_text) == 0:
95
- raise RuntimeError(f"The model {model_name} did not generate any text.")
96
  return generated_text
 
1
  from time import time
2
 
3
  import streamlit as st
4
+ from grouped_sampling import GroupedSamplingPipeLine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def generate_text(
 
25
 
26
 
27
  def on_form_submit(
28
+ pipeline: GroupedSamplingPipeLine,
29
  output_length: int,
30
  prompt: str,
31
  ) -> str:
32
  """
33
  Called when the user submits the form.
34
+ :param pipeline: The pipeline to use. GroupedSamplingPipeLine.
35
  :param output_length: The size of the groups to use.
36
  :param prompt: The prompt to use.
37
  :return: The output of the model.
 
43
  """
44
  if len(prompt) == 0:
45
  raise ValueError("The prompt must not be empty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  st.write("Generating text...")
47
  print("Generating text...")
48
  generation_start_time = time()
 
55
  generation_time = generation_end_time - generation_start_time
56
  st.write(f"Finished generating text in {generation_time:,.2f} seconds.")
57
  print(f"Finished generating text in {generation_time:,.2f} seconds.")
 
 
 
 
58
  return generated_text