m-ric's picture
m-ric HF staff
Update app.py
5e72e33 verified
raw
history blame
No virus
12.6 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import gradio as gr
import spaces
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
print("Loading finished.")
print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
STYLE = """
.custom-container {
display: grid;
align-items: center;
margin: 0!important;
overflow: auto;
}
.prose ul ul {
font-size: 10px!important;
}
.prose li {
margin-bottom: 0!important;
}
.prose table {
margin-bottom: 0!important;
}
.prose td, th {
padding-left: 2px;
padding-right: 2px;
padding-top: 0;
padding-bottom: 0;
}
.tree {
padding: 0px;
margin: 0!important;
box-sizing: border-box;
font-size: 10px;
width: 100%;
height: auto;
text-align: center;
display:inline-block;
}
#root {
display: inline-grid!important;
width:auto!important;
min-width: 220px;
}
.tree ul {
padding-left: 20px;
position: relative;
transition: all 0.5s ease 0s;
display: flex;
flex-direction: column;
gap: 10px;
margin: 0px !important;
}
.tree li {
display: flex;
text-align: center;
list-style-type: none;
position: relative;
padding-left: 20px;
transition: all 0.5s ease 0s;
flex-direction: row;
justify-content: start;
align-items: center;
}
.tree li::before, .tree li::after {
content: "";
position: absolute;
left: 0px;
border-left: 1px solid var(--body-text-color);
width: 20px;
}
.tree li::before {
top: 0;
height:50%;
}
.tree li::after {
top: 50%;
height: 55%;
bottom: auto;
border-top: 1px solid var(--body-text-color);
}
.tree li:only-child::after, li:only-child::before {
display: none;
}
.tree li:first-child::before, .tree li:last-child::after {
border: 0 none;
}
.tree li:last-child::before {
border-bottom: 1px solid var(--body-text-color);
border-radius: 0px 0px 0px 5px;
-webkit-border-radius: 0px 0px 0px 5px;
-moz-border-radius: 0px 0px 0px 5px;
}
.tree li:first-child::after {
border-radius: 5px 0 0 0;
-webkit-border-radius: 5px 0 0 0;
-moz-border-radius: 5px 0 0 0;
}
.tree ul ul::before {
content: "";
position: absolute;
left: 0;
top: 50%;
border-top: 1px solid var(--body-text-color);
width: 20px;
height: 0;
}
.tree ul:has(> li:only-child)::before {
width:40px;
}
a:before {
border-right: 1px solid var(--body-text-color);
border-bottom: 1px solid var(--body-text-color);
content: "";
position: absolute;
width: 10px;
left: 0px;
height: 10px;
top: 50%;
margin-top: -5px;
margin-left: 6px;
transform: rotate(315deg);
}
.tree li a {
border: 1px solid var(--body-text-color);
padding: 5px;
border-radius: 5px;
text-decoration-line: none;
border-radius: 5px;
transition: .5s;
width: 280px;
display: flex;
align-items: center;
justify-content: space-around;
}
.tree li a span {
padding: 5px;
font-size: 12px;
text-transform: uppercase;
letter-spacing: 1px;
font-weight: 500;
}
/*Hover-Section*/
.tree li a:hover, .tree li a:hover+ul li a {
background: #ffedd5;
}
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
border-color: #7c2d12;
}
.chosen {
background-color: #ea580c;
width:auto!important;
}
"""
def clean(s):
return s.replace("\n", r"\n").replace("\t", r"\t")
def generate_markdown_table(
scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None
):
markdown_table = """
<table>
<tr>
<th><b>Token</b></th>
<th><b>Step score</b></th>
<th><b>Total score</b></th>
</tr>"""
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]:
token = tokenizer.decode([token_idx])
item_class = ""
if chosen_tokens and token in chosen_tokens:
item_class = "chosen"
markdown_table += f"""
<tr class={item_class}>
<td>{clean(token)}</td>
<td>{scores[token_idx]:.4f}</td>
<td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td>
</tr>"""
markdown_table += """
</table>"""
return markdown_table
def generate_nodes(token_ix, node, step):
"""Recursively generate HTML for the tree nodes."""
token = tokenizer.decode([token_ix])
html_content = f" <li> <a href='#' class='{('chosen' if node.table is None else '')}'> <span> <b>{token_ix}:<br>{clean(token)}</b> </span> "
if node.table is not None:
html_content += node.table
html_content += "</a>"
if len(node.children.keys()) > 0:
html_content += "<ul> "
for token_ix, subnode in node.children.items():
html_content += generate_nodes(token_ix, subnode, step=step + 1)
html_content += "</ul>"
html_content += "</li>"
return html_content
def generate_html(start_sentence, original_tree):
html_output = f"""<div class="custom-container">
<div class="tree">
<ul>
<li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
if len(original_tree.children.keys()) > 0:
html_output += "<ul> "
for token_ix, subnode in original_tree.children.items():
html_output += generate_nodes(token_ix, subnode, step=1)
html_output += "</ul>"
html_output += """
</ul>
</div>
</body>
"""
return html_output
import pandas as pd
from typing import Dict
from dataclasses import dataclass
@dataclass
class BeamNode:
cumulative_score: float
children_score_divider: float
table: str
current_sentence: str
children: Dict[int, "BeamNode"]
def generate_beams(start_sentence, scores, sequences, length_penalty):
sequences = sequences.cpu().numpy()
input_length = len(tokenizer([start_sentence], return_tensors="pt"))
original_tree = BeamNode(
cumulative_score=0,
table=None,
current_sentence=start_sentence,
children={},
children_score_divider=((input_length + 1) ** length_penalty),
)
n_beams = len(scores[0])
beam_trees = [original_tree] * n_beams
for step, step_scores in enumerate(scores):
(
top_token_indexes,
top_cumulative_scores,
beam_indexes,
current_completions,
top_tokens,
) = ([], [], [], [], [])
for beam_ix in range(n_beams):
current_beam = beam_trees[beam_ix]
# Get top cumulative scores for the current beam
current_top_token_indexes = list(
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
)
top_token_indexes += current_top_token_indexes
top_cumulative_scores += list(
np.array(scores[step][beam_ix][current_top_token_indexes])
+ current_beam.cumulative_score
)
beam_indexes += [beam_ix] * n_beams
current_completions += [beam_trees[beam_ix].current_sentence] * n_beams
top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes]
top_df = pd.DataFrame.from_dict(
{
"token_index": top_token_indexes,
"cumulative_score": top_cumulative_scores,
"beam_index": beam_indexes,
"current_completions": current_completions,
"token": top_tokens,
}
)
maxes = top_df.groupby(["token_index", "current_completions"])[
"cumulative_score"
].idxmax()
top_df = top_df.loc[maxes]
# Sort all top probabilities and keep top n_beams
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[
:n_beams
]
# Write the scores table - one per beam source?
# Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse
for beam_ix in reversed(list(range(n_beams))):
current_beam = beam_trees[beam_ix]
selected_tokens = top_df_selected.loc[
top_df_selected["beam_index"] == beam_ix
]
markdown_table = generate_markdown_table(
step_scores[beam_ix, :],
current_beam.cumulative_score,
current_beam.children_score_divider,
chosen_tokens=list(selected_tokens["token"].values),
)
beam_trees[beam_ix].table = markdown_table
# Add new children for each beam
cumulative_scores = [beam.cumulative_score for beam in beam_trees]
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
current_token_choice = tokenizer.decode([current_token_choice_ix])
# Update the source tree
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"])
cumulative_score = (
cumulative_scores[source_beam_ix]
+ scores[step][source_beam_ix][current_token_choice_ix].numpy()
)
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
table=None,
children={},
current_sentence=beam_trees[source_beam_ix].current_sentence
+ current_token_choice,
cumulative_score=cumulative_score,
children_score_divider=((input_length + step + 1) ** length_penalty),
)
# Reassign all beams at once
beam_trees = [
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
for beam_ix in range(n_beams)
]
# Advance all beams by one token
for beam_ix in range(n_beams):
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
return original_tree
@spaces.GPU
def get_beam_search_html(input_text, number_steps, number_beams, length_penalty):
inputs = tokenizer([input_text], return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=number_steps,
num_beams=number_beams,
num_return_sequences=number_beams,
return_dict_in_generate=True,
length_penalty=length_penalty,
output_scores=True,
do_sample=False,
)
markdown = "Output sequences:"
decoded_sequences = tokenizer.batch_decode(outputs.sequences)
for i, sequence in enumerate(decoded_sequences):
markdown += f"\n- {clean(sequence.replace('<s> ', ''))} (score {outputs.sequences_scores[i]:.2f})"
original_tree = generate_beams(
input_text,
outputs.scores[:],
outputs.sequences[:, :],
length_penalty,
)
html = generate_html(input_text, original_tree)
return html, markdown
with gr.Blocks(
theme=gr.themes.Soft(
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow
),
css=STYLE,
) as demo:
gr.Markdown("""# Beam search visualizer
Play with the parameters below to understand how beam search decoding works!
#### Parameters:
- **Sentence to decode from**: the input sequence to your decoder.
- **Number of steps**: the number of tokens to generate
- **Number of beams**: the number of beams to use
- **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
""")
text = gr.Textbox(label="Sentence to decode from", value="Conclusion: thanks a lot. This article was originally published on")
with gr.Row():
steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
length_penalty = gr.Slider(label="Length penalty", minimum=-4, maximum=4, step=0.5, value=1)
button = gr.Button()
out_html = gr.Markdown()
out_markdown = gr.Markdown()
button.click(get_beam_search_html, inputs=[text, steps, beams, length_penalty], outputs=[out_html, out_markdown])
demo.launch()