|
|
|
|
|
@@ -23,14 +23,14 @@ |
|
from torch import nn |
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
-from ...activations import ACT2FN, gelu |
|
-from ...file_utils import ( |
|
+from transformers.activations import ACT2FN, gelu |
|
+from transformers.file_utils import ( |
|
add_code_sample_docstrings, |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
replace_return_docstrings, |
|
) |
|
-from ...modeling_outputs import ( |
|
+from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPastAndCrossAttentions, |
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
CausalLMOutputWithCrossAttentions, |
|
@@ -40,14 +40,14 @@ |
|
SequenceClassifierOutput, |
|
TokenClassifierOutput, |
|
) |
|
-from ...modeling_utils import ( |
|
+from transformers.modeling_utils import ( |
|
PreTrainedModel, |
|
apply_chunking_to_forward, |
|
find_pruneable_heads_and_indices, |
|
prune_linear_layer, |
|
) |
|
-from ...utils import logging |
|
-from .configuration_roberta import RobertaConfig |
|
+from transformers.utils import logging |
|
+from transformers.models.roberta.configuration_roberta import RobertaConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
@@ -183,6 +183,24 @@ |
|
|
|
self.is_decoder = config.is_decoder |
|
|
|
+ def get_attn(self): |
|
+ return self.attn |
|
+ |
|
+ def save_attn(self, attn): |
|
+ self.attn = attn |
|
+ |
|
+ def save_attn_cam(self, cam): |
|
+ self.attn_cam = cam |
|
+ |
|
+ def get_attn_cam(self): |
|
+ return self.attn_cam |
|
+ |
|
+ def save_attn_gradients(self, attn_gradients): |
|
+ self.attn_gradients = attn_gradients |
|
+ |
|
+ def get_attn_gradients(self): |
|
+ return self.attn_gradients |
|
+ |
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
x = x.view(*new_x_shape) |
|
|