import gradio as gr from datasets import load_dataset import matplotlib.pyplot as plt import seaborn as sns dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] def analyze_sentence(index): row = dataset[index] attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape']) return row['text'], sns.heatmap(attn_maps.sum(1).sum(0)) iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')], outputs=[gr.Label(), gr.Plot(label="Plot")]) iface.launch()