File size: 2,326 Bytes
a042926
fac5648
cbd7ed1
e991bcb
cbd7ed1
 
c75ae48
 
a042926
c5fa8a7
 
 
43f6895
c5fa8a7
7e273f3
e991bcb
cbd7ed1
43f6895
c75ae48
9ddf875
392fdcd
b3d4e85
0577153
b3d4e85
 
c6c8075
7578fb8
c5fa8a7
c75ae48
c5fa8a7
 
 
f1b9681
376cb17
f1b9681
c75ae48
6ff782d
 
fac5648
 
 
17a2383
c5fa8a7
9fa205f
a042926
822923c
 
 
c75ae48
 
17a2383
dae383f
8d98842
376cb17
a7e4d41
c75ae48
 
cbd7ed1
822923c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from enum import Enum


class VisType(Enum):
    SUM = 'Sum over Layers'
    

num_layers = 24
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True)


def analyze_sentence(index, vis_type):
    row = dataset[index]
    text = row['text']
    tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False))
    attn_map_shape = row['attention_maps_shape'][1:]
    seq_len = attn_map_shape[1]
    attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1)
    fig = plt.figure(figsize=(0.1 + 0.3 * len(tokenized), 0.25 * len(tokenized)))
    attn_maps = attn_maps[:, 1:, 1:]
    if vis_type == VisType.SUM.value:
        plot_data = attn_maps.sum(0)
    elif vis_type.startswith('Layer #'):
        layer_to_inspect = int(vis_type.split('#')[1])
        plot_data = attn_maps[layer_to_inspect]
    else:
        print(vis_type)
        0/0
    sns.heatmap(plot_data)
    plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90);
    plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0);
    plt.ylabel('TARGET')
    plt.xlabel('SOURCE')    
    plt.grid()
    metrics = {'Metrics': 1}
    metrics.update({k: v for k, v in row.items() if k not in ['text', 'attention_maps', 'attention_maps_shape']})    
    return fig, metrics

demo = gr.Blocks()
with demo:
    with gr.Row():
        sentence_dropdown = gr.Dropdown(label="Sentence", 
                                        choices=[x.split('</s> ')[1] for x in dataset['text']], 
                               value=0, min_width=300, type='index')
        vis_dropdown = gr.Dropdown(label="Visualization", choices=[x.value for x in VisType] + [f'Layer #{i}' for i in range(num_layers)], 
                                   min_width=70, value=VisType.SUM, type='value')
        btn = gr.Button("Run", min_width=30)
    output = gr.Plot(label="Plot", container=True)
    metrics = gr.Label("Metrics")
    btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics])


if __name__ == "__main__":
    demo.launch()