Saibo Geng commited on
Commit
a1aa766
1 Parent(s): 805081b

use smaller repetition penalty; add doc

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -42,7 +42,7 @@ if __name__ == "__main__":
42
  grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
43
 
44
  outputs = model.generate(
45
- **inputs, max_new_tokens=50, repetition_penalty=1.1, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
46
  )
47
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
48
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
@@ -70,19 +70,19 @@ if __name__ == "__main__":
70
  with demo:
71
  gr.Markdown(
72
  """
73
- # 🌈 Color Coded Text Generation 🌈
74
- This is a demo of how you can obtain the probabilities of each generated token, and use them to
75
- color code the model output.
76
- Feel free to clone this demo and modify it to your needs 🤗
77
- Internally, it relies on [`compute_transition_scores`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores),
78
- which was added in `transformers` v4.26.0.
79
  """
80
  )
81
 
82
  with gr.Row():
83
  with gr.Column():
84
  prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string for http request:")
85
- button = gr.Button(f"Generate with {MODEL_NAME}, using sampling!")
86
  with gr.Column():
87
  highlighted_text = gr.HighlightedText(
88
  label="Highlighted generation",
 
42
  grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
43
 
44
  outputs = model.generate(
45
+ **inputs, max_new_tokens=50, repetition_penalty=1.05, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
46
  )
47
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
48
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
 
70
  with demo:
71
  gr.Markdown(
72
  """
73
+ # Grammar-Constrained Decoding with GPT-2
74
+ This is a demo of how you can constrain the output of a GPT-2 model using a formal grammar.
75
+ Here we use a simple JSON grammar to constrain the output of the model to be valid JSON strings.
76
+ The grammar is defined in `json_minimal.ebnf` and is written in the Extended Backus-Naur Form (EBNF).
77
+ Internally, it relies on the library [`transformers-cfg`](https://github.com/epfl-dlab/transformers-CFG).
78
+ For demo purpose, gpt2 is used, but you can use much larger models for better performance.
79
  """
80
  )
81
 
82
  with gr.Row():
83
  with gr.Column():
84
  prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string for http request:")
85
+ button = gr.Button(f"Generate with json object using {MODEL_NAME}!")
86
  with gr.Column():
87
  highlighted_text = gr.HighlightedText(
88
  label="Highlighted generation",