pszemraj commited on
Commit
038148f
1 Parent(s): 5db51ab

generate with contrastive search

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (2) hide show
  1. app.py +33 -26
  2. requirements.txt +1 -1
app.py CHANGED
@@ -17,12 +17,11 @@ use_gpu = torch.cuda.is_available()
17
  def generate_text(
18
  prompt: str,
19
  gen_length=64,
20
- num_beams=4,
 
21
  no_repeat_ngram_size=2,
22
  length_penalty=1.0,
23
- num_beam_groups=1,
24
  # perma params (not set by user)
25
- repetition_penalty=3.5,
26
  abs_max_length=512,
27
  verbose=False,
28
  ):
@@ -53,15 +52,13 @@ def generate_text(
53
  logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
54
  result = generator(
55
  prompt,
56
- max_length=gen_length + input_len,
 
57
  min_length=input_len + 4,
58
- num_beams=num_beams,
59
- num_beam_groups=num_beam_groups,
60
- repetition_penalty=repetition_penalty,
61
  no_repeat_ngram_size=no_repeat_ngram_size,
62
  length_penalty=length_penalty,
63
- do_sample=False,
64
- early_stopping=True,
65
  ) # generate
66
  response = result[0]["generated_text"]
67
  rt = time.perf_counter() - st
@@ -118,18 +115,19 @@ def get_parser():
118
  )
119
 
120
  parser.add_argument(
121
- "-nb",
122
- "--num_beams",
123
- type=int,
124
- default=4,
125
- help="Number of beams for beam search. 1 means no beam search.",
126
  )
127
 
128
  parser.add_argument(
129
- "--num_beam_groups",
 
130
  type=int,
131
- default=1,
132
- help="Number of groups to divide best candidates into in order to ensure diversity among different groups of beams that yield the best n results. 1 means no group beam search. (default 1)",
133
  )
134
  return parser
135
 
@@ -146,11 +144,18 @@ available_models = [
146
  ]
147
 
148
  if __name__ == "__main__":
 
149
  logging.info("\n\n\nStarting new instance of app.py")
150
  args = get_parser().parse_args()
151
  logging.info(f"received args:\t{args}")
152
  model_tag = args.model
153
  verbose = args.verbose
 
 
 
 
 
 
154
  logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
155
  generator = pipeline(
156
  "text-generation",
@@ -228,16 +233,18 @@ if __name__ == "__main__":
228
  value=2,
229
  )
230
  with gr.Row():
231
- num_beams = gr.Radio(
232
- choices=[2, 4, 8],
233
  label="Number of Beams",
234
- value=4,
235
  )
236
 
237
- num_beam_groups = gr.Radio(
238
- choices=[1, 2],
239
- label="Number of Beam Groups",
240
- value=1,
 
 
241
  )
242
  length_penalty = gr.Slider(
243
  minimum=0.5,
@@ -269,10 +276,10 @@ if __name__ == "__main__":
269
  inputs=[
270
  prompt_text,
271
  num_gen_tokens,
272
- num_beams,
 
273
  no_repeat_ngram_size,
274
  length_penalty,
275
- num_beam_groups,
276
  ],
277
  outputs=[email_mailto_button, generated_email],
278
  )
 
17
  def generate_text(
18
  prompt: str,
19
  gen_length=64,
20
+ penalty_alpha=0.6,
21
+ top_k=6,
22
  no_repeat_ngram_size=2,
23
  length_penalty=1.0,
 
24
  # perma params (not set by user)
 
25
  abs_max_length=512,
26
  verbose=False,
27
  ):
 
52
  logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
53
  result = generator(
54
  prompt,
55
+ max_new_tokens=gen_length,
56
+ max_length=None, # in case of default max_length
57
  min_length=input_len + 4,
58
+ penalty_alpha=penalty_alpha,
59
+ top_k=top_k,
 
60
  no_repeat_ngram_size=no_repeat_ngram_size,
61
  length_penalty=length_penalty,
 
 
62
  ) # generate
63
  response = result[0]["generated_text"]
64
  rt = time.perf_counter() - st
 
115
  )
116
 
117
  parser.add_argument(
118
+ "-a",
119
+ "--penalty_alpha",
120
+ type=float,
121
+ default=0.6,
122
+ help="The penalty alpha for the text generation pipeline (contrastive search) - default 0.6",
123
  )
124
 
125
  parser.add_argument(
126
+ "-k",
127
+ "--top_k",
128
  type=int,
129
+ default=6,
130
+ help="The top k for the text generation pipeline (contrastive search) - default 6",
131
  )
132
  return parser
133
 
 
144
  ]
145
 
146
  if __name__ == "__main__":
147
+
148
  logging.info("\n\n\nStarting new instance of app.py")
149
  args = get_parser().parse_args()
150
  logging.info(f"received args:\t{args}")
151
  model_tag = args.model
152
  verbose = args.verbose
153
+ top_k = args.top_k
154
+ alpha = args.penalty_alpha
155
+
156
+ assert top_k > 0, "top_k must be greater than 0"
157
+ assert alpha >= 0.0 and alpha <= 1.0, "penalty_alpha must be between 0 and 1"
158
+
159
  logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}")
160
  generator = pipeline(
161
  "text-generation",
 
233
  value=2,
234
  )
235
  with gr.Row():
236
+ contrastive_top_k = gr.Radio(
237
+ choices=[2, 4, 6, 8],
238
  label="Number of Beams",
239
+ value=top_k,
240
  )
241
 
242
+ penalty_alpha = gr.Slider(
243
+ label="Penalty Alpha",
244
+ value=alpha,
245
+ maximum=1.0,
246
+ minimum=0.0,
247
+ step=0.1,
248
  )
249
  length_penalty = gr.Slider(
250
  minimum=0.5,
 
276
  inputs=[
277
  prompt_text,
278
  num_gen_tokens,
279
+ penalty_alpha,
280
+ contrastive_top_k,
281
  no_repeat_ngram_size,
282
  length_penalty,
 
283
  ],
284
  outputs=[email_mailto_button, generated_email],
285
  )
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
  gradio
2
  torch
3
- transformers
 
1
  gradio
2
  torch
3
+ transformers>=4.24.0