import gradio as gr import numpy as np from datasets import load_dataset import matplotlib.pyplot as plt import seaborn as sns dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] def analyze_sentence(index): row = dataset[index] text = row['text'] _, seq_len, _ = row['attention_maps_shape'] attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape) plot = sns.heatmap(attn_maps.sum(1).sum(0)) plt.xticks(np.arange(seq_len - 1) + 0.5, tokenizer.tokenize(text, add_special_tokens=False), rotation=90); plt.yticks(np.arange(seq_len - 1) + 0.5, tokenizer.tokenize(text, add_special_tokens=False), rotation=0); plt.ylabel('TARGET') plt.xlabel('SOURCE') plt.grid() return row['text'], plot iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')], outputs=[gr.Label(), gr.Plot(label="Plot")]) iface.launch()