danielhajialigol commited on
Commit
2690a96
1 Parent(s): 1841ebe

fixed seed issue

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -90,9 +90,10 @@ class MimicTransformer(Module):
90
  cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
91
  else:
92
  cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
93
- # last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
94
- last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
95
- last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
 
96
  xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
97
  return (cls_results, xai_logits)
98
 
 
90
  cls_results = self.model(input_ids, attention_mask=attention_mask, labels=drg_labels, output_attentions=True)
91
  else:
92
  cls_results = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
93
+ last_attn = cls_results[-1][-1] # (batch, attn_heads, tokens, tokens)
94
+ # last_attn = torch.mean(torch.stack(cls_results[-1])[:], dim=0)
95
+ # last_layer_attn = torch.mean(last_attn[:, :-3, :, :], dim=1)
96
+ last_layer_attn = last_attn[:, -1, :, :]
97
  xai_logits = self.linear(last_layer_attn).squeeze(dim=-1)
98
  return (cls_results, xai_logits)
99