Spaces:
Sleeping
Sleeping
File size: 2,326 Bytes
a042926 fac5648 cbd7ed1 e991bcb cbd7ed1 c75ae48 a042926 c5fa8a7 43f6895 c5fa8a7 7e273f3 e991bcb cbd7ed1 43f6895 c75ae48 9ddf875 392fdcd b3d4e85 0577153 b3d4e85 c6c8075 7578fb8 c5fa8a7 c75ae48 c5fa8a7 f1b9681 376cb17 f1b9681 c75ae48 6ff782d fac5648 17a2383 c5fa8a7 9fa205f a042926 822923c c75ae48 17a2383 dae383f 8d98842 376cb17 a7e4d41 c75ae48 cbd7ed1 822923c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum
class VisType(Enum):
SUM = 'Sum over Layers'
num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)
def analyze_sentence(index, vis_type):
row = dataset[index]
text = row['text']
tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
attn_map_shape = row['attention_maps_shape'][1:]
seq_len = attn_map_shape[1]
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
attn_maps = attn_maps[:, 1:, 1:]
if vis_type == VisType.SUM.value:
plot_data = attn_maps.sum(0)
elif vis_type.startswith('Layer #'):
layer_to_inspect = int(vis_type.split('#')[1])
plot_data = attn_maps[layer_to_inspect]
else:
print(vis_type)
0/0
sns.heatmap(plot_data)
plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
metrics = {'Metrics': 1}
metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})
return fig, metrics
demo = gr.Blocks()
with demo:
with gr.Row():
sentence_dropdown = gr.Dropdown(label="Sentence",
choices=[x.split('</s> ')[1] for x in dataset['text']],
value=0, min_width=300, type='index')
vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)],
min_width=70, value=VisType.SUM, type='value')
btn = gr.Button("Run", min_width=30)
output = gr.Plot(label="Plot", container=True)
metrics = gr.Label("Metrics")
btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])
if __name__ == "__main__":
demo.launch() |