import gradio as gr import numpy as np from datasets import load_dataset from transformers import AutoTokenizer import matplotlib.pyplot as plt import seaborn as sns dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) def analyze_sentence(index): row = dataset[index] text = row['text'] tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) attn_map_shape = row['attention_maps_shape'][1:] seq_len = attn_map_shape[1] attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) fig = plt.figure(figsize=(0.2 + 0.55 * len(tokenized), 0.5 * len(tokenized))) sns.heatmap(attn_maps.sum(0)[1:, 1:]) plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); plt.ylabel('TARGET') plt.xlabel('SOURCE') plt.grid() return fig demo = gr.Blocks() with demo: with gr.Row(): dropdown = gr.Dropdown(choices=dataset['text'], value=0, min_width=750, type='index') btn = gr.Button("Run") output = gr.Plot(label="Plot", container=True) btn.click(analyze_sentence, [dropdown], [output]) if __name__ == "__main__": demo.launch()