File size: 5,771 Bytes
9f74b46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import re
import pandas
import seaborn
import gradio
import pathlib

import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy
from sklearn.metrics.pairwise import cosine_distances

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification, AutoModelForMaskedLM
)

## Rollout Helper Function
def compute_joint_attention(att_mat, res=True):
    if res:
        residual_att = numpy.eye(att_mat.shape[1])[None,...]
        att_mat = att_mat + residual_att
        att_mat = att_mat / att_mat.sum(axis=-1)[...,None]

    joint_attentions = numpy.zeros(att_mat.shape)
    layers = joint_attentions.shape[0]
    joint_attentions[0] = att_mat[0]
    for i in numpy.arange(1,layers):
        joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])

    return joint_attentions

def create_plot(all_tokens, score_data):
    LAYERS = list(range(12))
    fig, axs = plt.subplots(6, 2, figsize=(8, 24))
    plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0.5, wspace=0.5)
    for layer in LAYERS:
        a = (layer)//2
        b = layer%2
        seaborn.heatmap(
                ax=axs[a, b],
                data=pandas.DataFrame(score_data[layer], index= all_tokens, columns=all_tokens),
                cmap="Blues",
                annot=False,
                cbar=False
            )
        axs[a, b].set_title(f"Layer: {layer+1}")
    return fig

matplotlib.use('agg')

DISTANCE_FUNC = {
    'cosine': cosine_distances
}
MODEL_PATH = {
    'bert': 'bert-base-uncased',
    'roberta': 'roberta-base',
}

MODEL_NAME = 'bert'
#MODEL_NAME = 'roberta'
METRIC = 'cosine'

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)


def run(mname, sent):
    global MODEL_NAME, config, model, tokenizer
    if mname != MODEL_NAME:
        MODEL_NAME = mname
        config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
        model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
    sent = re.sub(r".MASK.", tokenizer.mask_token, sent)
    inputs = tokenizer(sent, return_token_type_ids=True, return_tensors="pt")

    ## Cpmpute: layerwise value zeroing
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        token_type_ids=inputs['token_type_ids'],
                        output_hidden_states=True, output_attentions=False)

    org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
    input_shape = inputs['input_ids'].size()
    batch_size, seq_length = input_shape

    score_matrix = numpy.zeros((config.num_hidden_layers, seq_length, seq_length))
    for l, layer_module in enumerate(getattr(model, MODEL_NAME).encoder.layer):
        for t in range(seq_length):
            extended_blanking_attention_mask: torch.Tensor = getattr(model, MODEL_NAME).get_extended_attention_mask(inputs['attention_mask'], input_shape, device)
            with torch.no_grad():
                layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output
                                            attention_mask=extended_blanking_attention_mask,
                                            output_attentions=False,
                                            zero_value_index=t,
                                            )
            hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
            # compute similarity between original and new outputs
            # cosine
            x = hidden_states
            y = org_hidden_states[l+1].detach().cpu().numpy()

            distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
            score_matrix[l, :, t] = distances

    valuezeroing_scores = score_matrix / numpy.sum(score_matrix, axis=-1, keepdims=True)
    rollout_valuezeroing_scores = compute_joint_attention(valuezeroing_scores, res=False)


    # Plot:
    cmap = "Blues"
    all_tokens = [tokenizer.convert_ids_to_tokens(t) for t in inputs['input_ids']]
    rollout_fig = create_plot(all_tokens, rollout_valuezeroing_scores)
    value_fig = create_plot(all_tokens, valuezeroing_scores)

    return rollout_fig, value_fig

examples = pandas.read_csv("examples.csv").to_numpy().tolist()

with gradio.Blocks(
        title="Differences with/without zero-valuing",
        css= ".output-image > img {height: 2000px !important; max-height: none !important;} "
) as iface:
    gradio.Markdown(pathlib.Path("description.md").read_text)
    with gradio.Row(equal_height=True):
        with gradio.Column(scale=4):
            sent = gradio.Textbox(label="Input sentence")
        with gradio.Column(scale=1):
            model_choice = gradio.Dropdown(choices=['bert', 'roberta'], value="bert")
            but = gradio.Button("Submit")
    gradio.Examples(examples, [sent])
    with gradio.Row(equal_height=True):
        with gradio.Column():
            gradio.Markdown("### With Rollout")
            rollout_result = gradio.Plot()
        with gradio.Column():
            gradio.Markdown("### Without Rollout")
            value_result = gradio.Plot()
    with gradio.Accordion("Some more details"):
        gradio.Markdown(pathlib.Path("notice.md").read_text)

    but.click(run,
            inputs=[model_choice, sent],
            outputs=[rollout_result, value_result]
        )


iface.launch()