dar-tau's picture
Update app.py
ca37660 verified
import gradio as gr
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum
from spacy import displacy
class VisType(Enum):
SUM = 'Sum over Layers'
num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index, vis_type, vis_format):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
attn_maps = attn_maps[:, 1:, 1:]
if vis_type == VisType.SUM.value:
plot_data = attn_maps.sum(0)
elif vis_type.startswith('Layer #'):
layer_to_inspect = int(vis_type.split('#')[1])
plot_data = attn_maps[layer_to_inspect]
else:
print(vis_type)
0/0
if vis_format == 'Plot':
fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
sns.heatmap(plot_data)
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
fig_output, graph_output = fig, ""
elif vis_format == 'Graph':
ex = [{
"words": [{"text": x, "tag": ""} for x in tokenized[1:]],
"arcs": [{"start": j, "end": i, "label": "", "dir": "right"}
for i in range(len(tokenized) - 1) for j in range(i) if plot_data[i, j] > 0.5 and abs(i-j) > 1
]
}]
graph_output = displacy.render(ex, style="dep", jupyter=False, manual=True, options={"compact": True,
"offset_x": 20,
"distance": 130
})
graph_output = ("<div class='displacy_container' style='max-width:100%; max-height:500px; overflow:auto'>"
+ graph_output +
"</div>"
)
fig_output = None
else:
fig_output = None
plot_data[np.arange(len(plot_data)), np.arange(len(plot_data))] = 0.
top_values, top_indices = torch.tensor(plot_data).flatten().topk(30)
topk_data = []
for val, ind in zip(top_values, top_indices):
if val < 0.5:
break
ind = np.unravel_index(ind, plot_data.shape)
topk_data += [str((tokenized[1+ind[0]], tokenized[1+ind[1]]))]
graph_output = '<div><b>' + text + '</b><br/>' + '<br/>'.join(topk_data) + '</div>'
metrics = {'Metrics': 1}
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
return fig_output, graph_output, metrics
demo = gr.Blocks(css=".displacy_container svg{height:500px !important; margin-top:-100px; transform: scale(0.5)}")
with demo:
with gr.Row():
sentence_dropdown = gr.Dropdown(label="Sentence",
choices=[x.split('</s> ')[1] for x in dataset['text']],
value=0, min_width=300, type='index')
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] +
[f'Layer #{i}' for i in range(num_layers)],
min_width=70, value=VisType.SUM, type='value')
btn = gr.Button("Run", min_width=30)
vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text'])
output = gr.Plot(label="Plot", container=True)
graph_output = gr.HTML()
metrics = gr.Label("Metrics")
btn.click(analyze_sentence,
[sentence_dropdown, vis_dropdown, vis_format_checkbox],
[output, graph_output, metrics])
if __name__ == "__main__":
demo.launch()