dar-tau's picture
Update app.py
392fdcd verified
raw
history blame
986 Bytes
import gradio as gr
import numpy as np
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
def analyze_sentence(index):
row = dataset[index]
text = row['text']
_, seq_len, _ = row['attention_maps_shape']
attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape)
plot = sns.heatmap(attn_maps.sum(1).sum(0))
plt.xticks(np.arange(seq_len - 1) + 0.5,
tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
plt.yticks(np.arange(seq_len - 1) + 0.5,
tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
return row['text'], plot
iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
outputs=[gr.Label(), gr.Plot(label="Plot")])
iface.launch()