mohsenfayyaz commited on
Commit
094135a
·
1 Parent(s): 4344c32

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm.auto import tqdm
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ from IPython.display import display, HTML
9
+ from transformers import AutoTokenizer
10
+ from DecompX.src.globenc_utils import GlobencConfig
11
+ from DecompX.src.modeling_bert import BertForSequenceClassification
12
+ from DecompX.src.modeling_roberta import RobertaForSequenceClassification
13
+
14
+ plt.style.use("ggplot")
15
+ MODELS = ["WillHeld/roberta-base-sst2"]
16
+
17
+ def plot_clf(tokens, logits, label_names, title="", file_name=None):
18
+ print(tokens)
19
+ plt.figure(figsize=(4.5, 5))
20
+ colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)]
21
+ plt.barh(range(len(tokens)), logits, color=colors)
22
+ plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2)
23
+ plt.gca().invert_yaxis()
24
+
25
+ max_limit = np.max(np.abs(logits)) + 0.2
26
+ min_limit = -0.01 if np.min(logits) > 0 else -max_limit
27
+ plt.xlim(min_limit, max_limit)
28
+ plt.gca().set_xticks([min_limit, max_limit])
29
+ plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold")
30
+ plt.gca().set_yticks(range(len(tokens)))
31
+ plt.gca().set_yticklabels(tokens)
32
+
33
+ plt.gca().yaxis.tick_right()
34
+ for xtick, color in zip(plt.gca().get_yticklabels(), colors):
35
+ xtick.set_color(color)
36
+ xtick.set_fontweight("bold")
37
+ xtick.set_verticalalignment("center")
38
+
39
+ for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]):
40
+ xtick.set_color(color)
41
+ # plt.title(title, fontsize=14, fontweight="bold")
42
+ plt.title(title)
43
+ plt.tight_layout()
44
+
45
+ def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
46
+ """
47
+ importance: (sent_len)
48
+ """
49
+ if no_cls_sep:
50
+ importance = importance[1:-1]
51
+ tokenized_text = tokenized_text[1:-1]
52
+ importance = importance / np.abs(importance).max() / 1.5 # Normalize
53
+ if discrete:
54
+ importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
55
+
56
+ html = "<pre style='color:black; padding: 3px;'>"+prefix
57
+ for i in range(len(tokenized_text)):
58
+ if importance[i] >= 0:
59
+ rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
60
+ else:
61
+ rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
62
+ text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
63
+ color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
64
+ html += (f"<span style='"
65
+ f"{color}"
66
+ f"color:black; border-radius: 5px; padding: 3px;"
67
+ f"font-weight: {int(800)};"
68
+ "'>")
69
+ html += tokenized_text[i].replace('<', "[").replace(">", "]")
70
+ html += "</span> "
71
+ html += "</pre>"
72
+ # display(HTML(html))
73
+ return html
74
+
75
+ def print_preview(decompx_outputs_df, idx=0, discrete=False):
76
+ html = ""
77
+ NO_CLS_SEP = False
78
+ df = decompx_outputs_df
79
+ for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
80
+ if col in df and df[col][idx] is not None:
81
+ if "aggregated" in col:
82
+ sentence_importance = df[col].iloc[idx][0, :]
83
+ if "classifier" in col:
84
+ for label in range(df[col].iloc[idx].shape[-1]):
85
+ sentence_importance = df[col].iloc[idx][:, label]
86
+ html += print_importance(
87
+ sentence_importance,
88
+ df["tokens"].iloc[idx],
89
+ prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
90
+ no_cls_sep=NO_CLS_SEP,
91
+ discrete=False
92
+ )
93
+ break
94
+ sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
95
+ html += print_importance(
96
+ sentence_importance,
97
+ df["tokens"].iloc[idx],
98
+ prefix=f"{col.split('_')[-1]}:".ljust(20),
99
+ no_cls_sep=NO_CLS_SEP,
100
+ discrete=discrete
101
+ )
102
+ return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html
103
+
104
+ def run_decompx(text, model):
105
+ """
106
+ Provide DecompX Token Explanation of Model on Text
107
+ """
108
+ SENTENCES = [text, "nothing"]
109
+ CONFIGS = {
110
+ "DecompX":
111
+ GlobencConfig(
112
+ include_biases=True,
113
+ bias_decomp_type="absdot",
114
+ include_LN1=True,
115
+ include_FFN=True,
116
+ FFN_approx_type="GeLU_ZO",
117
+ include_LN2=True,
118
+ aggregation="vector",
119
+ include_classifier_w_pooler=True,
120
+ tanh_approx_type="ZO",
121
+ output_all_layers=True,
122
+ output_attention=None,
123
+ output_res1=None,
124
+ output_LN1=None,
125
+ output_FFN=None,
126
+ output_res2=None,
127
+ output_encoder=None,
128
+ output_aggregated="norm",
129
+ output_pooler="norm",
130
+ output_classifier=True,
131
+ ),
132
+ }
133
+ MODEL = model
134
+ # LOAD MODEL AND TOKENIZER
135
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
136
+ tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
137
+ batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
138
+ if "roberta" in MODEL:
139
+ model = RobertaForSequenceClassification.from_pretrained(MODEL)
140
+ elif "bert" in MODEL:
141
+ model = BertForSequenceClassification.from_pretrained(MODEL)
142
+ else:
143
+ raise Exception(f"Not implemented model: {MODEL}")
144
+ # RUN DECOMPX
145
+ with torch.no_grad():
146
+ model.eval()
147
+ logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model(
148
+ **tokenized_sentence,
149
+ output_attentions=False,
150
+ return_dict=False,
151
+ output_hidden_states=True,
152
+ globenc_config=CONFIGS["DecompX"]
153
+ )
154
+ decompx_outputs = {
155
+ "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
156
+ "logits": logits.cpu().detach().numpy().tolist(), # (batch, classes)
157
+ "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
158
+ }
159
+
160
+ ### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ###
161
+ importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
162
+ importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
163
+ decompx_outputs["importance_last_layer_classifier"] = importance
164
+
165
+ ### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
166
+ importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
167
+ importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
168
+ importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
169
+ decompx_outputs["importance_all_layers_aggregated"] = importance
170
+
171
+ decompx_outputs_df = pd.DataFrame(decompx_outputs)
172
+ idx = 0
173
+ pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1)
174
+ label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label]
175
+ tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1]
176
+ label = label[1:-1]
177
+ label = label / np.max(np.abs(label))
178
+ plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method")
179
+ return plt, print_preview(decompx_outputs_df)
180
+
181
+ demo = gr.Interface(
182
+ fn=run_decompx,
183
+ inputs=[
184
+ gr.components.Textbox(label="Text"),
185
+ gr.components.Dropdown(label="Model", choices=MODELS),
186
+ ],
187
+ outputs=["plot", "html"],
188
+ examples=[["Building a translation demo with Gradio is so easy!", "WillHeld/roberta-base-sst2"]],
189
+ cache_examples=False,
190
+ title="DecompX Demo",
191
+ description="This demo is a simplified version of the original [NLLB-Translator](https://huggingface.co/spaces/Narrativaai/NLLB-Translator) space"
192
+ )
193
+
194
+ demo.launch()