Saibo Geng commited on
Commit
92dde49
1 Parent(s): 5204c67

add json grammar constraint

Browse files
Files changed (3) hide show
  1. app.py +10 -2
  2. json_minimal.ebnf +16 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
 
3
  from transformers import GPT2Tokenizer, AutoModelForCausalLM
4
  import numpy as np
 
 
5
 
6
  MODEL_NAME = "gpt2"
7
 
@@ -13,6 +15,12 @@ if __name__ == "__main__":
13
  tokenizer.pad_token_id = tokenizer.eos_token_id
14
  model.config.pad_token_id = model.config.eos_token_id
15
 
 
 
 
 
 
 
16
  # Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
17
  probs_to_label = [
18
  (0.1, "p >= 10%"),
@@ -33,7 +41,7 @@ if __name__ == "__main__":
33
  """
34
  inputs = tokenizer([prompt], return_tensors="pt")
35
  outputs = model.generate(
36
- **inputs, max_new_tokens=50, return_dict_in_generate=True, output_scores=True, do_sample=True
37
  )
38
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
39
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
@@ -72,7 +80,7 @@ if __name__ == "__main__":
72
 
73
  with gr.Row():
74
  with gr.Column():
75
- prompt = gr.Textbox(label="Prompt", lines=3, value="Today is")
76
  button = gr.Button(f"Generate with {MODEL_NAME}, using sampling!")
77
  with gr.Column():
78
  highlighted_text = gr.HighlightedText(
 
2
 
3
  from transformers import GPT2Tokenizer, AutoModelForCausalLM
4
  import numpy as np
5
+ from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
6
+ from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
7
 
8
  MODEL_NAME = "gpt2"
9
 
 
15
  tokenizer.pad_token_id = tokenizer.eos_token_id
16
  model.config.pad_token_id = model.config.eos_token_id
17
 
18
+ # Load json grammar
19
+ with open("json_minimal.ebnf", "r") as file:
20
+ grammar_str = file.read()
21
+ grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
22
+ grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
23
+
24
  # Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
25
  probs_to_label = [
26
  (0.1, "p >= 10%"),
 
41
  """
42
  inputs = tokenizer([prompt], return_tensors="pt")
43
  outputs = model.generate(
44
+ **inputs, max_new_tokens=20, return_dict_in_generate=True, output_scores=True, logits_processor=[grammar_processor]
45
  )
46
  # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
47
  transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
 
80
 
81
  with gr.Row():
82
  with gr.Column():
83
+ prompt = gr.Textbox(label="Prompt", lines=3, value="This is a valid json string for http request:")
84
  button = gr.Button(f"Generate with {MODEL_NAME}, using sampling!")
85
  with gr.Column():
86
  highlighted_text = gr.HighlightedText(
json_minimal.ebnf ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ root ::= object
4
+
5
+ object ::= " {" ws ( string ":" ws value ("," ws string ":" ws value)* )? ws "}"
6
+
7
+ value ::= object | array | string | number | ("true" | "false" | "null") ws
8
+
9
+ array ::= "[" ws ( value ("," ws value)* )? "]" ws
10
+
11
+ string ::= "\"" [a-zA-Z0-9]* "\"" ws
12
+
13
+ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
14
+
15
+
16
+ ws ::= ([ \t\n] ws)?
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  torch
2
  transformers>=4.26
 
 
1
  torch
2
  transformers>=4.26
3
+ transformers-cfg==0.2.0