yonikremer commited on
Commit
d2282ae
1 Parent(s): 84ab38f

changed group size to 1024 (constant)

Browse files
Files changed (1) hide show
  1. hanlde_form_submit.py +2 -4
hanlde_form_submit.py CHANGED
@@ -20,11 +20,10 @@ def is_downloaded(model_name: str) -> bool:
20
 
21
 
22
  @lru_cache()
23
- def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine:
24
  """
25
  Creates a pipeline with the given model name and group size.
26
  :param model_name: The name of the model to use.
27
- :param group_size: The size of the groups to use.
28
  :return: A pipeline with the given model name and group size.
29
  """
30
  if not is_downloaded(model_name):
@@ -38,7 +37,7 @@ def create_pipeline(model_name: str, group_size: int) -> GroupedSamplingPipeLine
38
  pipeline_start_time = time()
39
  pipeline = GroupedSamplingPipeLine(
40
  model_name=model_name,
41
- group_size=group_size,
42
  end_of_sentence_stop=False,
43
  top_k=50,
44
  load_in_8bit=False,
@@ -92,7 +91,6 @@ def on_form_submit(
92
  loading_start_time = time()
93
  pipeline = create_pipeline(
94
  model_name=model_name,
95
- group_size=output_length,
96
  )
97
  loading_end_time = time()
98
  loading_time = loading_end_time - loading_start_time
 
20
 
21
 
22
  @lru_cache()
23
+ def create_pipeline(model_name: str) -> GroupedSamplingPipeLine:
24
  """
25
  Creates a pipeline with the given model name and group size.
26
  :param model_name: The name of the model to use.
 
27
  :return: A pipeline with the given model name and group size.
28
  """
29
  if not is_downloaded(model_name):
 
37
  pipeline_start_time = time()
38
  pipeline = GroupedSamplingPipeLine(
39
  model_name=model_name,
40
+ group_size=1024,
41
  end_of_sentence_stop=False,
42
  top_k=50,
43
  load_in_8bit=False,
 
91
  loading_start_time = time()
92
  pipeline = create_pipeline(
93
  model_name=model_name,
 
94
  )
95
  loading_end_time = time()
96
  loading_time = loading_end_time - loading_start_time