Spaces:
Running
Running
mohsenfayyaz
commited on
Commit
•
f34a8cd
1
Parent(s):
44569e5
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
7 |
import matplotlib
|
8 |
from IPython.display import display, HTML
|
9 |
from transformers import AutoTokenizer
|
10 |
-
from DecompX.src.
|
11 |
from DecompX.src.modeling_bert import BertForSequenceClassification
|
12 |
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
|
13 |
|
@@ -108,7 +108,7 @@ def run_decompx(text, model):
|
|
108 |
SENTENCES = [text, "nothing"]
|
109 |
CONFIGS = {
|
110 |
"DecompX":
|
111 |
-
|
112 |
include_biases=True,
|
113 |
bias_decomp_type="absdot",
|
114 |
include_LN1=True,
|
@@ -144,12 +144,12 @@ def run_decompx(text, model):
|
|
144 |
# RUN DECOMPX
|
145 |
with torch.no_grad():
|
146 |
model.eval()
|
147 |
-
logits, hidden_states,
|
148 |
**tokenized_sentence,
|
149 |
output_attentions=False,
|
150 |
return_dict=False,
|
151 |
output_hidden_states=True,
|
152 |
-
|
153 |
)
|
154 |
decompx_outputs = {
|
155 |
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
|
@@ -157,13 +157,13 @@ def run_decompx(text, model):
|
|
157 |
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
|
158 |
}
|
159 |
|
160 |
-
###
|
161 |
-
importance = np.array([g.squeeze().cpu().detach().numpy() for g in
|
162 |
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
|
163 |
decompx_outputs["importance_last_layer_classifier"] = importance
|
164 |
|
165 |
-
###
|
166 |
-
importance = np.array([g.squeeze().cpu().detach().numpy() for g in
|
167 |
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
|
168 |
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
|
169 |
decompx_outputs["importance_all_layers_aggregated"] = importance
|
@@ -196,4 +196,4 @@ demo = gr.Interface(
|
|
196 |
description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
|
197 |
)
|
198 |
|
199 |
-
demo.launch()
|
|
|
7 |
import matplotlib
|
8 |
from IPython.display import display, HTML
|
9 |
from transformers import AutoTokenizer
|
10 |
+
from DecompX.src.decompx_utils import DecompXConfig
|
11 |
from DecompX.src.modeling_bert import BertForSequenceClassification
|
12 |
from DecompX.src.modeling_roberta import RobertaForSequenceClassification
|
13 |
|
|
|
108 |
SENTENCES = [text, "nothing"]
|
109 |
CONFIGS = {
|
110 |
"DecompX":
|
111 |
+
DecompXConfig(
|
112 |
include_biases=True,
|
113 |
bias_decomp_type="absdot",
|
114 |
include_LN1=True,
|
|
|
144 |
# RUN DECOMPX
|
145 |
with torch.no_grad():
|
146 |
model.eval()
|
147 |
+
logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
|
148 |
**tokenized_sentence,
|
149 |
output_attentions=False,
|
150 |
return_dict=False,
|
151 |
output_hidden_states=True,
|
152 |
+
decompx_config=CONFIGS["DecompX"]
|
153 |
)
|
154 |
decompx_outputs = {
|
155 |
"tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
|
|
|
157 |
"cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
|
158 |
}
|
159 |
|
160 |
+
### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
|
161 |
+
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
|
162 |
importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
|
163 |
decompx_outputs["importance_last_layer_classifier"] = importance
|
164 |
|
165 |
+
### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
|
166 |
+
importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
|
167 |
importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
|
168 |
importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
|
169 |
decompx_outputs["importance_all_layers_aggregated"] = importance
|
|
|
196 |
description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
|
197 |
)
|
198 |
|
199 |
+
demo.launch()
|