Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoTokenizer | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from enum import Enum | |
from spacy import displacy | |
class VisType(Enum): | |
SUM = 'Sum over Layers' | |
num_layers = 24 | |
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] | |
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) | |
def analyze_sentence(index, vis_type, vis_format): | |
row = dataset[index] | |
text = row['text'] | |
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) | |
attn_map_shape = row['attention_maps_shape'][1:] | |
seq_len = attn_map_shape[1] | |
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) | |
attn_maps = attn_maps[:, 1:, 1:] | |
if vis_type == VisType.SUM.value: | |
plot_data = attn_maps.sum(0) | |
elif vis_type.startswith('Layer #'): | |
layer_to_inspect = int(vis_type.split('#')[1]) | |
plot_data = attn_maps[layer_to_inspect] | |
else: | |
print(vis_type) | |
0/0 | |
if vis_format == 'Plot': | |
fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized))) | |
sns.heatmap(plot_data) | |
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); | |
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); | |
plt.ylabel('TARGET') | |
plt.xlabel('SOURCE') | |
plt.grid() | |
fig_output, graph_output = fig, "" | |
elif vis_format == 'Graph': | |
ex = [{ | |
"words": [{"text": x, "tag": ""} for x in tokenized[1:]], | |
"arcs": [{"start": j, "end": i, "label": "", "dir": "right"} | |
for i in range(len(tokenized) - 1) for j in range(i) if plot_data[i, j] > 0.5 and abs(i-j) > 1 | |
] | |
}] | |
graph_output = displacy.render(ex, style="dep", jupyter=False, manual=True, options={"compact": True, | |
"offset_x": 20, | |
"distance": 130 | |
}) | |
graph_output = ("<div class='displacy_container' style='max-width:100%; max-height:500px; overflow:auto'>" | |
+ graph_output + | |
"</div>" | |
) | |
fig_output = None | |
else: | |
fig_output = None | |
plot_data[np.arange(len(plot_data)), np.arange(len(plot_data))] = 0. | |
top_values, top_indices = torch.tensor(plot_data).flatten().topk(30) | |
topk_data = [] | |
for val, ind in zip(top_values, top_indices): | |
if val < 0.5: | |
break | |
ind = np.unravel_index(ind, plot_data.shape) | |
topk_data += [str((tokenized[1+ind[0]], tokenized[1+ind[1]]))] | |
graph_output = '<div><b>' + text + '</b><br/>' + '<br/>'.join(topk_data) + '</div>' | |
metrics = {'Metrics': 1} | |
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']}) | |
return fig_output, graph_output, metrics | |
demo = gr.Blocks(css=".displacy_container svg{height:500px !important; margin-top:-100px; transform: scale(0.5)}") | |
with demo: | |
with gr.Row(): | |
sentence_dropdown = gr.Dropdown(label="Sentence", | |
choices=[x.split('</s> ')[1] for x in dataset['text']], | |
value=0, min_width=300, type='index') | |
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + | |
[f'Layer #{i}' for i in range(num_layers)], | |
min_width=70, value=VisType.SUM, type='value') | |
btn = gr.Button("Run", min_width=30) | |
vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text']) | |
output = gr.Plot(label="Plot", container=True) | |
graph_output = gr.HTML() | |
metrics = gr.Label("Metrics") | |
btn.click(analyze_sentence, | |
[sentence_dropdown, vis_dropdown, vis_format_checkbox], | |
[output, graph_output, metrics]) | |
if __name__ == "__main__": | |
demo.launch() |