Spaces:
Runtime error
Runtime error
mohsenfayyaz
commited on
Commit
·
094135a
1
Parent(s):
4344c32
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
from IPython.display import display, HTML
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from DecompX.src.globenc_utils import GlobencConfig
|
11 |
+
from DecompX.src.modeling_bert import BertForSequenceClassification
|
12 |
+
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
|
13 |
+
|
14 |
+
plt.style.use("ggplot")
|
15 |
+
MODELS = ["WillHeld/roberta-base-sst2"]
|
16 |
+
|
17 |
+
def plot_clf(tokens, logits, label_names, title="", file_name=None):
|
18 |
+
print(tokens)
|
19 |
+
plt.figure(figsize=(4.5, 5))
|
20 |
+
colors = ["#019875" if l else "#B8293D" for l in (logits >= 0)]
|
21 |
+
plt.barh(range(len(tokens)), logits, color=colors)
|
22 |
+
plt.axvline(0, color='black', ls='-', lw=2, alpha=0.2)
|
23 |
+
plt.gca().invert_yaxis()
|
24 |
+
|
25 |
+
max_limit = np.max(np.abs(logits)) + 0.2
|
26 |
+
min_limit = -0.01 if np.min(logits) > 0 else -max_limit
|
27 |
+
plt.xlim(min_limit, max_limit)
|
28 |
+
plt.gca().set_xticks([min_limit, max_limit])
|
29 |
+
plt.gca().set_xticklabels(label_names, fontsize=14, fontweight="bold")
|
30 |
+
plt.gca().set_yticks(range(len(tokens)))
|
31 |
+
plt.gca().set_yticklabels(tokens)
|
32 |
+
|
33 |
+
plt.gca().yaxis.tick_right()
|
34 |
+
for xtick, color in zip(plt.gca().get_yticklabels(), colors):
|
35 |
+
xtick.set_color(color)
|
36 |
+
xtick.set_fontweight("bold")
|
37 |
+
xtick.set_verticalalignment("center")
|
38 |
+
|
39 |
+
for xtick, color in zip(plt.gca().get_xticklabels(), ["#B8293D", "#019875"]):
|
40 |
+
xtick.set_color(color)
|
41 |
+
# plt.title(title, fontsize=14, fontweight="bold")
|
42 |
+
plt.title(title)
|
43 |
+
plt.tight_layout()
|
44 |
+
|
45 |
+
def print_importance(importance, tokenized_text, discrete=False, prefix="", no_cls_sep=False):
|
46 |
+
"""
|
47 |
+
importance: (sent_len)
|
48 |
+
"""
|
49 |
+
if no_cls_sep:
|
50 |
+
importance = importance[1:-1]
|
51 |
+
tokenized_text = tokenized_text[1:-1]
|
52 |
+
importance = importance / np.abs(importance).max() / 1.5 # Normalize
|
53 |
+
if discrete:
|
54 |
+
importance = np.argsort(np.argsort(importance)) / len(importance) / 1.6
|
55 |
+
|
56 |
+
html = "<pre style='color:black; padding: 3px;'>"+prefix
|
57 |
+
for i in range(len(tokenized_text)):
|
58 |
+
if importance[i] >= 0:
|
59 |
+
rgba = matplotlib.colormaps.get_cmap('Greens')(importance[i]) # Wistia
|
60 |
+
else:
|
61 |
+
rgba = matplotlib.colormaps.get_cmap('Reds')(np.abs(importance[i])) # Wistia
|
62 |
+
text_color = "color: rgba(255, 255, 255, 1.0); " if np.abs(importance[i]) > 0.9 else ""
|
63 |
+
color = f"background-color: rgba({rgba[0]*255}, {rgba[1]*255}, {rgba[2]*255}, {rgba[3]}); " + text_color
|
64 |
+
html += (f"<span style='"
|
65 |
+
f"{color}"
|
66 |
+
f"color:black; border-radius: 5px; padding: 3px;"
|
67 |
+
f"font-weight: {int(800)};"
|
68 |
+
"'>")
|
69 |
+
html += tokenized_text[i].replace('<', "[").replace(">", "]")
|
70 |
+
html += "</span> "
|
71 |
+
html += "</pre>"
|
72 |
+
# display(HTML(html))
|
73 |
+
return html
|
74 |
+
|
75 |
+
def print_preview(decompx_outputs_df, idx=0, discrete=False):
|
76 |
+
html = ""
|
77 |
+
NO_CLS_SEP = False
|
78 |
+
df = decompx_outputs_df
|
79 |
+
for col in ["importance_last_layer_aggregated", "importance_last_layer_classifier"]:
|
80 |
+
if col in df and df[col][idx] is not None:
|
81 |
+
if "aggregated" in col:
|
82 |
+
sentence_importance = df[col].iloc[idx][0, :]
|
83 |
+
if "classifier" in col:
|
84 |
+
for label in range(df[col].iloc[idx].shape[-1]):
|
85 |
+
sentence_importance = df[col].iloc[idx][:, label]
|
86 |
+
html += print_importance(
|
87 |
+
sentence_importance,
|
88 |
+
df["tokens"].iloc[idx],
|
89 |
+
prefix=f"{col.split('_')[-1]} Label{label}:".ljust(20),
|
90 |
+
no_cls_sep=NO_CLS_SEP,
|
91 |
+
discrete=False
|
92 |
+
)
|
93 |
+
break
|
94 |
+
sentence_importance = df[col].iloc[idx][:, df["label"].iloc[idx]]
|
95 |
+
html += print_importance(
|
96 |
+
sentence_importance,
|
97 |
+
df["tokens"].iloc[idx],
|
98 |
+
prefix=f"{col.split('_')[-1]}:".ljust(20),
|
99 |
+
no_cls_sep=NO_CLS_SEP,
|
100 |
+
discrete=discrete
|
101 |
+
)
|
102 |
+
return "<div style='overflow:auto; background-color:white; padding: 10px;'>" + html
|
103 |
+
|
104 |
+
def run_decompx(text, model):
|
105 |
+
"""
|
106 |
+
Provide DecompX Token Explanation of Model on Text
|
107 |
+
"""
|
108 |
+
SENTENCES = [text, "nothing"]
|
109 |
+
CONFIGS = {
|
110 |
+
"DecompX":
|
111 |
+
GlobencConfig(
|
112 |
+
include_biases=True,
|
113 |
+
bias_decomp_type="absdot",
|
114 |
+
include_LN1=True,
|
115 |
+
include_FFN=True,
|
116 |
+
FFN_approx_type="GeLU_ZO",
|
117 |
+
include_LN2=True,
|
118 |
+
aggregation="vector",
|
119 |
+
include_classifier_w_pooler=True,
|
120 |
+
tanh_approx_type="ZO",
|
121 |
+
output_all_layers=True,
|
122 |
+
output_attention=None,
|
123 |
+
output_res1=None,
|
124 |
+
output_LN1=None,
|
125 |
+
output_FFN=None,
|
126 |
+
output_res2=None,
|
127 |
+
output_encoder=None,
|
128 |
+
output_aggregated="norm",
|
129 |
+
output_pooler="norm",
|
130 |
+
output_classifier=True,
|
131 |
+
),
|
132 |
+
}
|
133 |
+
MODEL = model
|
134 |
+
# LOAD MODEL AND TOKENIZER
|
135 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL)
|
136 |
+
tokenized_sentence = tokenizer(SENTENCES, return_tensors="pt", padding=True)
|
137 |
+
batch_lengths = tokenized_sentence['attention_mask'].sum(dim=-1)
|
138 |
+
if "roberta" in MODEL:
|
139 |
+
model = RobertaForSequenceClassification.from_pretrained(MODEL)
|
140 |
+
elif "bert" in MODEL:
|
141 |
+
model = BertForSequenceClassification.from_pretrained(MODEL)
|
142 |
+
else:
|
143 |
+
raise Exception(f"Not implemented model: {MODEL}")
|
144 |
+
# RUN DECOMPX
|
145 |
+
with torch.no_grad():
|
146 |
+
model.eval()
|
147 |
+
logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model(
|
148 |
+
**tokenized_sentence,
|
149 |
+
output_attentions=False,
|
150 |
+
return_dict=False,
|
151 |
+
output_hidden_states=True,
|
152 |
+
globenc_config=CONFIGS["DecompX"]
|
153 |
+
)
|
154 |
+
decompx_outputs = {
|
155 |
+
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
|
156 |
+
"logits": logits.cpu().detach().numpy().tolist(), # (batch, classes)
|
157 |
+
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
|
158 |
+
}
|
159 |
+
|
160 |
+
### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ###
|
161 |
+
importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
|
162 |
+
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
|
163 |
+
decompx_outputs["importance_last_layer_classifier"] = importance
|
164 |
+
|
165 |
+
### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
|
166 |
+
importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
|
167 |
+
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
|
168 |
+
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
|
169 |
+
decompx_outputs["importance_all_layers_aggregated"] = importance
|
170 |
+
|
171 |
+
decompx_outputs_df = pd.DataFrame(decompx_outputs)
|
172 |
+
idx = 0
|
173 |
+
pred_label = np.argmax(decompx_outputs_df.iloc[idx]["logits"], axis=-1)
|
174 |
+
label = decompx_outputs_df.iloc[idx]["importance_last_layer_classifier"][:, pred_label]
|
175 |
+
tokens = decompx_outputs_df.iloc[idx]["tokens"][1:-1]
|
176 |
+
label = label[1:-1]
|
177 |
+
label = label / np.max(np.abs(label))
|
178 |
+
plot_clf(tokens, label, ['-','+'], title=f"DecompX for Predicted Label: {pred_label}", file_name="example_sst2_our_method")
|
179 |
+
return plt, print_preview(decompx_outputs_df)
|
180 |
+
|
181 |
+
demo = gr.Interface(
|
182 |
+
fn=run_decompx,
|
183 |
+
inputs=[
|
184 |
+
gr.components.Textbox(label="Text"),
|
185 |
+
gr.components.Dropdown(label="Model", choices=MODELS),
|
186 |
+
],
|
187 |
+
outputs=["plot", "html"],
|
188 |
+
examples=[["Building a translation demo with Gradio is so easy!", "WillHeld/roberta-base-sst2"]],
|
189 |
+
cache_examples=False,
|
190 |
+
title="DecompX Demo",
|
191 |
+
description="This demo is a simplified version of the original [NLLB-Translator](https://huggingface.co/spaces/Narrativaai/NLLB-Translator) space"
|
192 |
+
)
|
193 |
+
|
194 |
+
demo.launch()
|