p1atdev commited on
Commit
e45ac5c
1 Parent(s): 51d0877

feat: add repetition penalty

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -32,6 +32,7 @@ def generate(
32
  top_p=0.95,
33
  top_k=20,
34
  no_repeat_ngram_size=3,
 
35
  num_beams=1,
36
  ):
37
  if input_text.strip() == "":
@@ -55,6 +56,7 @@ def generate(
55
  top_p=top_p,
56
  top_k=top_k,
57
  no_repeat_ngram_size=no_repeat_ngram_size,
 
58
  num_beams=num_beams,
59
  )
60
  return tokenizer.batch_decode(generated)[0]
@@ -144,6 +146,13 @@ with gr.Blocks() as demo:
144
  value=3,
145
  step=1,
146
  )
 
 
 
 
 
 
 
147
  num_beams = gr.Slider(
148
  label="Num beams",
149
  minimum=1,
@@ -167,6 +176,7 @@ with gr.Blocks() as demo:
167
  top_p,
168
  top_k,
169
  no_repeat_ngram_size,
 
170
  num_beams,
171
  ],
172
  outputs=output_text,
@@ -182,6 +192,7 @@ with gr.Blocks() as demo:
182
  top_p,
183
  top_k,
184
  no_repeat_ngram_size,
 
185
  num_beams,
186
  ],
187
  outputs=[input_text, output_text],
 
32
  top_p=0.95,
33
  top_k=20,
34
  no_repeat_ngram_size=3,
35
+ repetition_penalty=1.2,
36
  num_beams=1,
37
  ):
38
  if input_text.strip() == "":
 
56
  top_p=top_p,
57
  top_k=top_k,
58
  no_repeat_ngram_size=no_repeat_ngram_size,
59
+ repetition_penalty=repetition_penalty,
60
  num_beams=num_beams,
61
  )
62
  return tokenizer.batch_decode(generated)[0]
 
146
  value=3,
147
  step=1,
148
  )
149
+ repetition_penalty = gr.Slider(
150
+ label="Repetition penalty",
151
+ minimum=0,
152
+ maximum=2,
153
+ value=1.2,
154
+ step=0.1,
155
+ )
156
  num_beams = gr.Slider(
157
  label="Num beams",
158
  minimum=1,
 
176
  top_p,
177
  top_k,
178
  no_repeat_ngram_size,
179
+ repetition_penalty,
180
  num_beams,
181
  ],
182
  outputs=output_text,
 
192
  top_p,
193
  top_k,
194
  no_repeat_ngram_size,
195
+ repetition_penalty,
196
  num_beams,
197
  ],
198
  outputs=[input_text, output_text],