Saibo Geng commited on
Commit
d7755a4
1 Parent(s): 075c5ca

remove repetition penalty

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -42,7 +42,7 @@ if __name__ == "__main__":
42
  grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
43
 
44
  outputs = model.generate(
45
- **inputs, max_new_tokens=50, repetition_penalty=1.05, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
46
  )
47
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
48
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
@@ -81,7 +81,7 @@ if __name__ == "__main__":
81
 
82
  with gr.Row():
83
  with gr.Column():
84
- prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string for http request:")
85
  button = gr.Button(f"Generate with json object using {MODEL_NAME}!")
86
  with gr.Column():
87
  highlighted_text = gr.HighlightedText(
 
42
  grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
43
 
44
  outputs = model.generate(
45
+ **inputs, max_new_tokens=50, repetition_penalty=1, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
46
  )
47
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
48
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
 
81
 
82
  with gr.Row():
83
  with gr.Column():
84
+ prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string describing a Pokémon character:")
85
  button = gr.Button(f"Generate with json object using {MODEL_NAME}!")
86
  with gr.Column():
87
  highlighted_text = gr.HighlightedText(