pszemraj commited on
Commit
be38ebe
1 Parent(s): b327b2e

⚡️ adjust params for csearch

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -20,31 +20,28 @@ def generate_text(
20
  gen_length=64,
21
  penalty_alpha=0.6,
22
  top_k=6,
23
- no_repeat_ngram_size=2,
24
  length_penalty=1.0,
25
  # perma params (not set by user)
26
  abs_max_length=512,
27
  verbose=False,
28
  ):
29
  """
30
- generate_text - generate text from a prompt using a text generation pipeline
31
-
32
- Args:
33
- prompt (str): the prompt to generate text from
34
- model_input (_type_): the text generation pipeline
35
- max_length (int, optional): the maximum length of the generated text. Defaults to 128.
36
- method (str, optional): the generation method. Defaults to "Sampling".
37
- verbose (bool, optional): the verbosity of the output. Defaults to False.
38
-
39
- Returns:
40
- str: the generated text
41
  """
42
  global generator
43
  if verbose:
44
  logging.info(f"Generating text from prompt:\n\n{prompt}")
45
  logging.info(
46
  pp.pformat(
47
- f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}"
48
  )
49
  )
50
  st = time.perf_counter()
@@ -59,7 +56,6 @@ def generate_text(
59
  min_length=input_len + 4,
60
  penalty_alpha=penalty_alpha,
61
  top_k=top_k,
62
- no_repeat_ngram_size=no_repeat_ngram_size,
63
  length_penalty=length_penalty,
64
  ) # generate
65
  response = result[0]["generated_text"]
@@ -107,15 +103,14 @@ def get_parser():
107
  default="postbot/distilgpt2-emailgen-V2",
108
  help="Pass an different huggingface model tag to use a custom model",
109
  )
110
-
111
  parser.add_argument(
112
- "-v",
113
- "--verbose",
114
  required=False,
115
- action="store_true",
116
- help="Verbose output",
 
117
  )
118
-
119
  parser.add_argument(
120
  "-a",
121
  "--penalty_alpha",
@@ -131,6 +126,13 @@ def get_parser():
131
  default=6,
132
  help="The top k for the text generation pipeline (contrastive search) - default 6",
133
  )
 
 
 
 
 
 
 
134
  return parser
135
 
136
 
@@ -192,7 +194,7 @@ if __name__ == "__main__":
192
  )
193
  num_gen_tokens = gr.Slider(
194
  label="Generation Tokens",
195
- value=32,
196
  maximum=96,
197
  minimum=16,
198
  step=8,
@@ -217,7 +219,7 @@ if __name__ == "__main__":
217
  gr.Markdown("---")
218
  gr.Markdown("## Advanced Options")
219
  gr.Markdown(
220
- "This demo generates text via the new [constrastive search](https://huggingface.co/blog/introducing-csearch). See details on the csearch blog post for the methods' parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
221
  )
222
  with gr.Row():
223
  model_name = gr.Dropdown(
@@ -229,11 +231,6 @@ if __name__ == "__main__":
229
  "Load Model",
230
  variant="secondary",
231
  )
232
- no_repeat_ngram_size = gr.Radio(
233
- choices=[1, 2, 3, 4],
234
- label="no repeat ngram size",
235
- value=3,
236
- )
237
  with gr.Row():
238
  contrastive_top_k = gr.Radio(
239
  choices=[2, 4, 6, 8],
@@ -280,7 +277,6 @@ if __name__ == "__main__":
280
  num_gen_tokens,
281
  penalty_alpha,
282
  contrastive_top_k,
283
- no_repeat_ngram_size,
284
  length_penalty,
285
  ],
286
  outputs=[email_mailto_button, generated_email],
 
20
  gen_length=64,
21
  penalty_alpha=0.6,
22
  top_k=6,
 
23
  length_penalty=1.0,
24
  # perma params (not set by user)
25
  abs_max_length=512,
26
  verbose=False,
27
  ):
28
  """
29
+ generate_text - generate text using the text generation pipeline
30
+
31
+ :param str prompt: the prompt to use for the text generation pipeline
32
+ :param int gen_length: the number of tokens to generate
33
+ :param float penalty_alpha: the penalty alpha for the text generation pipeline (contrastive search)
34
+ :param int top_k: the top k for the text generation pipeline (contrastive search)
35
+ :param int abs_max_length: the absolute max length for the text generation pipeline
36
+ :param bool verbose: verbose output
37
+ :return str: the generated text
 
 
38
  """
39
  global generator
40
  if verbose:
41
  logging.info(f"Generating text from prompt:\n\n{prompt}")
42
  logging.info(
43
  pp.pformat(
44
+ f"params:\tmax_length={gen_length}, penalty_alpha={penalty_alpha}, top_k={top_k}, length_penalty={length_penalty}"
45
  )
46
  )
47
  st = time.perf_counter()
 
56
  min_length=input_len + 4,
57
  penalty_alpha=penalty_alpha,
58
  top_k=top_k,
 
59
  length_penalty=length_penalty,
60
  ) # generate
61
  response = result[0]["generated_text"]
 
103
  default="postbot/distilgpt2-emailgen-V2",
104
  help="Pass an different huggingface model tag to use a custom model",
105
  )
 
106
  parser.add_argument(
107
+ "-l",
108
+ "--max_length",
109
  required=False,
110
+ type=int,
111
+ default=64,
112
+ help="default max length of the generated text",
113
  )
 
114
  parser.add_argument(
115
  "-a",
116
  "--penalty_alpha",
 
126
  default=6,
127
  help="The top k for the text generation pipeline (contrastive search) - default 6",
128
  )
129
+ parser.add_argument(
130
+ "-v",
131
+ "--verbose",
132
+ required=False,
133
+ action="store_true",
134
+ help="Verbose output",
135
+ )
136
  return parser
137
 
138
 
 
194
  )
195
  num_gen_tokens = gr.Slider(
196
  label="Generation Tokens",
197
+ value=40,
198
  maximum=96,
199
  minimum=16,
200
  step=8,
 
219
  gr.Markdown("---")
220
  gr.Markdown("## Advanced Options")
221
  gr.Markdown(
222
+ "This demo generates text via the new [contrastive search](https://huggingface.co/blog/introducing-csearch). See the csearch blog post for details on the parameters or [here](https://huggingface.co/blog/how-to-generate), for general decoding."
223
  )
224
  with gr.Row():
225
  model_name = gr.Dropdown(
 
231
  "Load Model",
232
  variant="secondary",
233
  )
 
 
 
 
 
234
  with gr.Row():
235
  contrastive_top_k = gr.Radio(
236
  choices=[2, 4, 6, 8],
 
277
  num_gen_tokens,
278
  penalty_alpha,
279
  contrastive_top_k,
 
280
  length_penalty,
281
  ],
282
  outputs=[email_mailto_button, generated_email],