File size: 4,417 Bytes
3c3eabb 282bfce 3c3eabb ac2cf21 282bfce ac2cf21 282bfce ac2cf21 282bfce ac2cf21 edce2eb ac2cf21 fd9a520 ac2cf21 426c6f1 282bfce fd9a520 426c6f1 edce2eb 426c6f1 ac2cf21 426c6f1 ac2cf21 a5e42e5 ade3ffb e2ed84a 282bfce ade3ffb 282bfce a5e42e5 6fb202b 282bfce ac2cf21 ade3ffb b7c3808 282bfce ade3ffb db3bd52 ac2cf21 3c3eabb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import numpy as np
MODEL_NAME = "google/flan-t5-base"
if __name__ == "__main__":
# Define your model and your tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # or AutoModelForCausalLM
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
probs_to_label = [
(0.1, "p >= 10%"),
(0.01, "p >= 1%"),
(1e-20, "p < 1%"),
]
label_to_color = {
"p >= 10%": "green",
"p >= 1%": "yellow",
"p < 1%": "red"
}
def get_tokens_and_labels(prompt):
"""
Given the prompt (text), return a list of tuples (decoded_token, label)
"""
inputs = tokenizer([prompt], return_tensors="pt")
outputs = model.generate(
**inputs, max_new_tokens=50, return_dict_in_generate=True, output_scores=True
)
# Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
transition_proba = np.exp(transition_scores)
# We only have scores for the generated tokens, so pop out the prompt tokens
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
generated_ids = outputs.sequences[:, input_length:]
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
# Important: you might need to find a tokenization character to replace (e.g. "Δ " for BPE) and get the correct
# spacing into the final output πΌ
if model.config.is_encoder_decoder:
highlighted_out = []
else:
input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)
highlighted_out = [(token.replace("β", " "), None) for token in input_tokens]
# Get the (decoded_token, label) pairs for the generated tokens
for token, proba in zip(generated_tokens, transition_proba[0]):
this_label = None
assert 0. <= proba <= 1.0
for min_proba, label in probs_to_label:
if proba >= min_proba:
this_label = label
break
highlighted_out.append((token.replace("β", " "), this_label))
return highlighted_out
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# π Color-Coded Text Generation π
This is a demo of how you can obtain the probabilities of each generated token, and use them to
color code the model output. Internally, it relies on
[`compute_transition_scores`](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores),
which was added in `transformers` v4.26.0.
β οΈ For instance, with the pre-populated input and its color-coded output, you can see that
`google/flan-t5-base` struggles with arithmetics.
π€ Feel free to clone this demo and modify it to your needs π€
"""
)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
lines=3,
value=(
"Answer the following question by reasoning step-by-step. The cafeteria had 23 apples. "
"If they used 20 for lunch and bought 6 more, how many apples do they have?"
),
)
button = gr.Button(f"Generate with {MODEL_NAME}")
with gr.Column():
highlighted_text = gr.HighlightedText(
label="Highlighted generation",
combine_adjacent=True,
show_legend=True,
color_map=label_to_color,
)
button.click(get_tokens_and_labels, inputs=prompt, outputs=highlighted_text)
if __name__ == "__main__":
demo.launch()
|