import gradio as gr import numpy as np import torch from datasets import load_dataset from transformers import AutoTokenizer import matplotlib.pyplot as plt import seaborn as sns from enum import Enum from spacy import displacy class VisType(Enum): SUM = 'Sum over Layers' num_layers = 24 dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) def analyze_sentence(index, vis_type, vis_format): row = dataset[index] text = row['text'] tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) attn_map_shape = row['attention_maps_shape'][1:] seq_len = attn_map_shape[1] attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) attn_maps = attn_maps[:, 1:, 1:] if vis_type == VisType.SUM.value: plot_data = attn_maps.sum(0) elif vis_type.startswith('Layer #'): layer_to_inspect = int(vis_type.split('#')[1]) plot_data = attn_maps[layer_to_inspect] else: print(vis_type) 0/0 if vis_format == 'Plot': fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized))) sns.heatmap(plot_data) plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); plt.ylabel('TARGET') plt.xlabel('SOURCE') plt.grid() fig_output, graph_output = fig, "" elif vis_format == 'Graph': ex = [{ "words": [{"text": x, "tag": ""} for x in tokenized[1:]], "arcs": [{"start": j, "end": i, "label": "", "dir": "right"} for i in range(len(tokenized) - 1) for j in range(i) if plot_data[i, j] > 0.5 and abs(i-j) > 1 ] }] graph_output = displacy.render(ex, style="dep", jupyter=False, manual=True, options={"compact": True, "offset_x": 20, "distance": 130 }) graph_output = ("
" + graph_output + "
" ) fig_output = None else: fig_output = None plot_data[np.arange(len(plot_data)), np.arange(len(plot_data))] = 0. top_values, top_indices = torch.tensor(plot_data).flatten().topk(30) topk_data = [] for val, ind in zip(top_values, top_indices): if val < 0.5: break ind = np.unravel_index(ind, plot_data.shape) topk_data += [str((tokenized[1+ind[0]], tokenized[1+ind[1]]))] graph_output = '
' + text + '
' + '
'.join(topk_data) + '
' metrics = {'Metrics': 1} metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']}) return fig_output, graph_output, metrics demo = gr.Blocks(css=".displacy_container svg{height:500px !important; margin-top:-100px; transform: scale(0.5)}") with demo: with gr.Row(): sentence_dropdown = gr.Dropdown(label="Sentence", choices=[x.split(' ')[1] for x in dataset['text']], value=0, min_width=300, type='index') vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)], min_width=70, value=VisType.SUM, type='value') btn = gr.Button("Run", min_width=30) vis_format_checkbox = gr.Radio(['Plot', 'Graph', 'Text']) output = gr.Plot(label="Plot", container=True) graph_output = gr.HTML() metrics = gr.Label("Metrics") btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown, vis_format_checkbox], [output, graph_output, metrics]) if __name__ == "__main__": demo.launch()