mohsenfayyaz commited on
Commit
f34a8cd
1 Parent(s): 44569e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
7
  import matplotlib
8
  from IPython.display import display, HTML
9
  from transformers import AutoTokenizer
10
- from DecompX.src.globenc_utils import GlobencConfig
11
  from DecompX.src.modeling_bert import BertForSequenceClassification
12
  from DecompX.src.modeling_roberta import RobertaForSequenceClassification
13
 
@@ -108,7 +108,7 @@ def run_decompx(text, model):
108
  SENTENCES = [text, "nothing"]
109
  CONFIGS = {
110
  "DecompX":
111
- GlobencConfig(
112
  include_biases=True,
113
  bias_decomp_type="absdot",
114
  include_LN1=True,
@@ -144,12 +144,12 @@ def run_decompx(text, model):
144
  # RUN DECOMPX
145
  with torch.no_grad():
146
  model.eval()
147
- logits, hidden_states, globenc_last_layer_outputs, globenc_all_layers_outputs = model(
148
  **tokenized_sentence,
149
  output_attentions=False,
150
  return_dict=False,
151
  output_hidden_states=True,
152
- globenc_config=CONFIGS["DecompX"]
153
  )
154
  decompx_outputs = {
155
  "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
@@ -157,13 +157,13 @@ def run_decompx(text, model):
157
  "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
158
  }
159
 
160
- ### globenc_last_layer_outputs.classifier ~ (8, 55, 2) ###
161
- importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
162
  importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
163
  decompx_outputs["importance_last_layer_classifier"] = importance
164
 
165
- ### globenc_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
166
- importance = np.array([g.squeeze().cpu().detach().numpy() for g in globenc_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
167
  importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
168
  importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
169
  decompx_outputs["importance_all_layers_aggregated"] = importance
@@ -196,4 +196,4 @@ demo = gr.Interface(
196
  description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
197
  )
198
 
199
- demo.launch()
 
7
  import matplotlib
8
  from IPython.display import display, HTML
9
  from transformers import AutoTokenizer
10
+ from DecompX.src.decompx_utils import DecompXConfig
11
  from DecompX.src.modeling_bert import BertForSequenceClassification
12
  from DecompX.src.modeling_roberta import RobertaForSequenceClassification
13
 
 
108
  SENTENCES = [text, "nothing"]
109
  CONFIGS = {
110
  "DecompX":
111
+ DecompXConfig(
112
  include_biases=True,
113
  bias_decomp_type="absdot",
114
  include_LN1=True,
 
144
  # RUN DECOMPX
145
  with torch.no_grad():
146
  model.eval()
147
+ logits, hidden_states, decompx_last_layer_outputs, decompx_all_layers_outputs = model(
148
  **tokenized_sentence,
149
  output_attentions=False,
150
  return_dict=False,
151
  output_hidden_states=True,
152
+ decompx_config=CONFIGS["DecompX"]
153
  )
154
  decompx_outputs = {
155
  "tokens": [tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][i][:batch_lengths[i]]) for i in range(len(SENTENCES))],
 
157
  "cls": hidden_states[-1][:, 0, :].cpu().detach().numpy().tolist()# Last layer & only CLS -> (batch, emb_dim)
158
  }
159
 
160
+ ### decompx_last_layer_outputs.classifier ~ (8, 55, 2) ###
161
+ importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_last_layer_outputs.classifier]).squeeze() # (batch, seq_len, classes)
162
  importance = [importance[j][:batch_lengths[j], :] for j in range(len(importance))]
163
  decompx_outputs["importance_last_layer_classifier"] = importance
164
 
165
+ ### decompx_all_layers_outputs.aggregated ~ (12, 8, 55, 55) ###
166
+ importance = np.array([g.squeeze().cpu().detach().numpy() for g in decompx_all_layers_outputs.aggregated]) # (layers, batch, seq_len, seq_len)
167
  importance = np.einsum('lbij->blij', importance) # (batch, layers, seq_len, seq_len)
168
  importance = [importance[j][:, :batch_lengths[j], :batch_lengths[j]] for j in range(len(importance))]
169
  decompx_outputs["importance_all_layers_aggregated"] = importance
 
196
  description="This is a demo for the ACL 2023 paper [DecompX](https://github.com/mohsenfayyaz/DecompX/)"
197
  )
198
 
199
+ demo.launch()