joaogante HF staff commited on
Commit
3c3eabb
1 Parent(s): 075bb7d

1st commit

Browse files
Files changed (2) hide show
  1. app.py +62 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from transformers import GPT2Tokenizer, AutoModelForCausalLM
4
+ import numpy as np
5
+
6
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
7
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
8
+ tokenizer.pad_token_id = tokenizer.eos_token_id
9
+
10
+ # if prob > x, then label = y; sorted in descending probability order
11
+ probs_to_label = [
12
+ (0.1, "p >= 10%"),
13
+ (0.01, "p >= 1%"),
14
+ (1e-20, "p < 1%"),
15
+ ]
16
+
17
+ label_to_color = {
18
+ "p >= 10%": "green",
19
+ "p >= 1%": "yellow",
20
+ "p < 1%": "red"
21
+ }
22
+
23
+ def get_tokens_and_scores(prompt):
24
+ inputs = tokenizer([prompt], return_tensors="pt")
25
+ outputs = model.generate(**inputs, max_new_tokens=50, return_dict_in_generate=True, output_scores=True, do_sample=True)
26
+ transition_scores = model.compute_transition_scores(
27
+ outputs.sequences, outputs.scores, normalize_logits=True
28
+ )
29
+ transition_proba = np.exp(transition_scores)
30
+ input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
31
+ generated_tokens = outputs.sequences[:, input_length:]
32
+ highlighted_out = [(tokenizer.decode(token), None) for token in inputs.input_ids]
33
+
34
+ for token, proba in zip(generated_tokens[0], transition_proba[0]):
35
+ this_label = None
36
+ assert 0. <= proba <= 1.0
37
+ for min_proba, label in probs_to_label:
38
+ if proba >= min_proba:
39
+ this_label = label
40
+ break
41
+ highlighted_out.append((tokenizer.decode(token), this_label))
42
+
43
+ return highlighted_out
44
+
45
+
46
+ demo = gr.Interface(
47
+ get_tokens_and_scores,
48
+ [
49
+ gr.Textbox(
50
+ label="Prompt",
51
+ lines=3,
52
+ value="Today is",
53
+ ),
54
+ ],
55
+ gr.HighlightedText(
56
+ label="Highlighted generation",
57
+ combine_adjacent=True,
58
+ show_legend=True,
59
+ ).style(color_map=label_to_color),
60
+ )
61
+ if __name__ == "__main__":
62
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ transformers>=4.26